Unverified Commit e8777e24 authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Benchmarks: Micro benchmarks - add fp8 and initialization for hipblaslt benchmark (#605)

**Description**
Benchmarks: Micro benchmarks - add fp8 and initialization for hipblaslt
benchmark.
parent c635f755
......@@ -4,7 +4,6 @@
"""Module of the hipBlasLt GEMM benchmark."""
import os
import re
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
......@@ -23,11 +22,12 @@ def __init__(self, name, parameters=''):
super().__init__(name, parameters)
self._bin_name = 'hipblaslt-bench'
self._in_types = ['fp32', 'fp16', 'bf16']
self._in_types = ['fp32', 'fp16', 'bf16', 'fp8']
self._in_type_map = {
'fp16': '--a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --compute_type f32_r',
'fp32': '--a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --compute_type f32_r',
'bf16': '--a_type bf16_r --b_type bf16_r --c_type bf16_r --d_type bf16_r --compute_type f32_r',
'fp8': '--a_type f8_r --b_type f8_r --c_type f8_r --d_type f8_r --compute_type f32_r',
}
def add_parser_arguments(self):
......@@ -42,6 +42,30 @@ def add_parser_arguments(self):
required=False,
help='List of input data types, support {}.'.format(' '.join(self._in_types)),
)
self._parser.add_argument(
'--initialization',
type=str,
default='rand_int',
choices=['trig_float', 'rand_int', 'hpl'],
required=False,
help='Initialize matrix data.',
)
self._parser.add_argument(
'--transA',
type=str,
default='N',
choices=['N', 'T', 'C'],
required=False,
help='Transpose matrix A.',
)
self._parser.add_argument(
'--transB',
type=str,
default='N',
choices=['N', 'T', 'C'],
required=False,
help='Transpose matrix B.',
)
def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
......@@ -58,7 +82,9 @@ def _preprocess(self):
self._precision_in_commands = []
for (_m, _n, _k, _b, _in_type) in self._shapes_to_run:
command = f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -j {self._args.num_warmup}' + \
f' -i {self._args.num_steps} {self._in_type_map[_in_type]}'
f' -i {self._args.num_steps} {self._in_type_map[_in_type]}' + \
f' --transA {self._args.transA} --transB {self._args.transB}' + \
f' --initialization {self._args.initialization}'
command = command + f' -b {str(_b)}' if _b > 0 else command
logger.info(command)
self._commands.append(command)
......@@ -97,9 +123,7 @@ def _process_raw_result(self, cmd_idx, raw_output):
fields = lines[index + 1].strip().split(',')
# Check the number of fields and the format of the first two fields
if len(fields) != 23 or not all(
re.match(r'\d*\.\d*$', item.strip()) or item.strip().isdigit() for item in fields[-2:]
):
if len(fields) != 23:
raise ValueError('Invalid result')
self._result.add_result(
......
......@@ -55,7 +55,7 @@ def test_hipblaslt_gemm_command_generation(self):
self.assertFalse(benchmark._preprocess())
benchmark = benchmark_cls(
self.benchmark_name,
parameters='--shapes 2:4,4:8,8:32 2:4,4:8,8:32:+4 --in_types fp16 fp32 bf16',
parameters='--shapes 2:4,4:8,8:32 2:4,4:8,8:32:+4 --in_types fp16 fp32 bf16 fp8',
)
self.assertTrue(benchmark._preprocess())
self.assertEqual((2 * 2 * 3 + 2 * 2 * 7) * len(benchmark._args.in_types), len(benchmark._commands))
......@@ -63,12 +63,16 @@ def test_hipblaslt_gemm_command_generation(self):
def cmd(t, b, m, n, k):
if b == 0:
return f'{benchmark._HipBlasLtBenchmark__bin_path} ' + \
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]}'
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]}' + \
f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \
f' --initialization {benchmark._args.initialization}'
else:
return f'{benchmark._HipBlasLtBenchmark__bin_path} ' + \
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} -b {b}'
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} -b {b}' + \
f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \
f' --initialization {benchmark._args.initialization}'
for _t in ['fp16', 'fp32', 'bf16']:
for _t in ['fp16', 'fp32', 'bf16', 'fp8']:
for _m in [2, 4]:
for _n in [4, 8]:
for _k in [8, 16, 32]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment