Commit 70b7a68f authored by Alan Turner's avatar Alan Turner
Browse files

Further improvements

parent 6cb25a15
#%matplotlib
import subprocess, csv, re, datetime, argparse, os, enum import subprocess, csv, re, datetime, argparse, os, enum
from subprocess import STDOUT
def parse_args(): def parse_args():
...@@ -93,11 +91,11 @@ def get_total_time(output): ...@@ -93,11 +91,11 @@ def get_total_time(output):
def verify_output(output, provider): def verify_output(output, provider):
summary = re.findall("Summary.*", output)[0].replace("\\n", "\n") summary = re.findall("Summary.*", output)[0].replace("\\n", "\n")
if provider == GEMM_Provider.CK: if provider == GEMM_Provider.CK:
return bool(re.match(".*ck_gemm.*", summary)) return bool(re.search(".*ck_gemm.*", summary))
if provider == GEMM_Provider.ROCBLAS: if provider == GEMM_Provider.ROCBLAS:
return bool(re.match(".*gpu::gemm.*", summary)) return bool(re.search(".*gpu::gemm.*", summary)) or bool(re.search(".*gpu::quant_gemm.*", summary))
if provider == GEMM_Provider.MLIR: if provider == GEMM_Provider.MLIR:
return bool(re.match(".*MLIR.*", summary)) ####### return bool(re.search(".*mlir_dot.*", summary)) or bool(re.search(".*mlir_quant_dot.*", summary))
def get_gemm_time(config, fp16, provider, timeout): def get_gemm_time(config, fp16, provider, timeout):
...@@ -116,15 +114,19 @@ def get_gemm_time(config, fp16, provider, timeout): ...@@ -116,15 +114,19 @@ def get_gemm_time(config, fp16, provider, timeout):
timeout=timeout, timeout=timeout,
env=dict(os.environ, env=dict(os.environ,
MIGRAPHX_ENABLE_CK=use_CK, MIGRAPHX_ENABLE_CK=use_CK,
MIGRAPHX_ENABLE_MLIR=use_MLIR)) MIGRAPHX_ENABLE_MLIR=use_MLIR,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS="dot"))
except Exception as e: except Exception as e:
print(f"{provider.name} encountered and exception {e}")
return -100.0 return -100.0
verify_output(str(out.stdout), provider) if verify_output(str(out.stdout), provider):
total_time = get_total_time(str(out.stdout)) total_time = get_total_time(str(out.stdout))
print(f"{provider.name} finished in {total_time}") print(f"{provider.name} total time: {total_time} ms")
return total_time
return total_time else:
print(f"{provider.name} was not found in performance summary")
return -100.0
def get_gemm_softmax_gemm_time(config, provider, timeout): def get_gemm_softmax_gemm_time(config, provider, timeout):
...@@ -142,17 +144,19 @@ def get_gemm_softmax_gemm_time(config, provider, timeout): ...@@ -142,17 +144,19 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
timeout=timeout, timeout=timeout,
env=dict(os.environ, env=dict(os.environ,
MIGRAPHX_ENABLE_CK=use_CK, MIGRAPHX_ENABLE_CK=use_CK,
MIGRAPHX_ENABLE_MLIR=use_MLIR)) MIGRAPHX_ENABLE_MLIR=use_MLIR,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS="dot"))
except Exception as e: except Exception as e:
print(f"{provider.name} encountered and exception {e}")
return -100.0 return -100.0
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n") if verify_output(str(out.stdout), provider):
total_time = re.findall("Total time: \d+\.\d*", summary)[0] total_time = get_total_time(str(out.stdout))
total_time = total_time.replace("Total time: ", "") print(f"{provider.name} total time: {total_time} ms")
total_time = float(total_time) return total_time
print(f"{provider.name} finished in {total_time}") else:
print(f"{provider.name} was not found in performance summary")
return total_time return -100.0
def run_gemm_perf(batches, sizes, fp16, timeout): def run_gemm_perf(batches, sizes, fp16, timeout):
...@@ -163,7 +167,7 @@ def run_gemm_perf(batches, sizes, fp16, timeout): ...@@ -163,7 +167,7 @@ def run_gemm_perf(batches, sizes, fp16, timeout):
out.write_row(["batch_size", "m", "n", "k", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"]) out.write_row(["batch_size", "m", "n", "k", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
for shape in sizes: for shape in sizes:
config = (b,) + shape config = (b,) + shape
print("Running gemm with config: {0}, {1}, {2}".format(*config)) print("Running {prec} gemm with config: {0}, {1}, {2}, {3}".format(prec=prec_str, *config))
ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK, timeout) ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK, timeout)
rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS, timeout) rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS, timeout)
mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR, timeout) mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR, timeout)
...@@ -177,7 +181,7 @@ def run_gemm_softmax_gemm_perf(batches, sizes, timeout): ...@@ -177,7 +181,7 @@ def run_gemm_softmax_gemm_perf(batches, sizes, timeout):
out.write_row(["batch_size", "m", "n", "k", "o", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"]) out.write_row(["batch_size", "m", "n", "k", "o", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
for shape in sizes: for shape in sizes:
config = (b,) + shape config = (b,) + shape
print("Running gemm-softmax-gemm with config: {0}, {1}, {2}, {3}".format(*config)) print("Running fp16 gemm-softmax-gemm with config: {0}, {1}, {2}, {3}, {4}".format(*config))
ck_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.CK, timeout) ck_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.CK, timeout)
rb_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.ROCBLAS, timeout) rb_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.ROCBLAS, timeout)
mlir_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.MLIR, timeout) mlir_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.MLIR, timeout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment