Commit 70b7a68f authored by Alan Turner's avatar Alan Turner
Browse files

Further improvements

parent 6cb25a15
#%matplotlib
import subprocess, csv, re, datetime, argparse, os, enum
from subprocess import STDOUT
def parse_args():
......@@ -93,11 +91,11 @@ def get_total_time(output):
def verify_output(output, provider):
summary = re.findall("Summary.*", output)[0].replace("\\n", "\n")
if provider == GEMM_Provider.CK:
return bool(re.match(".*ck_gemm.*", summary))
return bool(re.search(".*ck_gemm.*", summary))
if provider == GEMM_Provider.ROCBLAS:
return bool(re.match(".*gpu::gemm.*", summary))
return bool(re.search(".*gpu::gemm.*", summary)) or bool(re.search(".*gpu::quant_gemm.*", summary))
if provider == GEMM_Provider.MLIR:
return bool(re.match(".*MLIR.*", summary)) #######
return bool(re.search(".*mlir_dot.*", summary)) or bool(re.search(".*mlir_quant_dot.*", summary))
def get_gemm_time(config, fp16, provider, timeout):
......@@ -116,15 +114,19 @@ def get_gemm_time(config, fp16, provider, timeout):
timeout=timeout,
env=dict(os.environ,
MIGRAPHX_ENABLE_CK=use_CK,
MIGRAPHX_ENABLE_MLIR=use_MLIR))
MIGRAPHX_ENABLE_MLIR=use_MLIR,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS="dot"))
except Exception as e:
print(f"{provider.name} encountered and exception {e}")
return -100.0
verify_output(str(out.stdout), provider)
total_time = get_total_time(str(out.stdout))
print(f"{provider.name} finished in {total_time}")
return total_time
if verify_output(str(out.stdout), provider):
total_time = get_total_time(str(out.stdout))
print(f"{provider.name} total time: {total_time} ms")
return total_time
else:
print(f"{provider.name} was not found in performance summary")
return -100.0
def get_gemm_softmax_gemm_time(config, provider, timeout):
......@@ -142,17 +144,19 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
timeout=timeout,
env=dict(os.environ,
MIGRAPHX_ENABLE_CK=use_CK,
MIGRAPHX_ENABLE_MLIR=use_MLIR))
MIGRAPHX_ENABLE_MLIR=use_MLIR,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS="dot"))
except Exception as e:
print(f"{provider.name} encountered and exception {e}")
return -100.0
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
total_time = re.findall("Total time: \d+\.\d*", summary)[0]
total_time = total_time.replace("Total time: ", "")
total_time = float(total_time)
print(f"{provider.name} finished in {total_time}")
return total_time
if verify_output(str(out.stdout), provider):
total_time = get_total_time(str(out.stdout))
print(f"{provider.name} total time: {total_time} ms")
return total_time
else:
print(f"{provider.name} was not found in performance summary")
return -100.0
def run_gemm_perf(batches, sizes, fp16, timeout):
......@@ -163,7 +167,7 @@ def run_gemm_perf(batches, sizes, fp16, timeout):
out.write_row(["batch_size", "m", "n", "k", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
for shape in sizes:
config = (b,) + shape
print("Running gemm with config: {0}, {1}, {2}".format(*config))
print("Running {prec} gemm with config: {0}, {1}, {2}, {3}".format(prec=prec_str, *config))
ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK, timeout)
rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS, timeout)
mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR, timeout)
......@@ -177,7 +181,7 @@ def run_gemm_softmax_gemm_perf(batches, sizes, timeout):
out.write_row(["batch_size", "m", "n", "k", "o", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
for shape in sizes:
config = (b,) + shape
print("Running gemm-softmax-gemm with config: {0}, {1}, {2}, {3}".format(*config))
print("Running fp16 gemm-softmax-gemm with config: {0}, {1}, {2}, {3}, {4}".format(*config))
ck_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.CK, timeout)
rb_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.ROCBLAS, timeout)
mlir_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.MLIR, timeout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment