Unverified Commit dbeba805 authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Benchmark - Support batch/shape range in cublaslt gemm (#494)

Support batch and shape range with multiplication factors in cublaslt
gemm benchmark.
parent 655bd0aa
...@@ -66,9 +66,9 @@ Measure the GEMM performance of [`cublasLtMatmul`](https://docs.nvidia.com/cuda/ ...@@ -66,9 +66,9 @@ Measure the GEMM performance of [`cublasLtMatmul`](https://docs.nvidia.com/cuda/
#### Metrics #### Metrics
| Name | Unit | Description | | Name | Unit | Description |
|------------------------------------------------|----------------|---------------------------------| |----------------------------------------------------------|----------------|---------------------------------|
| cublaslt-gemm/${dtype}\_${m}\_${n}\_${k}_flops | FLOPS (TFLOPS) | TFLOPS of measured GEMM kernel. | | cublaslt-gemm/${dtype}\_${batch}\_${m}\_${n}\_${k}_flops | FLOPS (TFLOPS) | TFLOPS of measured GEMM kernel. |
### `cublas-function` ### `cublas-function`
...@@ -195,13 +195,13 @@ performed by [University of Virginia STREAM benchmark](https://www.cs.virginia.e ...@@ -195,13 +195,13 @@ performed by [University of Virginia STREAM benchmark](https://www.cs.virginia.e
#### Metrics #### Metrics
| Name | Unit | Description | | Name | Unit | Description |
|----------------------------------------------------------|------------------|---------------------------------------------------------------------| |----------------------------------------------------------|------------------|----------------------------------------------------------------|
| cpu-stream/threads | | Number of threads used for the test. Determined by core count. | | cpu-stream/threads | | Number of threads used for the test. Determined by core count. |
| cpu-stream/['copy', 'scale', 'add', 'triad']\_throughput | bandwidth (MB/s) | Memory throughput of designated kerel operation. | | cpu-stream/['copy', 'scale', 'add', 'triad']\_throughput | bandwidth (MB/s) | Memory throughput of designated kerel operation. |
| cpu-stream/['copy', 'scale', 'add', 'triad']\_time_avg | time (s) | Average elapsed times over all iterations. | | cpu-stream/['copy', 'scale', 'add', 'triad']\_time_avg | time (s) | Average elapsed times over all iterations. |
| cpu-stream/['copy', 'scale', 'add', 'triad']\_time_min | time (s) | Minimum elapsed times over all iterations. | | cpu-stream/['copy', 'scale', 'add', 'triad']\_time_min | time (s) | Minimum elapsed times over all iterations. |
| cpu-stream/['copy', 'scale', 'add', 'triad']\_time_max | time (s) | Maximum elapsed times over all iterations. | | cpu-stream/['copy', 'scale', 'add', 'triad']\_time_max | time (s) | Maximum elapsed times over all iterations. |
## Communication Benchmarks ## Communication Benchmarks
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Module of the cuBLASLt GEMM benchmark.""" """Module of the cuBLASLt GEMM benchmark."""
import os import os
import itertools
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
...@@ -24,6 +25,37 @@ def __init__(self, name, parameters=''): ...@@ -24,6 +25,37 @@ def __init__(self, name, parameters=''):
self._bin_name = 'cublaslt_gemm' self._bin_name = 'cublaslt_gemm'
self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2'] self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2']
def mrange(self, start, stop=-1, multiplication_factor=2):
"""Range constructor with multiplication factor.
Args:
start (int): Start number.
stop (int, optional): Stop number. Defaults to -1.
multiplication_factor (int, optional): Multiplication factor. Defaults to 2.
Yields:
int: number in the range.
"""
while True:
yield start
start *= multiplication_factor
if start > stop or start == 0 or multiplication_factor < 2:
break
def validate_mrange(self, string):
"""Validate mrange string in format start[[:stop]:multiplication_factor].
Args:
string (str): mrange string.
Returns:
bool: whether the mrange is expected.
"""
nums = string.split(':')
if len(nums) > 3:
return False
return bool(all(x.isdigit() for x in nums))
def add_parser_arguments(self): def add_parser_arguments(self):
"""Add the specified arguments.""" """Add the specified arguments."""
super().add_parser_arguments() super().add_parser_arguments()
...@@ -33,14 +65,17 @@ def add_parser_arguments(self): ...@@ -33,14 +65,17 @@ def add_parser_arguments(self):
type=str, type=str,
nargs='+', nargs='+',
default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]], default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
help='Shapes in m,n,k format.', help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
) )
self._parser.add_argument( self._parser.add_argument(
'--batch', '--batch',
type=int, type=str,
default=0, default='0',
required=False, required=False,
help='Batch size for strided batch GEMM, set 0 to disable.', help=(
'Batch size for strided batch GEMM, set 0 to disable.'
' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
),
) )
self._parser.add_argument( self._parser.add_argument(
'--num_warmup', '--num_warmup',
...@@ -57,11 +92,12 @@ def add_parser_arguments(self): ...@@ -57,11 +92,12 @@ def add_parser_arguments(self):
help='Number of steps to measure.', help='Number of steps to measure.',
) )
self._parser.add_argument( self._parser.add_argument(
'--in_type', '--in_types',
type=str, type=str,
default='fp8e4m3', nargs='+',
default=['fp8e4m3'],
required=False, required=False,
help='Input data type, supports {}.'.format(' '.join(self._in_types)), help='List of input data types, support {}.'.format(' '.join(self._in_types)),
) )
def _preprocess(self): def _preprocess(self):
...@@ -75,20 +111,28 @@ def _preprocess(self): ...@@ -75,20 +111,28 @@ def _preprocess(self):
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name) self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)
if self._args.in_type not in self._in_types: if not self.validate_mrange(self._args.batch):
logger.error(f'Invalid input type {self._args.in_type}.') logger.error(f'Invalid batch size {self._args.batch}.')
return False return False
self._commands = [] self._commands = []
for shape in self._args.shapes: for _in_type in self._args.in_types:
shape_list = shape.replace(',', ' ').split() if _in_type not in self._in_types:
if len(shape_list) != 3 or not all(x.isdigit() for x in shape_list): logger.error(f'Invalid input type {_in_type}.')
logger.error(f'Invalid shape {shape}.')
return False return False
self._commands.append( for _b in self.mrange(*map(int, self._args.batch.split(':'))):
f'{self.__bin_path} -m {shape_list[0]} -n {shape_list[1]} -k {shape_list[2]} ' for shape in self._args.shapes:
f'-b {self._args.batch} -w {self._args.num_warmup} -i {self._args.num_steps} -t {self._args.in_type}' shape_list = shape.replace(',', ' ').split()
) if len(shape_list) != 3 or not all(self.validate_mrange(x) for x in shape_list):
logger.error(f'Invalid shape {shape}.')
return False
for _m, _n, _k in itertools.product(
*map(lambda shape: self.mrange(*map(int, shape.split(':'))), shape_list)
):
self._commands.append(
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
)
return True return True
...@@ -110,7 +154,9 @@ def _process_raw_result(self, cmd_idx, raw_output): ...@@ -110,7 +154,9 @@ def _process_raw_result(self, cmd_idx, raw_output):
fields = raw_output.strip().split() fields = raw_output.strip().split()
if len(fields) != 6 or not all(x.isdigit() for x in fields[:4]): if len(fields) != 6 or not all(x.isdigit() for x in fields[:4]):
raise ValueError('Invalid result.') raise ValueError('Invalid result.')
self._result.add_result(f'{self._args.in_type}_{"_".join(fields[:3])}_flops', float(fields[-1])) self._result.add_result(
f'{self._commands[cmd_idx].split()[-1]}_{fields[3]}_{"_".join(fields[:3])}_flops', float(fields[-1])
)
except BaseException as e: except BaseException as e:
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE) self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
logger.error( logger.error(
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Tests for cublaslt-gemm benchmark.""" """Tests for cublaslt-gemm benchmark."""
import unittest import unittest
from types import SimpleNamespace from types import GeneratorType, SimpleNamespace
from tests.helper.testcase import BenchmarkTestCase from tests.helper.testcase import BenchmarkTestCase
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
...@@ -19,7 +19,12 @@ def setUpClass(cls): ...@@ -19,7 +19,12 @@ def setUpClass(cls):
super().setUpClass() super().setUpClass()
cls.benchmark_name = 'cublaslt-gemm' cls.benchmark_name = 'cublaslt-gemm'
cls.createMockEnvs(cls) cls.createMockEnvs(cls)
cls.createMockFiles(cls, ['bin/cublaslt_fp8_gemm']) cls.createMockFiles(cls, ['bin/cublaslt_gemm'])
def get_benchmark(self):
"""Get Benchmark."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
return benchmark_cls(self.benchmark_name, parameters='')
def test_cublaslt_gemm_cls(self): def test_cublaslt_gemm_cls(self):
"""Test cublaslt-gemm benchmark class.""" """Test cublaslt-gemm benchmark class."""
...@@ -30,11 +35,56 @@ def test_cublaslt_gemm_cls(self): ...@@ -30,11 +35,56 @@ def test_cublaslt_gemm_cls(self):
else: else:
self.assertIsNone(benchmark_cls) self.assertIsNone(benchmark_cls)
def test_mrange(self):
"""Test mrange generation."""
benchmark = self.get_benchmark()
self.assertIsInstance(benchmark.mrange(1), GeneratorType)
self.assertListEqual([4, 8, 16, 32], list(benchmark.mrange(4, 32, 2)))
self.assertListEqual([2, 4, 8, 16], list(benchmark.mrange(2, 31, 2)))
self.assertListEqual([2, 4, 8], list(benchmark.mrange(2, 8)))
self.assertListEqual([2], list(benchmark.mrange(2, 0, 2)))
self.assertListEqual([2], list(benchmark.mrange(2)))
self.assertListEqual([2], list(benchmark.mrange(2, 4, 1)))
self.assertListEqual([2], list(benchmark.mrange(2, 4, 0)))
self.assertListEqual([0], list(benchmark.mrange(0, 0)))
self.assertListEqual([0], list(benchmark.mrange(0)))
def test_validate_mrange(self):
"""Test mrange validation."""
benchmark = self.get_benchmark()
self.assertTrue(benchmark.validate_mrange('2:32:2'))
self.assertTrue(benchmark.validate_mrange('4:32'))
self.assertTrue(benchmark.validate_mrange('8'))
self.assertFalse(benchmark.validate_mrange('2:32:2:4'))
self.assertFalse(benchmark.validate_mrange('2.5:32'))
def test_cublaslt_gemm_command_generation(self):
"""Test cublaslt-gemm benchmark command generation."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
benchmark = benchmark_cls(
self.benchmark_name,
parameters='--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64',
)
self.assertTrue(benchmark._preprocess())
self.assertEqual(4 * (2 * 2 * 3 + 2) * 3, len(benchmark._commands))
def cmd(t, b, m, n, k):
return f'{benchmark._CublasLtBenchmark__bin_path} -m {m} -n {n} -k {k} -b {b} -w 20 -i 50 -t {t}'
for _t in ['fp16', 'fp32', 'fp64']:
for _b in [2, 4, 8, 16]:
for _m in [2, 4]:
for _n in [4, 8]:
for _k in [8, 16, 32]:
self.assertIn(cmd(_t, _b, _m, _n, _k), benchmark._commands)
for _m in [32, 128]:
self.assertIn(cmd(_t, _b, _m, 128, 128), benchmark._commands)
def test_cublaslt_gemm_result_parsing(self): def test_cublaslt_gemm_result_parsing(self):
"""Test cublaslt-gemm benchmark result parsing.""" """Test cublaslt-gemm benchmark result parsing."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA) benchmark = self.get_benchmark()
benchmark = benchmark_cls(self.benchmark_name, parameters='') self.assertTrue(benchmark._preprocess())
benchmark._args = SimpleNamespace(shapes=['16,16,16', '32,64,128'], in_type='fp8e4m3', log_raw_data=False) benchmark._args = SimpleNamespace(shapes=['16,16,16', '32,64,128'], in_types=['fp8e4m3'], log_raw_data=False)
benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1) benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)
# Positive case - valid raw output # Positive case - valid raw output
...@@ -44,7 +94,7 @@ def test_cublaslt_gemm_result_parsing(self): ...@@ -44,7 +94,7 @@ def test_cublaslt_gemm_result_parsing(self):
self.assertEqual(3, len(benchmark.result)) self.assertEqual(3, len(benchmark.result))
for shape in benchmark._args.shapes: for shape in benchmark._args.shapes:
self.assertEqual(2.222, benchmark.result[f'fp8e4m3_{shape.replace(",", "_")}_flops'][0]) self.assertEqual(2.222, benchmark.result[f'fp8e4m3_0_{shape.replace(",", "_")}_flops'][0])
# Negative case - invalid raw output # Negative case - invalid raw output
self.assertFalse(benchmark._process_raw_result(1, 'cuBLAS API failed')) self.assertFalse(benchmark._process_raw_result(1, 'cuBLAS API failed'))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment