Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tsoc
superbenchmark
Commits
616e7a5a
Unverified
Commit
616e7a5a
authored
Jan 03, 2023
by
Yifan Xiong
Committed by
GitHub
Jan 03, 2023
Browse files
Benchmarks - Integrate cublaslt micro-benchmark (#455)
Integrate cublaslt-gemm micro-benchmark #451.
parent
75573f59
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
192 additions
and
2 deletions
+192
-2
docs/user-tutorial/benchmarks/micro-benchmarks.md
docs/user-tutorial/benchmarks/micro-benchmarks.md
+12
-0
superbench/benchmarks/micro_benchmarks/__init__.py
superbench/benchmarks/micro_benchmarks/__init__.py
+4
-2
superbench/benchmarks/micro_benchmarks/cublaslt_function.py
superbench/benchmarks/micro_benchmarks/cublaslt_function.py
+126
-0
tests/benchmarks/micro_benchmarks/test_cublaslt_function.py
tests/benchmarks/micro_benchmarks/test_cublaslt_function.py
+50
-0
No files found.
docs/user-tutorial/benchmarks/micro-benchmarks.md
View file @
616e7a5a
...
...
@@ -58,6 +58,18 @@ Large scale matmul operation using `torch.matmul` with one GPU.
|--------------------------------|-----------|--------------------------------|
| pytorch-matmul/nosharding_time | time (ms) | Time of pure matmul operation. |
### `cublaslt-gemm`
#### Introduction
Measure the GEMM performance of
[
`cublasLtMatmul`
](
https://docs.nvidia.com/cuda/cublas/#cublasltmatmul
)
.
#### Metrics
| Name | Unit | Description |
|---------------------------------|----------------|---------------------------------|
| cublaslt-gemm/dtype_m_n_k_flops | FLOPS (TFLOPS) | TFLOPS of measured GEMM kernel. |
### `cublas-function`
#### Introduction
...
...
superbench/benchmarks/micro_benchmarks/__init__.py
View file @
616e7a5a
...
...
@@ -9,6 +9,7 @@
from
superbench.benchmarks.micro_benchmarks.computation_communication_overlap
import
ComputationCommunicationOverlap
from
superbench.benchmarks.micro_benchmarks.cublas_function
import
CublasBenchmark
from
superbench.benchmarks.micro_benchmarks.cublaslt_function
import
CublasLtBenchmark
from
superbench.benchmarks.micro_benchmarks.cuda_gemm_flops_performance
import
CudaGemmFlopsBenchmark
from
superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance
import
CudaMemBwBenchmark
from
superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance
import
CudaNcclBwBenchmark
...
...
@@ -30,17 +31,18 @@
__all__
=
[
'ComputationCommunicationOverlap'
,
'CpuMemBwLatencyBenchmark'
,
'CublasBenchmark'
,
'CublasLtBenchmark'
,
'CudaGemmFlopsBenchmark'
,
'CudaMemBwBenchmark'
,
'CudaNcclBwBenchmark'
,
'CudnnBenchmark'
,
'DiskBenchmark'
,
'CpuMemBwLatencyBenchmark'
,
'GPCNetBenchmark'
,
'GemmFlopsBenchmark'
,
'GpuCopyBwBenchmark'
,
'GpuBurnBenchmark'
,
'GpuCopyBwBenchmark'
,
'IBBenchmark'
,
'IBLoopbackBenchmark'
,
'KernelLaunch'
,
...
...
superbench/benchmarks/micro_benchmarks/cublaslt_function.py
0 → 100644
View file @
616e7a5a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the cuBLASLt GEMM benchmark."""
import
os
from
superbench.common.utils
import
logger
from
superbench.benchmarks
import
BenchmarkRegistry
,
Platform
,
ReturnCode
from
superbench.benchmarks.micro_benchmarks
import
MicroBenchmarkWithInvoke
class
CublasLtBenchmark
(
MicroBenchmarkWithInvoke
):
"""The cuBLASLt GEMM benchmark class."""
def
__init__
(
self
,
name
,
parameters
=
''
):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super
().
__init__
(
name
,
parameters
)
self
.
_bin_name
=
'cublaslt_fp8_gemm'
self
.
_in_types
=
[
'fp16'
,
'fp8e4m3'
,
'fp8e5m2'
]
def
add_parser_arguments
(
self
):
"""Add the specified arguments."""
super
().
add_parser_arguments
()
self
.
_parser
.
add_argument
(
'--shapes'
,
type
=
str
,
nargs
=
'+'
,
default
=
[
f
'
{
x
}
,
{
x
}
,
{
x
}
'
for
x
in
[
2048
,
4096
,
8192
]],
help
=
'Shapes in m,n,k format.'
,
)
self
.
_parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
0
,
required
=
False
,
help
=
'Batch size for strided batch GEMM, set 0 to disable.'
,
)
self
.
_parser
.
add_argument
(
'--num_warmup'
,
type
=
int
,
default
=
20
,
required
=
False
,
help
=
'Number of warm up steps.'
,
)
self
.
_parser
.
add_argument
(
'--num_steps'
,
type
=
int
,
default
=
50
,
required
=
False
,
help
=
'Number of steps to measure.'
,
)
self
.
_parser
.
add_argument
(
'--in_type'
,
type
=
str
,
default
=
'fp8e4m3'
,
required
=
False
,
help
=
'Input data type, supports {}.'
.
format
(
' '
.
join
(
self
.
_in_types
)),
)
def
_preprocess
(
self
):
"""Preprocess/preparation operations before the benchmarking.
Return:
True if _preprocess() succeed.
"""
if
not
super
().
_preprocess
():
return
False
self
.
__bin_path
=
os
.
path
.
join
(
self
.
_args
.
bin_dir
,
self
.
_bin_name
)
if
self
.
_args
.
in_type
not
in
self
.
_in_types
:
logger
.
error
(
f
'Invalid input type
{
self
.
_args
.
in_type
}
.'
)
return
False
self
.
_commands
=
[]
for
shape
in
self
.
_args
.
shapes
:
shape_list
=
shape
.
replace
(
','
,
' '
).
split
()
if
len
(
shape_list
)
!=
3
or
not
all
(
x
.
isdigit
()
for
x
in
shape_list
):
logger
.
error
(
f
'Invalid shape
{
shape
}
.'
)
return
False
self
.
_commands
.
append
(
f
'
{
self
.
__bin_path
}
-m
{
shape_list
[
0
]
}
-n
{
shape_list
[
1
]
}
-k
{
shape_list
[
2
]
}
'
f
'-b
{
self
.
_args
.
batch
}
-w
{
self
.
_args
.
num_warmup
}
-i
{
self
.
_args
.
num_steps
}
-t
{
self
.
_args
.
in_type
}
'
)
return
True
def
_process_raw_result
(
self
,
cmd_idx
,
raw_output
):
"""Function to parse raw results and save the summarized results.
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
Args:
cmd_idx (int): the index of command corresponding with the raw_output.
raw_output (str): raw output string of the micro-benchmark.
Return:
True if the raw output string is valid and result can be extracted.
"""
self
.
_result
.
add_raw_data
(
f
'raw_output_
{
cmd_idx
}
'
,
raw_output
,
self
.
_args
.
log_raw_data
)
try
:
fields
=
raw_output
.
strip
().
split
()
if
len
(
fields
)
!=
6
or
not
all
(
x
.
isdigit
()
for
x
in
fields
[:
4
]):
raise
ValueError
(
'Invalid result.'
)
self
.
_result
.
add_result
(
f
'
{
self
.
_args
.
in_type
}
_
{
"_"
.
join
(
fields
[:
3
])
}
_flops'
,
float
(
fields
[
-
1
]))
except
BaseException
as
e
:
self
.
_result
.
set_return_code
(
ReturnCode
.
MICROBENCHMARK_RESULT_PARSING_FAILURE
)
logger
.
error
(
'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'
.
format
(
self
.
_curr_run_index
,
self
.
_name
,
raw_output
,
str
(
e
)
)
)
return
False
return
True
BenchmarkRegistry
.
register_benchmark
(
'cublaslt-gemm'
,
CublasLtBenchmark
,
platform
=
Platform
.
CUDA
)
tests/benchmarks/micro_benchmarks/test_cublaslt_function.py
0 → 100644
View file @
616e7a5a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for cublaslt-gemm benchmark."""
import
unittest
from
types
import
SimpleNamespace
from
tests.helper.testcase
import
BenchmarkTestCase
from
superbench.benchmarks
import
BenchmarkRegistry
,
BenchmarkType
,
ReturnCode
,
Platform
from
superbench.benchmarks.result
import
BenchmarkResult
class
CublasLtBenchmarkTestCase
(
BenchmarkTestCase
,
unittest
.
TestCase
):
"""Class for cublaslt-gemm benchmark test cases."""
@
classmethod
def
setUpClass
(
cls
):
"""Hook method for setting up class fixture before running tests in the class."""
super
().
setUpClass
()
cls
.
benchmark_name
=
'cublaslt-gemm'
cls
.
createMockEnvs
(
cls
)
cls
.
createMockFiles
(
cls
,
[
'bin/cublaslt_fp8_gemm'
])
def
test_cublaslt_gemm_cls
(
self
):
"""Test cublaslt-gemm benchmark class."""
for
platform
in
Platform
:
(
benchmark_cls
,
_
)
=
BenchmarkRegistry
.
_BenchmarkRegistry__select_benchmark
(
self
.
benchmark_name
,
platform
)
if
platform
is
Platform
.
CUDA
:
self
.
assertIsNotNone
(
benchmark_cls
)
else
:
self
.
assertIsNone
(
benchmark_cls
)
def
test_cublaslt_gemm_result_parsing
(
self
):
"""Test cublaslt-gemm benchmark result parsing."""
(
benchmark_cls
,
_
)
=
BenchmarkRegistry
.
_BenchmarkRegistry__select_benchmark
(
self
.
benchmark_name
,
Platform
.
CUDA
)
benchmark
=
benchmark_cls
(
self
.
benchmark_name
,
parameters
=
''
)
benchmark
.
_args
=
SimpleNamespace
(
shapes
=
[
'16,16,16'
,
'32,64,128'
],
in_type
=
'fp8e4m3'
,
log_raw_data
=
False
)
benchmark
.
_result
=
BenchmarkResult
(
self
.
benchmark_name
,
BenchmarkType
.
MICRO
,
ReturnCode
.
SUCCESS
,
run_count
=
1
)
# Positive case - valid raw output
self
.
assertTrue
(
benchmark
.
_process_raw_result
(
0
,
'16 16 16 0 1.111 2.222'
))
self
.
assertTrue
(
benchmark
.
_process_raw_result
(
1
,
'32 64 128 0 1.111 2.222'
))
self
.
assertEqual
(
ReturnCode
.
SUCCESS
,
benchmark
.
return_code
)
self
.
assertEqual
(
3
,
len
(
benchmark
.
result
))
for
shape
in
benchmark
.
_args
.
shapes
:
self
.
assertEqual
(
2.222
,
benchmark
.
result
[
f
'fp8e4m3_
{
shape
.
replace
(
","
,
"_"
)
}
_flops'
][
0
])
# Negative case - invalid raw output
self
.
assertFalse
(
benchmark
.
_process_raw_result
(
1
,
'cuBLAS API failed'
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment