Commit 1a57f2d6 authored by one's avatar one
Browse files

Enhance DTK platform support and GPU detection

- Added Platform.DTK in the microbenchmark framework.
- Introduced new DTK hipblaslt benchmark class and corresponding tests.
- Updated Dockerfile to include hipblaslt-bench and its permissions.
- Registered DTK benchmarks in the benchmark registry for various performance tests.
- Enhanced GPU detection logic to recognize HYGON GPUs.

This update improves the benchmarking capabilities for DTK, ensuring compatibility and performance testing across platforms.
parent c4f39919
...@@ -129,7 +129,9 @@ RUN cd /tmp && \ ...@@ -129,7 +129,9 @@ RUN cd /tmp && \
# Add rocblas-bench to path # Add rocblas-bench to path
RUN ln -s ${ROCM_PATH}/lib/rocblas/benchmark_tool/rocblas-bench ${ROCM_PATH}/bin/ && \ RUN ln -s ${ROCM_PATH}/lib/rocblas/benchmark_tool/rocblas-bench ${ROCM_PATH}/bin/ && \
chmod +x ${ROCM_PATH}/bin/rocblas-bench chmod +x ${ROCM_PATH}/bin/rocblas-bench && \
ln -s ${ROCM_PATH}/lib/hipblaslt/benchmark_tool/hipblaslt-bench ${ROCM_PATH}/bin/ && \
chmod +x ${ROCM_PATH}/bin/hipblaslt-bench
ENV PATH="${MPI_HOME}/bin:${UCX_HOME}/bin:/opt/superbench/bin:/usr/local/bin/${PATH:+:${PATH}}" \ ENV PATH="${MPI_HOME}/bin:${UCX_HOME}/bin:/opt/superbench/bin:/usr/local/bin/${PATH:+:${PATH}}" \
LD_LIBRARY_PATH="${MPI_HOME}/lib:${UCX_HOME}/lib:/usr/lib/x86_64-linux-gnu/:/usr/local/lib/${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" \ LD_LIBRARY_PATH="${MPI_HOME}/lib:${UCX_HOME}/lib:/usr/lib/x86_64-linux-gnu/:/usr/local/lib/${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" \
......
...@@ -24,6 +24,7 @@ class Platform(Enum): ...@@ -24,6 +24,7 @@ class Platform(Enum):
CPU = 'CPU' CPU = 'CPU'
CUDA = 'CUDA' CUDA = 'CUDA'
ROCM = 'ROCm' ROCM = 'ROCm'
DTK = 'DTK'
DIRECTX = 'DirectX' DIRECTX = 'DirectX'
...@@ -91,7 +92,7 @@ def __init__(self, name, platform, parameters='', framework=Framework.NONE): ...@@ -91,7 +92,7 @@ def __init__(self, name, platform, parameters='', framework=Framework.NONE):
Args: Args:
name (str): name of benchmark in config file. name (str): name of benchmark in config file.
platform (Platform): Platform types like CUDA, ROCM. platform (Platform): Platform types like CUDA, ROCM, DTK.
parameters (str): predefined parameters of benchmark. parameters (str): predefined parameters of benchmark.
framework (Framework): Framework types like ONNXRUNTIME, PYTORCH. framework (Framework): Framework types like ONNXRUNTIME, PYTORCH.
""" """
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark
from superbench.benchmarks.micro_benchmarks.blaslt_function_base import BlasLtBaseBenchmark from superbench.benchmarks.micro_benchmarks.blaslt_function_base import BlasLtBaseBenchmark
from superbench.benchmarks.micro_benchmarks.cublaslt_function import CublasLtBenchmark from superbench.benchmarks.micro_benchmarks.cublaslt_function import CublasLtBenchmark
from superbench.benchmarks.micro_benchmarks.hipblaslt_function import HipBlasLtBenchmark from superbench.benchmarks.micro_benchmarks.rocm_hipblaslt_function import RocmHipBlasLtBenchmark
from superbench.benchmarks.micro_benchmarks.dtk_hipblaslt_function import DtkHipBlasLtBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_gemm_flops_performance import CudaGemmFlopsBenchmark from superbench.benchmarks.micro_benchmarks.cuda_gemm_flops_performance import CudaGemmFlopsBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance import CudaMemBwBenchmark from superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance import CudaMemBwBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark
...@@ -54,7 +55,8 @@ ...@@ -54,7 +55,8 @@
'CudnnBenchmark', 'CudnnBenchmark',
'DiskBenchmark', 'DiskBenchmark',
'DistInference', 'DistInference',
'HipBlasLtBenchmark', 'RocmHipBlasLtBenchmark',
'DtkHipBlasLtBenchmark',
'GPCNetBenchmark', 'GPCNetBenchmark',
'GemmFlopsBenchmark', 'GemmFlopsBenchmark',
'GpuBurnBenchmark', 'GpuBurnBenchmark',
......
...@@ -242,3 +242,4 @@ def _process_raw_result(self, cmd_idx, raw_output): # noqa: C901 ...@@ -242,3 +242,4 @@ def _process_raw_result(self, cmd_idx, raw_output): # noqa: C901
BenchmarkRegistry.register_benchmark('nccl-bw', CudaNcclBwBenchmark, platform=Platform.CUDA) BenchmarkRegistry.register_benchmark('nccl-bw', CudaNcclBwBenchmark, platform=Platform.CUDA)
BenchmarkRegistry.register_benchmark('rccl-bw', CudaNcclBwBenchmark, platform=Platform.ROCM) BenchmarkRegistry.register_benchmark('rccl-bw', CudaNcclBwBenchmark, platform=Platform.ROCM)
BenchmarkRegistry.register_benchmark('rccl-bw', CudaNcclBwBenchmark, platform=Platform.DTK)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the hipBlasLt GEMM benchmark."""
import os
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark
class DtkHipBlasLtBenchmark(BlasLtBaseBenchmark):
"""The hipBlasLt GEMM benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)
self._bin_name = 'hipblaslt-bench'
self._in_types = ['fp32', 'fp16', 'bf16', 'fp8']
self._in_type_map = {
'fp16': '--a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --compute_type f32_r',
'fp32': '--a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --compute_type f32_r',
'bf16': '--a_type bf16_r --b_type bf16_r --c_type bf16_r --d_type bf16_r --compute_type f32_r',
'fp8': '--a_type f8_r --b_type f8_r --c_type f8_r --d_type f8_r --compute_type f32_r',
}
def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()
self._parser.add_argument(
'--in_types',
type=str,
nargs='+',
default=['fp16'],
required=False,
help='List of input data types, support {}.'.format(' '.join(self._in_types)),
)
self._parser.add_argument(
'--initialization',
type=str,
default='rand_int',
choices=['trig_float', 'rand_int', 'hpl'],
required=False,
help='Initialize matrix data.',
)
self._parser.add_argument(
'--transA',
type=str,
default='N',
choices=['N', 'T', 'C'],
required=False,
help='Transpose matrix A.',
)
self._parser.add_argument(
'--transB',
type=str,
default='N',
choices=['N', 'T', 'C'],
required=False,
help='Transpose matrix B.',
)
def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
Return:
True if _preprocess() succeed.
"""
if not super()._preprocess():
return False
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)
self._commands = []
self._precision_in_commands = []
for (_m, _n, _k, _b, _in_type) in self._shapes_to_run:
command = f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -j {self._args.num_warmup}' + \
f' -i {self._args.num_steps} {self._in_type_map[_in_type]}' + \
f' --transA {self._args.transA} --transB {self._args.transB}' + \
f' --initialization {self._args.initialization}'
command = command + f' -b {str(_b)}' if _b > 0 else command
logger.info(command)
self._commands.append(command)
self._precision_in_commands.append(_in_type)
return True
def _process_raw_result(self, cmd_idx, raw_output):
"""Function to parse raw results and save the summarized results.
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
Args:
cmd_idx (int): the index of command corresponding with the raw_output.
raw_output (str): raw output string of the micro-benchmark.
Return:
True if the raw output string is valid and result can be extracted.
"""
self._result.add_raw_data(f'raw_output_{cmd_idx}', raw_output, self._args.log_raw_data)
try:
lines = raw_output.splitlines()
index = None
# Find the line containing 'hipblaslt-Gflops'
for i, line in enumerate(lines):
if 'hipblaslt-Gflops' in line:
index = i
break
if index is None:
raise ValueError('Line with "hipblaslt-Gflops" not found in the log.')
header = [field.strip().lstrip('[]0123456789:') for field in lines[index].strip().split(',')]
fields = [field.strip() for field in lines[index + 1].strip().split(',')]
if len(fields) != len(header):
raise ValueError('Invalid result')
batch_count_index = header.index('batch_count')
m_index = header.index('m')
n_index = header.index('n')
k_index = header.index('k')
gflops_index = header.index('hipblaslt-Gflops')
self._result.add_result(
f'{self._precision_in_commands[cmd_idx]}_{fields[batch_count_index]}_'
f'{"_".join(fields[m_index:k_index + 1])}_flops',
float(fields[gflops_index]) / 1000
)
except BaseException as e:
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
logger.error(
'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
self._curr_run_index, self._name, raw_output, str(e)
)
)
return False
return True
BenchmarkRegistry.register_benchmark('hipblaslt-gemm', DtkHipBlasLtBenchmark, platform=Platform.DTK)
...@@ -169,3 +169,4 @@ def _process_raw_result(self, cmd_idx, raw_output): ...@@ -169,3 +169,4 @@ def _process_raw_result(self, cmd_idx, raw_output):
BenchmarkRegistry.register_benchmark('gemm-flops', RocmGemmFlopsBenchmark, platform=Platform.ROCM) BenchmarkRegistry.register_benchmark('gemm-flops', RocmGemmFlopsBenchmark, platform=Platform.ROCM)
BenchmarkRegistry.register_benchmark('gemm-flops', RocmGemmFlopsBenchmark, platform=Platform.DTK)
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark
class HipBlasLtBenchmark(BlasLtBaseBenchmark): class RocmHipBlasLtBenchmark(BlasLtBaseBenchmark):
"""The hipBlasLt GEMM benchmark class.""" """The hipBlasLt GEMM benchmark class."""
def __init__(self, name, parameters=''): def __init__(self, name, parameters=''):
"""Constructor. """Constructor.
...@@ -142,4 +142,4 @@ def _process_raw_result(self, cmd_idx, raw_output): ...@@ -142,4 +142,4 @@ def _process_raw_result(self, cmd_idx, raw_output):
return True return True
BenchmarkRegistry.register_benchmark('hipblaslt-gemm', HipBlasLtBenchmark, platform=Platform.ROCM) BenchmarkRegistry.register_benchmark('hipblaslt-gemm', RocmHipBlasLtBenchmark, platform=Platform.ROCM)
...@@ -91,3 +91,4 @@ def _process_raw_result(self, cmd_idx, raw_output): ...@@ -91,3 +91,4 @@ def _process_raw_result(self, cmd_idx, raw_output):
BenchmarkRegistry.register_benchmark('mem-bw', RocmMemBwBenchmark, platform=Platform.ROCM) BenchmarkRegistry.register_benchmark('mem-bw', RocmMemBwBenchmark, platform=Platform.ROCM)
BenchmarkRegistry.register_benchmark('mem-bw', RocmMemBwBenchmark, platform=Platform.DTK)
...@@ -30,7 +30,7 @@ def register_benchmark(cls, name, class_def, parameters='', platform=None): ...@@ -30,7 +30,7 @@ def register_benchmark(cls, name, class_def, parameters='', platform=None):
name (str): internal name of benchmark. name (str): internal name of benchmark.
class_def (Benchmark): class object of benchmark. class_def (Benchmark): class object of benchmark.
parameters (str): predefined parameters of benchmark. parameters (str): predefined parameters of benchmark.
platform (Platform): Platform types like CUDA, ROCM. platform (Platform): Platform types like CUDA, ROCM, DTK.
""" """
if not name or not isinstance(name, str): if not name or not isinstance(name, str):
logger.log_and_raise( logger.log_and_raise(
...@@ -142,7 +142,7 @@ def create_benchmark_context(cls, name, platform=Platform.CPU, parameters='', fr ...@@ -142,7 +142,7 @@ def create_benchmark_context(cls, name, platform=Platform.CPU, parameters='', fr
Args: Args:
name (str): name of benchmark in config file. name (str): name of benchmark in config file.
platform (Platform): Platform types like Platform.CPU, Platform.CUDA, Platform.ROCM. platform (Platform): Platform types like Platform.CPU, Platform.CUDA, Platform.ROCM, Platform.DTK.
parameters (str): predefined parameters of benchmark. parameters (str): predefined parameters of benchmark.
framework (Framework): Framework types like Framework.PYTORCH, Framework.ONNXRUNTIME. framework (Framework): Framework types like Framework.PYTORCH, Framework.ONNXRUNTIME.
......
...@@ -28,6 +28,8 @@ def get_vendor(self): ...@@ -28,6 +28,8 @@ def get_vendor(self):
if Path('/dev/kfd').is_char_device() and Path('/dev/dri').is_dir(): if Path('/dev/kfd').is_char_device() and Path('/dev/dri').is_dir():
if not list(Path('/dev/dri').glob('renderD*')): if not list(Path('/dev/dri').glob('renderD*')):
logger.warning('Cannot find AMD GPU device.') logger.warning('Cannot find AMD GPU device.')
if Path('/usr/local/hyhal').exists():
return 'hygon'
return 'amd' return 'amd'
if list(Path(r'C:\Windows\System32').glob('*DriverStore/FileRepository/nv*.inf_amd64_*/nvapi64.dll')): if list(Path(r'C:\Windows\System32').glob('*DriverStore/FileRepository/nv*.inf_amd64_*/nvapi64.dll')):
return 'nvidia-graphics' return 'nvidia-graphics'
......
...@@ -87,6 +87,8 @@ def __get_platform(self): ...@@ -87,6 +87,8 @@ def __get_platform(self):
return Platform.CUDA return Platform.CUDA
elif gpu.vendor == 'amd': elif gpu.vendor == 'amd':
return Platform.ROCM return Platform.ROCM
elif gpu.vendor == 'hygon':
return Platform.DTK
elif gpu.vendor == 'amd-graphics' or gpu.vendor == 'nvidia-graphics': elif gpu.vendor == 'amd-graphics' or gpu.vendor == 'nvidia-graphics':
return Platform.DIRECTX return Platform.DIRECTX
except Exception as e: except Exception as e:
......
...@@ -37,6 +37,14 @@ ...@@ -37,6 +37,14 @@
- /dev/kfd - /dev/kfd
- /dev/dri - /dev/dri
register: amd_dev register: amd_dev
- name: Checking HYGON GPU Environment
stat:
path: '{{ item }}'
with_items:
- /dev/kfd
- /dev/dri
- /usr/local/hyhal
register: hygon_dev
- name: Set GPU Facts - name: Set GPU Facts
set_fact: set_fact:
nvidia_gpu_exist: >- nvidia_gpu_exist: >-
...@@ -45,11 +53,16 @@ ...@@ -45,11 +53,16 @@
amd_gpu_exist: >- amd_gpu_exist: >-
{{ amd_dev.results[0].stat.ischr is defined and amd_dev.results[0].stat.ischr and {{ amd_dev.results[0].stat.ischr is defined and amd_dev.results[0].stat.ischr and
amd_dev.results[1].stat.isdir is defined and amd_dev.results[1].stat.isdir }} amd_dev.results[1].stat.isdir is defined and amd_dev.results[1].stat.isdir }}
hygon_gpu_exist: >-
{{ (hygon_dev.results[0].stat.ischr is defined and hygon_dev.results[0].stat.ischr and
hygon_dev.results[1].stat.isdir is defined and hygon_dev.results[1].stat.isdir) and
hygon_dev.results[2].stat.exists is defined and hygon_dev.results[2].stat.exists }}
- name: Print GPU Checking Result - name: Print GPU Checking Result
debug: debug:
msg: msg:
- "NVIDIA GPU {{ 'detected' if nvidia_gpu_exist else 'not operational, pls confirm nvidia_uvm kernel module is loaded and /dev/nvidia-uvm exists' }}" - "NVIDIA GPU {{ 'detected' if nvidia_gpu_exist else 'not operational, pls confirm nvidia_uvm kernel module is loaded and /dev/nvidia-uvm exists' }}"
- "AMD GPU {{ 'detected' if amd_gpu_exist else 'not operational, pls confirm amdgpu kernel module is loaded' }}" - "AMD GPU {{ 'detected' if amd_gpu_exist and not hygon_gpu_exist else 'not operational, pls confirm amdgpu kernel module is loaded' }}"
- "HYGON GPU {{ 'detected' if hygon_gpu_exist else 'not operational, pls confirm hygon gpu kernel module and driver are loaded' }}"
- name: Remote Deployment - name: Remote Deployment
hosts: all hosts: all
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for DTK hipblaslt-bench benchmark."""
import unittest
from types import SimpleNamespace
from tests.helper.testcase import BenchmarkTestCase
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
from superbench.benchmarks.result import BenchmarkResult
class DtkHipblasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
"""Class for DTK hipblaslt-bench benchmark test cases."""
@classmethod
def setUpClass(cls):
"""Hook method for setting up class fixture before running tests in the class."""
super().setUpClass()
cls.benchmark_name = 'hipblaslt-gemm'
cls.createMockEnvs(cls)
cls.createMockFiles(cls, ['bin/hipblaslt-bench'])
def get_benchmark(self):
"""Get benchmark."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.DTK)
return benchmark_cls(self.benchmark_name, parameters='')
def test_hipblaslt_gemm_cls(self):
"""Test DTK hipblaslt-bench benchmark class."""
for platform in Platform:
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, platform)
if platform is Platform.DTK:
self.assertIsNotNone(benchmark_cls)
else:
self.assertIsNone(benchmark_cls)
def test_hipblaslt_gemm_command_generation(self):
"""Test DTK hipblaslt-bench benchmark command generation."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.DTK)
benchmark = benchmark_cls(
self.benchmark_name,
parameters='--batch 4:2:-1 --shapes 2,4,8 --in_types fp16 fp32 fp64 int8',
)
self.assertFalse(benchmark._preprocess())
benchmark = benchmark_cls(
self.benchmark_name,
parameters=' --shapes 2,4,8 --in_types fp16 fp32 fp64 int8',
)
self.assertFalse(benchmark._preprocess())
benchmark = benchmark_cls(
self.benchmark_name,
parameters=' --shapes 2:4,4:8 --in_types fp16 fp32',
)
self.assertFalse(benchmark._preprocess())
benchmark = benchmark_cls(
self.benchmark_name,
parameters='--shapes 2:4,4:8,8:32 2:4,4:8,8:32:+4 --in_types fp16 fp32 bf16 fp8',
)
self.assertTrue(benchmark._preprocess())
self.assertEqual((2 * 2 * 3 + 2 * 2 * 7) * len(benchmark._args.in_types), len(benchmark._commands))
def cmd(t, b, m, n, k):
if b == 0:
return f'{benchmark._DtkHipBlasLtBenchmark__bin_path} ' + \
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]}' + \
f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \
f' --initialization {benchmark._args.initialization}'
else:
return f'{benchmark._DtkHipBlasLtBenchmark__bin_path} ' + \
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} -b {b}' + \
f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \
f' --initialization {benchmark._args.initialization}'
for _t in ['fp16', 'fp32', 'bf16', 'fp8']:
for _m in [2, 4]:
for _n in [4, 8]:
for _k in [8, 16, 32]:
self.assertIn(cmd(_t, 0, _m, _n, _k), benchmark._commands)
for _k in [8, 12, 16, 20, 24, 28, 32]:
self.assertIn(cmd(_t, 0, _m, _n, _k), benchmark._commands)
def test_hipblaslt_gemm_result_parsing(self):
"""Test DTK hipblaslt-bench benchmark result parsing."""
benchmark = self.get_benchmark()
self.assertTrue(benchmark._preprocess())
benchmark._args = SimpleNamespace(shapes=['4096,4096,4096'], in_types=['fp32'], log_raw_data=False)
benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)
example_raw_output = """
hipBLASLt version: 1000
hipBLASLt git version: 4bd05bb5-dirty
Query device success: there are 1 devices
-------------------------------------------------------------------------------
Device ID 0 : BW150 gfx936:sramecc+:xnack-
with 68.7 GB memory, max. SCLK 1400 MHz, max. MCLK 1800 MHz, compute capability 9.3
maxGridDimX 2147483647, sharedMemPerBlock 65.5 KB, maxThreadsPerBlock 1024, warpSize 64
-------------------------------------------------------------------------------
Is supported 1 / Total solutions: 1
[0]:transA,transB,grouped_gemm,batch_count,m,n,k,alpha,lda,stride_a,beta,ldb,stride_b,ldc,stride_c,ldd,stride_d,a_type,b_type,c_type,d_type,compute_type,scaleA,scaleB,scaleC,scaleD,amaxD,activation_type,bias_vector,bias_type,hipblaslt-Gflops,us
N,N,0,1,4096,4096,4096,1,4096,16777216,0,4096,16777216,4096,16777216,4096,16777216,f32_r,f32_r,f32_r,f32_r,f32_r,0,0,0,0,0,none,0,non-supported type,1595.18,86159.1
"""
self.assertTrue(benchmark._process_raw_result(0, example_raw_output))
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
self.assertEqual(2, len(benchmark.result))
self.assertEqual(1.59518, benchmark.result['fp32_1_4096_4096_4096_flops'][0])
self.assertFalse(benchmark._process_raw_result(1, 'HipBLAS API failed'))
...@@ -62,12 +62,12 @@ def test_hipblaslt_gemm_command_generation(self): ...@@ -62,12 +62,12 @@ def test_hipblaslt_gemm_command_generation(self):
def cmd(t, b, m, n, k): def cmd(t, b, m, n, k):
if b == 0: if b == 0:
return f'{benchmark._HipBlasLtBenchmark__bin_path} ' + \ return f'{benchmark._RocmHipBlasLtBenchmark__bin_path} ' + \
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]}' + \ f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]}' + \
f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \ f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \
f' --initialization {benchmark._args.initialization}' f' --initialization {benchmark._args.initialization}'
else: else:
return f'{benchmark._HipBlasLtBenchmark__bin_path} ' + \ return f'{benchmark._RocmHipBlasLtBenchmark__bin_path} ' + \
f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} -b {b}' + \ f'-m {m} -n {n} -k {k} -j 20 -i 50 {benchmark._in_type_map[t]} -b {b}' + \
f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \ f' --transA {benchmark._args.transA} --transB {benchmark._args.transB}' + \
f' --initialization {benchmark._args.initialization}' f' --initialization {benchmark._args.initialization}'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment