Commit 6cb25a15 authored by turneram's avatar turneram
Browse files

Add improvements

parent 7a80bd37
......@@ -3,12 +3,6 @@ import subprocess, csv, re, datetime, argparse, os, enum
from subprocess import STDOUT
class GEMM_Provider(enum.Enum):
CK = 1
ROCBLAS = 2
MLIR = 3
def parse_args():
parser = argparse.ArgumentParser(description="GEMM performance tools")
parser.add_argument('--gemm',
......@@ -28,7 +22,15 @@ def parse_args():
parser.add_argument('--exclude-lens',
'-e',
nargs='+',
help='Exclude lengths from [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]')
help='Exclude lengths from [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096] \
Lengths not excluded will be permuted as m, n, k, (o) inputs')
parser.add_argument('--lens-from-file',
'-l',
help='Run a list of problem lens from file containing rows of \
m0 n0 k0 (o0) \
m1 n1 k1 (o1) \
... \
(oi) only used for gemm-softmax-gemm')
parser.add_argument('--timeout',
'-t',
type=int,
......@@ -50,6 +52,16 @@ class CSVFile:
cw.writerow(row)
class GEMM_Provider(enum.Enum):
CK = 1
ROCBLAS = 2
MLIR = 3
def get_migraphx_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def get_device_name():
out = subprocess.run("rocminfo",
capture_output=True,
......@@ -59,11 +71,41 @@ def get_device_name():
return matches[0]
def verify_format(file, single):
if not file:
return False
format = r"^\d\s\d\s\d\s*$" if single else r"^\d\s\d\s\d\s\d\s*$"
with open(file, 'r') as f:
return all([bool(re.match(format, line)) for line in f.readlines()])
def parse_lens(file):
with open(file, 'r') as f:
return [tuple(map(int, line.split())) for line in f.readlines()]
def get_total_time(output):
summary = re.findall("Summary.*", output)[0].replace("\\n", "\n")
total_time = re.findall("Total time: \d+\.\d*", summary)[0]
return float(total_time.replace("Total time: ", ""))
def verify_output(output, provider):
summary = re.findall("Summary.*", output)[0].replace("\\n", "\n")
if provider == GEMM_Provider.CK:
return bool(re.match(".*ck_gemm.*", summary))
if provider == GEMM_Provider.ROCBLAS:
return bool(re.match(".*gpu::gemm.*", summary))
if provider == GEMM_Provider.MLIR:
return bool(re.match(".*MLIR.*", summary)) #######
def get_gemm_time(config, fp16, provider, timeout):
model = "../test/onnx/matmul_half.onnx"
root = get_migraphx_root()
model = f"{root}/test/onnx/matmul_half.onnx"
b, m, n, k = config
prec_str = "--fp16" if fp16 else "--int8"
cmd = f"../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} {prec_str}"
cmd = f"{root}/build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} {prec_str} --exhaustive-tune"
use_CK = "1" if provider == GEMM_Provider.CK else "0"
use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"
......@@ -76,22 +118,20 @@ def get_gemm_time(config, fp16, provider, timeout):
MIGRAPHX_ENABLE_CK=use_CK,
MIGRAPHX_ENABLE_MLIR=use_MLIR))
except Exception as e:
print(f"An exception occurred: {str(e)}")
print(f"{provider.name} GEMM {b}, {m}, {n}, {k}:", end=" ")
print("-100.0")
return -100.0
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
total_time = re.findall("Total time: \d+\.\d*", summary)[0]
total_time = total_time.replace("Total time: ", "")
verify_output(str(out.stdout), provider)
total_time = get_total_time(str(out.stdout))
print(f"{provider.name} finished in {total_time}")
return float(total_time)
return total_time
def get_gemm_softmax_gemm_time(config, provider, timeout):
model = "../test/onnx/gemm_softmax_gemm_half.onnx"
root = get_migraphx_root()
model = f"{root}/test/onnx/gemm_softmax_gemm_half.onnx"
b, m, n, k, o = config
cmd = f"../build/bin/driver perf {model} --input-dim @a {b} {m} {k} @b {b} {k} {n} @b1 {b} {n} {o} --fp16"
cmd = f"{root}/build/bin/driver perf {model} --input-dim @a {b} {m} {k} @b {b} {k} {n} @b1 {b} {n} {o} --fp16 --exhaustive-tune"
use_CK = "1" if provider == GEMM_Provider.CK else "0"
use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"
......@@ -109,34 +149,38 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
total_time = re.findall("Total time: \d+\.\d*", summary)[0]
total_time = total_time.replace("Total time: ", "")
total_time = float(total_time)
print(f"{provider.name} finished in {total_time}")
return float(total_time)
return total_time
def run_gemm_perf(batches, sizes, fp16):
prec_str = "half" if fp16 else "int"
def run_gemm_perf(batches, sizes, fp16, timeout):
prec_str = "fp16" if fp16 else "int8"
for b in batches:
out = CSVFile(f"gemm_perf_{prec_str}_{b}.csv")
out.write_row([get_device_name(), datetime.datetime.now()])
out.write_row(["batch_size", "m", "n", "k", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
for shape in [(m, n, k) for m in sizes for n in sizes for k in sizes]:
for shape in sizes:
config = (b,) + shape
ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK)
rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS)
mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR)
print("Running gemm with config: {0}, {1}, {2}".format(*config))
ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK, timeout)
rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS, timeout)
mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR, timeout)
out.write_row(list(config) + [ck_time, rb_time, mlir_time])
def run_gemm_softmax_gemm_perf(batches, sizes):
def run_gemm_softmax_gemm_perf(batches, sizes, timeout):
for b in batches:
out = CSVFile(f"gemm_softmax_gemm_perf_fp16_{b}.csv")
out.write_row([get_device_name(), datetime.datetime.now()])
out.write_row(["batch_size", "m", "n", "k", "o", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
for shape in [(m, n, k, o) for m in sizes for n in sizes for k in sizes for o in sizes]:
for shape in sizes:
config = (b,) + shape
ck_time = get_gemm_time(config, GEMM_Provider.CK)
rb_time = get_gemm_time(config, GEMM_Provider.ROCBLAS)
mlir_time = get_gemm_time(config, GEMM_Provider.MLIR)
print("Running gemm-softmax-gemm with config: {0}, {1}, {2}, {3}".format(*config))
ck_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.CK, timeout)
rb_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.ROCBLAS, timeout)
mlir_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.MLIR, timeout)
out.write_row(list(config) + [ck_time, rb_time, mlir_time])
......@@ -150,7 +194,11 @@ if __name__ == "__main__":
timeout = int(args.timeout)
if args.gemm:
gemm_sizes = [(m, n, k) for m in sizes for n in sizes for k in sizes]
run_gemm_perf(args.batches, gemm_sizes, fp16, timeout)
if verify_format(args.lens_from_file, True):
gemm_sizes = parse_lens(args.lens_from_file)
run_gemm_perf(args.batch_sizes, gemm_sizes, fp16, timeout)
if args.gemm_softmax_gemm:
gemm_softmax_gemm_sizes = [(m, n, k, o) for m in sizes for n in sizes for k in sizes for o in sizes]
run_gemm_softmax_gemm_perf(args.batches, gemm_sizes, timeout)
if verify_format(args.lens_from_file, False):
gemm_softmax_gemm_sizes = parse_lens(args.lens_from_file)
run_gemm_softmax_gemm_perf(args.batch_sizes, gemm_softmax_gemm_sizes, timeout)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment