Commit 0953c1c6 authored by one's avatar one
Browse files

Support shapes in gemm-flops yaml

parent 114dbb4f
...@@ -29,24 +29,25 @@ host-side dispatch overhead, steady-state launch throughput, and device-side lau ...@@ -29,24 +29,25 @@ host-side dispatch overhead, steady-state launch throughput, and device-side lau
Measure the GPU GEMM FLOPS for different float and int data types, with or without Tensor Core (XDLOPS), Measure the GPU GEMM FLOPS for different float and int data types, with or without Tensor Core (XDLOPS),
performed by NVIDIA [cutlass](https://github.com/NVIDIA/cutlass/tree/ccb697bac77fcc898e9c897b2c90aa5b60ac72fb) performed by NVIDIA [cutlass](https://github.com/NVIDIA/cutlass/tree/ccb697bac77fcc898e9c897b2c90aa5b60ac72fb)
or AMD [rocblas-bench](https://github.com/ROCmSoftwarePlatform/rocBLAS/tree/develop/clients/benchmarks). or AMD [rocblas-bench](https://github.com/ROCmSoftwarePlatform/rocBLAS/tree/develop/clients/benchmarks).
The benchmark supports one or more GEMM shapes in `m,n,k` format.
#### Metrics #### Metrics
| Name | Unit | Description | | Name | Unit | Description |
|------------------------------|----------------|---------------------------------------------------------| |------------------------------|----------------|---------------------------------------------------------|
| gemm-flops/fp64_flops | FLOPS (GFLOPS) | GEMM float64 peak FLOPS. | | gemm-flops/fp64_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM float64 peak FLOPS. |
| gemm-flops/fp32_flops | FLOPS (GFLOPS) | GEMM float32 peak FLOPS. | | gemm-flops/fp32_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM float32 peak FLOPS. |
| gemm-flops/fp16_flops | FLOPS (GFLOPS) | GEMM float16 peak FLOPS. | | gemm-flops/fp16_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM float16 peak FLOPS. |
| gemm-flops/fp64_tc_flops | FLOPS (GFLOPS) | GEMM float64 peak FLOPS with NVIDIA Tensor Core. | | gemm-flops/fp64_tc_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM float64 peak FLOPS with NVIDIA Tensor Core. |
| gemm-flops/tf32_tc_flops | FLOPS (GFLOPS) | GEMM tensor-float32 peak FLOPS with NVIDIA Tensor Core. | | gemm-flops/tf32_tc_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM tensor-float32 peak FLOPS with NVIDIA Tensor Core. |
| gemm-flops/fp16_tc_flops | FLOPS (GFLOPS) | GEMM float16 peak FLOPS with NVIDIA Tensor Core. | | gemm-flops/fp16_tc_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM float16 peak FLOPS with NVIDIA Tensor Core. |
| gemm-flops/bf16_tc_flops | FLOPS (GFLOPS) | GEMM bfloat16 peak FLOPS with NVIDIA Tensor Core. | | gemm-flops/bf16_tc_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM bfloat16 peak FLOPS with NVIDIA Tensor Core. |
| gemm-flops/int8_tc_iops | IOPS (GIOPS) | GEMM int8 peak IOPS with NVIDIA Tensor Core. | | gemm-flops/int8_tc_m${m}_n${n}_k${k}_iops | IOPS (GIOPS) | GEMM int8 peak IOPS with NVIDIA Tensor Core. |
| gemm-flops/int4_tc_iops | IOPS (GIOPS) | GEMM int4 peak IOPS with NVIDIA Tensor Core. | | gemm-flops/int4_tc_m${m}_n${n}_k${k}_iops | IOPS (GIOPS) | GEMM int4 peak IOPS with NVIDIA Tensor Core. |
| gemm-flops/fp32_xdlops_flops | FLOPS (GFLOPS) | GEMM tensor-float32 peak FLOPS with AMD XDLOPS. | | gemm-flops/fp32_xdlops_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM tensor-float32 peak FLOPS with AMD XDLOPS. |
| gemm-flops/fp16_xdlops_flops | FLOPS (GFLOPS) | GEMM float16 peak FLOPS with AMD XDLOPS. | | gemm-flops/fp16_xdlops_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM float16 peak FLOPS with AMD XDLOPS. |
| gemm-flops/bf16_xdlops_flops | FLOPS (GFLOPS) | GEMM bfloat16 peak FLOPS with AMD XDLOPS. | | gemm-flops/bf16_xdlops_m${m}_n${n}_k${k}_flops | FLOPS (GFLOPS) | GEMM bfloat16 peak FLOPS with AMD XDLOPS. |
| gemm-flops/int8_xdlops_iops | IOPS (GIOPS) | GEMM int8 peak IOPS with AMD XDLOPS. | | gemm-flops/int8_xdlops_m${m}_n${n}_k${k}_iops | IOPS (GIOPS) | GEMM int8 peak IOPS with AMD XDLOPS. |
### `matmul` ### `matmul`
......
...@@ -92,15 +92,18 @@ class CudaGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -92,15 +92,18 @@ class CudaGemmFlopsBenchmark(GemmFlopsBenchmark):
if not super()._preprocess(): if not super()._preprocess():
return False return False
self._precision_shape_in_commands = []
for p in self._precision_need_to_run: for p in self._precision_need_to_run:
for m, n, k in self._shapes_to_run:
command = os.path.join(self._args.bin_dir, self._bin_name) command = os.path.join(self._args.bin_dir, self._bin_name)
command += (' --warmup-iterations=' + str(self._args.num_warmup)) command += (' --warmup-iterations=' + str(self._args.num_warmup))
command += (' --operation=gemm') command += (' --operation=gemm')
command += (' --n=' + str(self._args.n)) command += (' --n=' + str(n))
command += (' --k=' + str(self._args.k)) command += (' --k=' + str(k))
command += (' --m=' + str(self._args.m)) command += (' --m=' + str(m))
command += (' --kernels=' + self.__kernel_map[capability][p]) command += (' --kernels=' + self.__kernel_map[capability][p])
self._commands.append(command) self._commands.append(command)
self._precision_shape_in_commands.append((p, m, n, k))
return True return True
...@@ -116,7 +119,7 @@ class CudaGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -116,7 +119,7 @@ class CudaGemmFlopsBenchmark(GemmFlopsBenchmark):
Return: Return:
True if the raw output string is valid and result can be extracted. True if the raw output string is valid and result can be extracted.
""" """
precision = self._precision_need_to_run[cmd_idx] precision, m, n, k = self._precision_shape_in_commands[cmd_idx]
self._result.add_raw_data('raw_output_' + precision, raw_output, self._args.log_raw_data) self._result.add_raw_data('raw_output_' + precision, raw_output, self._args.log_raw_data)
valid = True valid = True
...@@ -138,7 +141,7 @@ class CudaGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -138,7 +141,7 @@ class CudaGemmFlopsBenchmark(GemmFlopsBenchmark):
) )
return False return False
self._result.add_result(self._metric_map[precision], max(flops)) self._result.add_result(self._get_metric_name(precision, m, n, k), max(flops))
return True return True
......
...@@ -262,14 +262,16 @@ class DtkGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -262,14 +262,16 @@ class DtkGemmFlopsBenchmark(GemmFlopsBenchmark):
if not super()._preprocess(): if not super()._preprocess():
return False return False
self._precision_shape_in_commands = []
for p in self._precision_need_to_run: for p in self._precision_need_to_run:
for m, n, k in self._shapes_to_run:
command = os.path.join(self._args.bin_dir, self._bin_name) command = os.path.join(self._args.bin_dir, self._bin_name)
command += ' ' + self.__precision_and_kernel_map[p] command += ' ' + self.__precision_and_kernel_map[p]
command += ' --iters {}'.format(self._args.iterations) command += ' --iters {}'.format(self._args.iterations)
command += ' --cold_iters {}'.format(self._args.num_warmup) command += ' --cold_iters {}'.format(self._args.num_warmup)
command += ' --transposeA {} --transposeB {}'.format(self._args.transposeA, self._args.transposeB) command += ' --transposeA {} --transposeB {}'.format(self._args.transposeA, self._args.transposeB)
command += ' --side {} --uplo {} --diag {}'.format(self._args.side, self._args.uplo, self._args.diag) command += ' --side {} --uplo {} --diag {}'.format(self._args.side, self._args.uplo, self._args.diag)
command += ' -m {} -n {} -k {}'.format(self._args.m, self._args.n, self._args.k) command += ' -m {} -n {} -k {}'.format(m, n, k)
command += ' --alpha {} --beta {}'.format(self._args.alpha, self._args.beta) command += ' --alpha {} --beta {}'.format(self._args.alpha, self._args.beta)
command += ' --kl {} --ku {}'.format(self._args.kl, self._args.ku) command += ' --kl {} --ku {}'.format(self._args.kl, self._args.ku)
command += ' --lda {} --ldb {} --ldc {} --ldd {}'.format( command += ' --lda {} --ldb {} --ldc {} --ldd {}'.format(
...@@ -302,6 +304,7 @@ class DtkGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -302,6 +304,7 @@ class DtkGemmFlopsBenchmark(GemmFlopsBenchmark):
command += ' --device {}'.format(self._args.device) command += ' --device {}'.format(self._args.device)
command += ' --initialization {}'.format(self._args.initialization) command += ' --initialization {}'.format(self._args.initialization)
self._commands.append(command) self._commands.append(command)
self._precision_shape_in_commands.append((p, m, n, k))
return True return True
...@@ -317,7 +320,7 @@ class DtkGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -317,7 +320,7 @@ class DtkGemmFlopsBenchmark(GemmFlopsBenchmark):
Return: Return:
True if the raw output string is valid and result can be extracted. True if the raw output string is valid and result can be extracted.
""" """
precision = self._precision_need_to_run[cmd_idx] precision, m, n, k = self._precision_shape_in_commands[cmd_idx]
self._result.add_raw_data('raw_output_' + precision, raw_output, self._args.log_raw_data) self._result.add_raw_data('raw_output_' + precision, raw_output, self._args.log_raw_data)
content = raw_output.splitlines() content = raw_output.splitlines()
...@@ -345,7 +348,7 @@ class DtkGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -345,7 +348,7 @@ class DtkGemmFlopsBenchmark(GemmFlopsBenchmark):
) )
return False return False
self._result.add_result(self._metric_map[precision], gflops) self._result.add_result(self._get_metric_name(precision, m, n, k), gflops)
return True return True
......
...@@ -3,11 +3,46 @@ ...@@ -3,11 +3,46 @@
"""Module of the FLOPs performance benchmark base class.""" """Module of the FLOPs performance benchmark base class."""
import itertools
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.benchmarks import ReturnCode from superbench.benchmarks import ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
def mrange(start, stop=-1, factor=2, symbol='x'):
"""Range constructor with multiplication or addition factor."""
if stop == -1:
yield start
return
if symbol == 'x':
while True:
yield start
start *= factor
if start > stop or start == 0 or factor < 2:
break
elif symbol == '+':
while True:
yield start
start += factor
if start > stop or start == 0 or factor < 1:
break
else:
raise ValueError(f'Invalid symbol {symbol}.')
def validate_mrange(string):
"""Validate mrange string in format start[[:stop]:factor]."""
nums = string.split(':')
if len(nums) > 3:
return False
if len(nums) < 3:
return all(x.isdigit() for x in nums)
return nums[0].isdigit() and nums[1].isdigit() and (nums[2].lstrip('+').isdigit() or nums[2].lstrip('x').isdigit())
class GemmFlopsBenchmark(MicroBenchmarkWithInvoke): class GemmFlopsBenchmark(MicroBenchmarkWithInvoke):
"""The GEMM FLOPs performance benchmark base class.""" """The GEMM FLOPs performance benchmark base class."""
def __init__(self, name, parameters=''): def __init__(self, name, parameters=''):
...@@ -23,6 +58,7 @@ class GemmFlopsBenchmark(MicroBenchmarkWithInvoke): ...@@ -23,6 +58,7 @@ class GemmFlopsBenchmark(MicroBenchmarkWithInvoke):
'fp64', 'fp32', 'fp16', 'fp64_tc', 'tf32_tc', 'bf16_tc', 'fp16_tc', 'int8_tc', 'int4_tc' 'fp64', 'fp32', 'fp16', 'fp64_tc', 'tf32_tc', 'bf16_tc', 'fp16_tc', 'int8_tc', 'int4_tc'
] ]
self._precision_need_to_run = list() self._precision_need_to_run = list()
self._shapes_to_run = list()
self._metric_map = { self._metric_map = {
'fp64': 'fp64_flops', 'fp64': 'fp64_flops',
'fp32': 'fp32_flops', 'fp32': 'fp32_flops',
...@@ -71,6 +107,13 @@ class GemmFlopsBenchmark(MicroBenchmarkWithInvoke): ...@@ -71,6 +107,13 @@ class GemmFlopsBenchmark(MicroBenchmarkWithInvoke):
required=False, required=False,
help='The M dim of matmul (N, K) * (K, M).', help='The M dim of matmul (N, K) * (K, M).',
) )
self._parser.add_argument(
'--shapes',
type=str,
nargs='+',
default=list(),
help='Shapes in m,n,k format. Support format start:stop:factor, e.g., 4096:32768:2.',
)
self._parser.add_argument( self._parser.add_argument(
'--precision', '--precision',
type=str, type=str,
...@@ -106,4 +149,31 @@ class GemmFlopsBenchmark(MicroBenchmarkWithInvoke): ...@@ -106,4 +149,31 @@ class GemmFlopsBenchmark(MicroBenchmarkWithInvoke):
self._result.set_return_code(ReturnCode.NO_SUPPORTED_PRECISION) self._result.set_return_code(ReturnCode.NO_SUPPORTED_PRECISION)
return False return False
shapes = self._args.shapes or [f'{self._args.m},{self._args.n},{self._args.k}']
for shape in shapes:
shape_list = shape.replace(',', ' ').split()
if len(shape_list) != 3 or not all(validate_mrange(x) for x in shape_list):
logger.error(f'Invalid shape - benchmark: {self._name}, shape: {shape}.')
return False
for m, n, k in itertools.product(
*map(
lambda dim: mrange(
*map(lambda value: int(value.lstrip('+').lstrip('x')), dim.split(':')),
symbol=dim.split(':')[2][0]
if len(dim.split(':')) == 3 and any([operator in dim for operator in ['+', 'x']]) else 'x'
), shape_list
)
):
self._shapes_to_run.append((m, n, k))
return True return True
def _get_metric_name(self, precision, m, n, k):
"""Build metric name with precision and GEMM shape."""
metric = self._metric_map[precision]
if metric.endswith('_flops'):
return f'{metric[:-len("_flops")]}_m{m}_n{n}_k{k}_flops'
if metric.endswith('_iops'):
return f'{metric[:-len("_iops")]}_m{m}_n{n}_k{k}_iops'
return f'{metric}_m{m}_n{n}_k{k}'
...@@ -109,17 +109,20 @@ class RocmGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -109,17 +109,20 @@ class RocmGemmFlopsBenchmark(GemmFlopsBenchmark):
if not super()._preprocess(): if not super()._preprocess():
return False return False
self._precision_shape_in_commands = []
for p in self._precision_need_to_run: for p in self._precision_need_to_run:
for m, n, k in self._shapes_to_run:
command = os.path.join(self._args.bin_dir, self._bin_name) command = os.path.join(self._args.bin_dir, self._bin_name)
command += ' ' + self.__precision_and_kernel_map[p] command += ' ' + self.__precision_and_kernel_map[p]
command += ' --transposeA {} --transposeB {}'.format(self._args.transposeA, self._args.transposeB) command += ' --transposeA {} --transposeB {}'.format(self._args.transposeA, self._args.transposeB)
command += ' -m {} -n {} -k {}'.format(self._args.m, self._args.n, self._args.k) command += ' -m {} -n {} -k {}'.format(m, n, k)
command += ' --alpha {} --beta {}'.format(self._args.alpha, self._args.beta) command += ' --alpha {} --beta {}'.format(self._args.alpha, self._args.beta)
command += ' --lda {} --ldb {} --ldc {} --ldd {}'.format( command += ' --lda {} --ldb {} --ldc {} --ldd {}'.format(
self._args.lda, self._args.ldb, self._args.ldc, self._args.ldd self._args.lda, self._args.ldb, self._args.ldc, self._args.ldd
) )
command += ' --initialization {}'.format(self._args.initialization) command += ' --initialization {}'.format(self._args.initialization)
self._commands.append(command) self._commands.append(command)
self._precision_shape_in_commands.append((p, m, n, k))
return True return True
...@@ -135,7 +138,7 @@ class RocmGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -135,7 +138,7 @@ class RocmGemmFlopsBenchmark(GemmFlopsBenchmark):
Return: Return:
True if the raw output string is valid and result can be extracted. True if the raw output string is valid and result can be extracted.
""" """
precision = self._precision_need_to_run[cmd_idx] precision, m, n, k = self._precision_shape_in_commands[cmd_idx]
self._result.add_raw_data('raw_output_' + precision, raw_output, self._args.log_raw_data) self._result.add_raw_data('raw_output_' + precision, raw_output, self._args.log_raw_data)
content = raw_output.splitlines() content = raw_output.splitlines()
...@@ -163,7 +166,7 @@ class RocmGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -163,7 +166,7 @@ class RocmGemmFlopsBenchmark(GemmFlopsBenchmark):
) )
return False return False
self._result.add_result(self._metric_map[precision], gflops) self._result.add_result(self._get_metric_name(precision, m, n, k), gflops)
return True return True
......
...@@ -40,9 +40,8 @@ superbench: ...@@ -40,9 +40,8 @@ superbench:
gemm-flops: gemm-flops:
<<: *default_local_mode <<: *default_local_mode
parameters: parameters:
m: 7680 shapes:
n: 8192 - 7680,8192,8192
k: 8192
hipblaslt-gemm: hipblaslt-gemm:
enable: true enable: true
modes: modes:
......
...@@ -37,9 +37,8 @@ superbench: ...@@ -37,9 +37,8 @@ superbench:
gemm-flops: gemm-flops:
<<: *default_local_mode <<: *default_local_mode
parameters: parameters:
m: 7680 shapes:
n: 8192 - 7680,8192,8192
k: 8192
hipblaslt-gemm: hipblaslt-gemm:
enable: true enable: true
modes: modes:
......
...@@ -87,9 +87,22 @@ Problem,Provider,OperationKind,Operation,Disposition,Status,gemm_kind,m,n,k,A,B, ...@@ -87,9 +87,22 @@ Problem,Provider,OperationKind,Operation,Disposition,Status,gemm_kind,m,n,k,A,B,
assert (benchmark._process_raw_result(1, raw_output_tf32_tc)) assert (benchmark._process_raw_result(1, raw_output_tf32_tc))
assert (benchmark._process_raw_result(2, raw_output_fp16_tc)) assert (benchmark._process_raw_result(2, raw_output_fp16_tc))
assert (benchmark.result['fp32_flops'][0] == 18369.7) assert (benchmark.result['fp32_m2048_n1024_k512_flops'][0] == 18369.7)
assert (benchmark.result['tf32_tc_flops'][0] == 128677) assert (benchmark.result['tf32_tc_m2048_n1024_k512_flops'][0] == 128677)
assert (benchmark.result['fp16_tc_flops'][0] == 281048) assert (benchmark.result['fp16_tc_m2048_n1024_k512_flops'][0] == 281048)
# Negative case - Add invalid raw output. # Negative case - Add invalid raw output.
assert (benchmark._process_raw_result(3, 'Invalid raw output') is False) assert (benchmark._process_raw_result(3, 'Invalid raw output') is False)
benchmark = benchmark_class(
benchmark_name,
parameters='--num_warmup 200 --precision fp32 --shapes 4096,4096,4096 8192:16384:2,4096,8192'
)
ret = benchmark._preprocess()
if dm.device_manager.get_device_compute_capability() in benchmark._CudaGemmFlopsBenchmark__kernel_map:
assert (ret is True)
assert (len(benchmark._commands) == 3)
expected_shapes = [(4096, 4096, 4096), (8192, 4096, 8192), (16384, 4096, 8192)]
assert (
[shape for _, *shape in benchmark._precision_shape_in_commands] == [list(x) for x in expected_shapes]
)
...@@ -31,14 +31,17 @@ class FakeGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -31,14 +31,17 @@ class FakeGemmFlopsBenchmark(GemmFlopsBenchmark):
return False return False
# Check the arguments and generate the commands # Check the arguments and generate the commands
self._precision_shape_in_commands = []
for precision in self._precision_need_to_run: for precision in self._precision_need_to_run:
for m, n, k in self._shapes_to_run:
command = os.path.join(self._args.bin_dir, self._bin_name) command = os.path.join(self._args.bin_dir, self._bin_name)
command += ' "--precision ' + precision command += ' "--precision ' + precision
command += ' --m ' + str(self._args.m) command += ' --m ' + str(m)
command += ' --n ' + str(self._args.n) command += ' --n ' + str(n)
command += ' --k ' + str(self._args.k) command += ' --k ' + str(k)
command += ' --num_warmup ' + str(self._args.num_warmup) + '"' command += ' --num_warmup ' + str(self._args.num_warmup) + '"'
self._commands.append(command) self._commands.append(command)
self._precision_shape_in_commands.append((precision, m, n, k))
return True return True
...@@ -61,9 +64,10 @@ class FakeGemmFlopsBenchmark(GemmFlopsBenchmark): ...@@ -61,9 +64,10 @@ class FakeGemmFlopsBenchmark(GemmFlopsBenchmark):
for param in params[1:]: for param in params[1:]:
key_value = param.split() key_value = param.split()
if key_value[0] == 'precision': if key_value[0] == 'precision':
if key_value[1] != self._precision_need_to_run[cmd_idx]: if key_value[1] != self._precision_shape_in_commands[cmd_idx][0]:
return False return False
metric = self._precision_need_to_run[cmd_idx] precision, m, n, k = self._precision_shape_in_commands[cmd_idx]
metric = self._get_metric_name(precision, m, n, k)
except BaseException: except BaseException:
return False return False
...@@ -95,7 +99,17 @@ def test_gemm_flops_performance_base(): ...@@ -95,7 +99,17 @@ def test_gemm_flops_performance_base():
command = benchmark._bin_name + benchmark._commands[i].split(benchmark._bin_name)[1] command = benchmark._bin_name + benchmark._commands[i].split(benchmark._bin_name)[1]
assert (command == expected_command[i]) assert (command == expected_command[i])
for i, metric in enumerate( for i, metric in enumerate(
['fp64', 'fp32', 'fp16', 'fp64_tc', 'tf32_tc', 'bf16_tc', 'fp16_tc', 'int8_tc', 'int4_tc'] [
'fp64_m16384_n16384_k16384_flops',
'fp32_m16384_n16384_k16384_flops',
'fp16_m16384_n16384_k16384_flops',
'fp64_tc_m16384_n16384_k16384_flops',
'tf32_tc_m16384_n16384_k16384_flops',
'bf16_tc_m16384_n16384_k16384_flops',
'fp16_tc_m16384_n16384_k16384_flops',
'int8_tc_m16384_n16384_k16384_iops',
'int4_tc_m16384_n16384_k16384_iops'
]
): ):
assert (metric in benchmark.result) assert (metric in benchmark.result)
assert (len(benchmark.result[metric]) == 1) assert (len(benchmark.result[metric]) == 1)
...@@ -114,7 +128,13 @@ def test_gemm_flops_performance_base(): ...@@ -114,7 +128,13 @@ def test_gemm_flops_performance_base():
for i in range(len(expected_command)): for i in range(len(expected_command)):
command = benchmark._bin_name + benchmark._commands[i].split(benchmark._bin_name)[1] command = benchmark._bin_name + benchmark._commands[i].split(benchmark._bin_name)[1]
assert (command == expected_command[i]) assert (command == expected_command[i])
for i, metric in enumerate(['fp64', 'fp32', 'fp16']): for i, metric in enumerate(
[
'fp64_m16384_n16384_k16384_flops',
'fp32_m16384_n16384_k16384_flops',
'fp16_m16384_n16384_k16384_flops'
]
):
assert (metric in benchmark.result) assert (metric in benchmark.result)
assert (len(benchmark.result[metric]) == 1) assert (len(benchmark.result[metric]) == 1)
...@@ -122,8 +142,28 @@ def test_gemm_flops_performance_base(): ...@@ -122,8 +142,28 @@ def test_gemm_flops_performance_base():
assert (benchmark._benchmark_type == BenchmarkType.MICRO) assert (benchmark._benchmark_type == BenchmarkType.MICRO)
assert (benchmark.run() is True) assert (benchmark.run() is True)
benchmark = FakeGemmFlopsBenchmark(
'fake',
parameters='--precision fp32 --shapes 4096,4096,4096 8192:16384:2,4096,8192'
)
assert (benchmark._benchmark_type == BenchmarkType.MICRO)
assert (benchmark.run() is True)
expected_command = [
'echo "--precision fp32 --m 4096 --n 4096 --k 4096 --num_warmup 5"',
'echo "--precision fp32 --m 8192 --n 4096 --k 8192 --num_warmup 5"',
'echo "--precision fp32 --m 16384 --n 4096 --k 8192 --num_warmup 5"',
]
assert (len(benchmark._commands) == len(expected_command))
for i in range(len(expected_command)):
command = benchmark._bin_name + benchmark._commands[i].split(benchmark._bin_name)[1]
assert (command == expected_command[i])
# Negative case - INVALID_ARGUMENT. # Negative case - INVALID_ARGUMENT.
benchmark = FakeGemmFlopsBenchmark('fake', parameters='--precision bf64') benchmark = FakeGemmFlopsBenchmark('fake', parameters='--precision bf64')
assert (benchmark._benchmark_type == BenchmarkType.MICRO) assert (benchmark._benchmark_type == BenchmarkType.MICRO)
assert (benchmark.run() is False) assert (benchmark.run() is False)
assert (benchmark.return_code == ReturnCode.NO_SUPPORTED_PRECISION) assert (benchmark.return_code == ReturnCode.NO_SUPPORTED_PRECISION)
benchmark = FakeGemmFlopsBenchmark('fake', parameters='--shapes 4096,4096')
assert (benchmark._benchmark_type == BenchmarkType.MICRO)
assert (benchmark.run() is False)
...@@ -85,11 +85,21 @@ T,N,7680,8192,8192,1,8416,0,8416,8416,8416,1, 162675, 6336.5 ...@@ -85,11 +85,21 @@ T,N,7680,8192,8192,1,8416,0,8416,8416,8416,1, 162675, 6336.5
assert (benchmark._process_raw_result(3, raw_output_BF16_X)) assert (benchmark._process_raw_result(3, raw_output_BF16_X))
assert (benchmark._process_raw_result(4, raw_output_INT8_X)) assert (benchmark._process_raw_result(4, raw_output_INT8_X))
assert (benchmark.result['fp64_flops'][0] == 10037.5) assert (benchmark.result['fp64_m7680_n8192_k8192_flops'][0] == 10037.5)
assert (benchmark.result['fp32_xdlops_flops'][0] == 39441.6) assert (benchmark.result['fp32_xdlops_m7680_n8192_k8192_flops'][0] == 39441.6)
assert (benchmark.result['fp16_xdlops_flops'][0] == 153728) assert (benchmark.result['fp16_xdlops_m7680_n8192_k8192_flops'][0] == 153728)
assert (benchmark.result['bf16_xdlops_flops'][0] == 81374.3) assert (benchmark.result['bf16_xdlops_m7680_n8192_k8192_flops'][0] == 81374.3)
assert (benchmark.result['int8_xdlops_iops'][0] == 162675) assert (benchmark.result['int8_xdlops_m7680_n8192_k8192_iops'][0] == 162675)
# Negative case - Add invalid raw output. # Negative case - Add invalid raw output.
assert (benchmark._process_raw_result(4, 'Invalid raw output') is False) assert (benchmark._process_raw_result(4, 'Invalid raw output') is False)
benchmark = benchmark_class(
benchmark_name,
parameters='--precision fp32_xdlops --shapes 4096,4096,4096 8192:16384:2,4096,8192'
)
assert (benchmark._preprocess() is True)
assert (len(benchmark._commands) == 3)
expected_shapes = [(4096, 4096, 4096), (8192, 4096, 8192), (16384, 4096, 8192)]
assert ([shape for _, *shape in benchmark._precision_shape_in_commands] == [list(x) for x in expected_shapes])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment