Commit 280e76d0 authored by turneram's avatar turneram
Browse files

Add onnx models and refactor gemm_perf

parent 092ec713
...@@ -2431,6 +2431,44 @@ def gathernd_batch_dims_test(): ...@@ -2431,6 +2431,44 @@ def gathernd_batch_dims_test():
return ([node], [x, i], [y]) return ([node], [x, i], [y])
@onnx_test()
def gemm_softmax_gemm_half():
a = helper.make_tensor_value_info('a', TensorProto.FLOAT16, [1, 1])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT16, [1, 1])
b1 = helper.make_tensor_value_info('b1', TensorProto.FLOAT16, [1, 1])
out = helper.make_tensor_value_info('out', TensorProto.FLOAT16, [1, 1])
scale_array = np.array([(1/8)])
zero_array = np.array([0])
scale_tensor = helper.make_tensor(name='scale',
data_type=TensorProto.FLOAT16,
dims=scale_array.shape,
vals=scale_array.flatten().astype(np.float16))
zero_tensor = helper.make_tensor(name='zeros',
data_type=TensorProto.FLOAT16,
dims=zero_array.shape,
vals=zero_array.flatten().astype(np.float16))
gemm1 = onnx.helper.make_node('MatMul',
inputs=['a', 'b'],
outputs=['gemm1_out'])
mul1 = onnx.helper.make_node('Mul',
inputs=['gemm1_out', 'scale'],
outputs=['mul1_out'])
add1 = onnx.helper.make_node('Add',
inputs=['mul1_out', 'zeros'],
outputs=['add1_out'])
softmax = onnx.helper.make_node('Softmax',
inputs=['add1_out'],
outputs=['softmax_out'])
gemm2 = onnx.helper.make_node('MatMul',
inputs=['softmax_out', 'b1'],
outputs=['out'])
return ([gemm1, mul1, add1, softmax, gemm2], [a, b, b1], [out], [scale_tensor, zero_tensor])
@onnx_test() @onnx_test()
def gemm_test(): def gemm_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, 6]) A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, 6])
...@@ -4135,6 +4173,21 @@ def lrn_test(): ...@@ -4135,6 +4173,21 @@ def lrn_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def matmul_half():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 1, 1])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 1, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 1, 1])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test() @onnx_test()
def matmul_bmbm_test(): def matmul_bmbm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7])
......
  matmul_half:k

1
2y"MatMul matmul_halfZ
1




Z
2




b
y




B
\ No newline at end of file
#%matplotlib #%matplotlib
import subprocess, csv, re, datetime, argparse, os import subprocess, csv, re, datetime, argparse, os, enum
from subprocess import STDOUT from subprocess import STDOUT
class GEMM_Provider(enum.Enum):
CK = 1
ROCBLAS = 2
MLIR = 3
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="GEMM performance tools") parser = argparse.ArgumentParser(description="GEMM performance tools")
parser.add_argument( parser.add_argument('--gemm',
'--gemm', action='store_true',
action='store_true', help='Run performance comparison on a range of GEMM problem sizes (fp16/int8)')
help='Run performance comparison on a range of GEMM problem sizes') parser.add_argument('--int8',
action='store_true',
help='Quantize GEMMs to int8 precision (not available for GEMM-Softmax-GEMM)')
parser.add_argument('--gemm-softmax-gemm',
action='store_true',
help='Run performance comparison on a range of GEMM-Softmax-GEMM problem sizes (fp16)')
parser.add_argument('--batch_sizes',
'-b',
nargs='+',
help='Batch sizes to run',
required=True)
parser.add_argument('--exclude-lens',
'-e',
nargs='+',
help='Exclude lengths from [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]')
parser.add_argument('--timeout',
'-t',
type=int,
default=600,
help='Time in seconds before compilation timeout')
args = parser.parse_args() args = parser.parse_args()
return args return args
class CSVFile: class CSVFile:
def __init__(self, path="output.csv"): def __init__(self, path="output.csv"):
self.path = path self.path = path
def write_row(self, row=[]): def write_row(self, row=[]):
row = [str(r) for r in row]
with open(self.path, "a+") as f: with open(self.path, "a+") as f:
cw = csv.writer(f) cw = csv.writer(f)
cw.writerow(row) cw.writerow(row)
...@@ -34,167 +59,98 @@ def get_device_name(): ...@@ -34,167 +59,98 @@ def get_device_name():
return matches[0] return matches[0]
def run_perf(model, def get_gemm_time(config, fp16, provider, timeout):
batch_size,
int8=False,
use_ck=False,
use_large_k=False,
disable_fusion=False):
env_vars = ""
if use_ck:
env_vars += "MIGRAPHX_ENABLE_CK=1 "
if use_large_k:
env_vars += "MIGRAPHX_USE_LARGE_K=1 "
if disable_fusion:
env_vars += "MIGRAPHX_DISABLE_CK_FUSION=1 "
int8_str = "--int8" if int8 else ""
cmd = f"{env_vars} ../build/bin/driver perf {model} --fill1 input_ids --input-dim @input_ids {batch_size} 384 --batch {batch_size} --fp16 {int8_str} --exhaustive-tune"
out = subprocess.run(cmd, capture_output=True, check=True, shell=True)
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
total_time = re.findall("Total time: \d+\.\d*", summary)[0]
total_time = total_time.replace("Total time: ", "")
ck_gemm_time = re.findall("ck_gemm_kernel: \d+\.\d*", summary)
if ck_gemm_time:
ck_gemm_time = re.findall("\d+\.\d*", ck_gemm_time[0])[0]
else:
ck_gemm_time = "0.0"
rb_gemm_time = re.findall("gpu::quant_gemm: \d+\.\d*|gpu::gemm: \d+\.\d*",
summary)
if rb_gemm_time:
rb_gemm_time = re.findall("\d+\.\d*", rb_gemm_time[0])[0]
else:
rb_gemm_time = "0.0"
gemm_pack_time = re.findall("gpu::int8_gemm_pack_a: \d+\.\d*", summary)
if gemm_pack_time:
gemm_pack_time = re.findall("\d+\.\d*", gemm_pack_time[0])[0]
else:
gemm_pack_time = "0.0"
gemm_times = [ck_gemm_time, rb_gemm_time, gemm_pack_time]
total_gemm_time = [str(sum(map(float, gemm_times)))]
gemm_times.extend(total_gemm_time)
print(cmd)
print(total_time + "ms")
with open("perf_summaries.txt", "a+") as f:
f.write(cmd + "\n")
f.write(summary + "\n\n")
return [total_time] + gemm_times
def run_ck_perf(model, batch_size, int8=False, use_large_k=False):
# CK with fusions
total_time = run_perf(model, batch_size, int8, True, use_large_k, False)[0]
# CK without fusions
gemm_times = run_perf(model, batch_size, int8, True, use_large_k, True)
return [total_time] + gemm_times[1:]
def gemm_perf(b, m, n, k, fp16):
model = "../test/onnx/matmul_half.onnx" model = "../test/onnx/matmul_half.onnx"
b, m, n, k = config
prec_str = "--fp16" if fp16 else "--int8" prec_str = "--fp16" if fp16 else "--int8"
#rocBLAS run
cmd = f"../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} {prec_str}" cmd = f"../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} {prec_str}"
use_CK = "1" if provider == GEMM_Provider.CK else "0"
use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"
try: try:
out = subprocess.run(cmd.split(), out = subprocess.run(cmd.split(),
capture_output=True, capture_output=True,
check=True, check=True,
timeout=600, timeout=timeout,
env=dict(os.environ, MIGRAPHX_ENABLE_CK="0")) env=dict(os.environ,
# print(out.stderr) MIGRAPHX_ENABLE_CK=use_CK,
MIGRAPHX_ENABLE_MLIR=use_MLIR))
except Exception as e: except Exception as e:
print(f"An exception occurred: {str(e)}") print(f"An exception occurred: {str(e)}")
print(f"{b}, {m}, {n}, {k}:", end=" ") print(f"{provider.name} GEMM {b}, {m}, {n}, {k}:", end=" ")
print("-100.0, -100.0") print("-100.0")
return -100.0, -100.0 return -100.0
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n") summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
# print(summary)
total_time = re.findall("Total time: \d+\.\d*", summary)[0] total_time = re.findall("Total time: \d+\.\d*", summary)[0]
total_time = total_time.replace("Total time: ", "") total_time = total_time.replace("Total time: ", "")
rb_time = total_time
rb_gemm_time = re.findall("gpu::quant_gemm: \d+\.\d*|gpu::gemm: \d+\.\d*",
summary)
rb_gemm_time = re.findall("\d+\.\d*", rb_gemm_time[0])[0]
cmd = f"../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} {prec_str} --exhaustive-tune" return float(total_time)
def get_gemm_softmax_gemm_time(config, provider, timeout):
model = "../test/onnx/gemm_softmax_gemm_half.onnx"
b, m, n, k, o = config
cmd = f"../build/bin/driver perf {model} --input-dim @a {b} {m} {k} @b {b} {k} {n} @b1 {b} {n} {o} --fp16"
use_CK = "1" if provider == GEMM_Provider.CK else "0"
use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"
try: try:
out = subprocess.run(cmd.split(), out = subprocess.run(cmd.split(),
capture_output=True, capture_output=True,
check=True, check=True,
timeout=600, timeout=timeout,
env=dict(os.environ, env=dict(os.environ,
MIGRAPHX_ENABLE_CK="1", MIGRAPHX_ENABLE_CK=use_CK,
MIGRAPHX_USE_LARGE_K="1", MIGRAPHX_ENABLE_MLIR=use_MLIR))
MIGRAPHX_DISABLE_CK_FUSION="1"))
# print(out.stderr)
except Exception as e: except Exception as e:
print(f"An exception occurred: {str(e)}") return -100.0
print(f"{b}, {m}, {n}, {k}:", end=" ")
print("100.0, 100.0")
return 100.0, 100.0
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n") summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
# print(summary)
total_time = re.findall("Total time: \d+\.\d*", summary)[0] total_time = re.findall("Total time: \d+\.\d*", summary)[0]
total_time = total_time.replace("Total time: ", "") total_time = total_time.replace("Total time: ", "")
ck_time = total_time
ck_gemm_time = re.findall("ck_gemm_kernel: \d+\.\d*", summary) return float(total_time)
ck_gemm_time = re.findall("\d+\.\d*", ck_gemm_time[0])[0]
total_diff = float(ck_time) - float(rb_time) def run_gemm_perf(batches, sizes, fp16):
kernel_diff = float(ck_gemm_time) - float(rb_gemm_time)
print(f"{b}, {m}, {n}, {k}:", end=" ")
print(f"{total_diff}, {kernel_diff}")
with open(f"gemm_log.txt", "a+") as f:
f.write(f"{b},{m},{n},{k},{total_diff},{kernel_diff}\n")
return total_diff, kernel_diff
def run_gemm_perf():
batches = [128]
sizes = [64, 256, 384, 768, 1024, 2048, 2304, 3072]
# batches = [768]
# sizes = [64, 384, 768, 2304]
fp16 = True
prec_str = "half" if fp16 else "int" prec_str = "half" if fp16 else "int"
# problems = [(1, 384, 384, 64),
# (64, 384, 384, 64),
# (1, 384, 64, 384),
# (64, 384, 64, 384)]
# for p in problems:
# with open(f"gemm_log.txt", "a+") as f:
# timestamp = str(datetime.datetime.now())
# f.write(f"{timestamp}\n")
# b, m, n, k = p
# gemm_perf(b*12, m, n, k, fp16)
for b in batches: for b in batches:
with open(f"gemm_log.txt", "a+") as f: out = CSVFile(f"gemm_perf_{prec_str}_{b}.csv")
timestamp = str(datetime.datetime.now()) out.write_row([get_device_name(), datetime.datetime.now()])
f.write(f"{timestamp}\n") out.write_row(["batch_size", "m", "n", "k", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
results = [(b, m, n, k, gemm_perf(b, m, n, k, fp16)) for m in sizes for n in sizes for k in sizes] for shape in [(m, n, k) for m in sizes for n in sizes for k in sizes]:
with open(f"gemm_results_{b}_{prec_str}.txt", "w+") as f: config = (b,) + shape
for r in results: ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK)
f.write(f"{r[0]},{r[1]},{r[2]},{r[3]},{r[4][0]},{r[4][1]}\n") rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS)
mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR)
# fp16 = True out.write_row(list(config) + [ck_time, rb_time, mlir_time])
# prec_str = "half" if fp16 else "int"
# for b in batches:
# with open(f"gemm_log.txt", "a+") as f: def run_gemm_softmax_gemm_perf(batches, sizes):
# timestamp = str(datetime.datetime.now()) for b in batches:
# f.write(f"{timestamp}\n") out = CSVFile(f"gemm_softmax_gemm_perf_fp16_{b}.csv")
# results = [(b, m, n, k, gemm_perf(b, m, n, k, fp16)) for m in sizes for n in sizes for k in sizes] out.write_row([get_device_name(), datetime.datetime.now()])
# with open(f"gemm_results_{b}_{prec_str}.txt", "w+") as f: out.write_row(["batch_size", "m", "n", "k", "o", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
# for r in results: for shape in [(m, n, k, o) for m in sizes for n in sizes for k in sizes for o in sizes]:
# f.write(f"{r[0]},{r[1]},{r[2]},{r[3]},{r[4][0]},{r[4][1]}\n") config = (b,) + shape
ck_time = get_gemm_time(config, GEMM_Provider.CK)
rb_time = get_gemm_time(config, GEMM_Provider.ROCBLAS)
mlir_time = get_gemm_time(config, GEMM_Provider.MLIR)
out.write_row(list(config) + [ck_time, rb_time, mlir_time])
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
exclude = args.exclude_lens
exclude = [int(x) for x in exclude]
sizes = [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]
sizes = [x for x in sizes if x not in exclude]
fp16 = not args.int8
timeout = int(args.timeout)
if args.gemm: if args.gemm:
run_gemm_perf() gemm_sizes = [(m, n, k) for m in sizes for n in sizes for k in sizes]
run_gemm_perf(args.batches, gemm_sizes, fp16, timeout)
if args.gemm_softmax_gemm:
gemm_softmax_gemm_sizes = [(m, n, k, o) for m in sizes for n in sizes for k in sizes for o in sizes]
run_gemm_softmax_gemm_perf(args.batches, gemm_sizes, timeout)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment