Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
superbenchmark
Commits
f53d941a
Unverified
Commit
f53d941a
authored
Nov 20, 2023
by
Yuting Jiang
Committed by
GitHub
Nov 20, 2023
Browse files
Benchmarks: micro benchmarks - add int8 support for cublaslt function (#574)
**Description** add int8 support for cublaslt function.
parent
c7800bb8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
4 deletions
+11
-4
superbench/benchmarks/micro_benchmarks/cublaslt_function.py
superbench/benchmarks/micro_benchmarks/cublaslt_function.py
+1
-1
superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu
...enchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu
+5
-0
superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_utils.cc
...nchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_utils.cc
+2
-0
tests/benchmarks/micro_benchmarks/test_cublaslt_function.py
tests/benchmarks/micro_benchmarks/test_cublaslt_function.py
+3
-3
No files found.
superbench/benchmarks/micro_benchmarks/cublaslt_function.py
View file @
f53d941a
...
@@ -23,7 +23,7 @@ class CublasLtBenchmark(MicroBenchmarkWithInvoke):
...
@@ -23,7 +23,7 @@ class CublasLtBenchmark(MicroBenchmarkWithInvoke):
super
().
__init__
(
name
,
parameters
)
super
().
__init__
(
name
,
parameters
)
self
.
_bin_name
=
'cublaslt_gemm'
self
.
_bin_name
=
'cublaslt_gemm'
self
.
_in_types
=
[
'fp64'
,
'fp32'
,
'fp16'
,
'bf16'
,
'fp8e4m3'
,
'fp8e5m2'
]
self
.
_in_types
=
[
'fp64'
,
'fp32'
,
'fp16'
,
'bf16'
,
'fp8e4m3'
,
'fp8e5m2'
,
'int8'
]
def
mrange
(
self
,
start
,
stop
=-
1
,
multiplication_factor
=
2
):
def
mrange
(
self
,
start
,
stop
=-
1
,
multiplication_factor
=
2
):
"""Range constructor with multiplication factor.
"""Range constructor with multiplication factor.
...
...
superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu
View file @
f53d941a
...
@@ -16,6 +16,7 @@ using fp16 = half;
...
@@ -16,6 +16,7 @@ using fp16 = half;
using
bf16
=
nv_bfloat16
;
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
int8
=
int8_t
;
struct
Args
{
struct
Args
{
int
m
=
16
;
int
m
=
16
;
...
@@ -84,6 +85,8 @@ template <typename T> cudaDataType_t get_datatype() {
...
@@ -84,6 +85,8 @@ template <typename T> cudaDataType_t get_datatype() {
return
CUDA_R_8F_E4M3
;
return
CUDA_R_8F_E4M3
;
if
(
std
::
is_same
<
T
,
fp8e5m2
>::
value
)
if
(
std
::
is_same
<
T
,
fp8e5m2
>::
value
)
return
CUDA_R_8F_E5M2
;
return
CUDA_R_8F_E5M2
;
if
(
std
::
is_same
<
T
,
int8
>::
value
)
return
CUDA_R_8I
;
throw
std
::
invalid_argument
(
"Unknown type"
);
throw
std
::
invalid_argument
(
"Unknown type"
);
}
}
...
@@ -162,6 +165,8 @@ int main(int argc, char **argv) {
...
@@ -162,6 +165,8 @@ int main(int argc, char **argv) {
run
<
fp8e4m3
,
fp8e4m3
,
fp16
>
(
&
args
);
run
<
fp8e4m3
,
fp8e4m3
,
fp16
>
(
&
args
);
else
if
(
args
.
in_type
==
"fp8e5m2"
)
else
if
(
args
.
in_type
==
"fp8e5m2"
)
run
<
fp8e5m2
,
fp8e4m3
,
fp16
>
(
&
args
);
run
<
fp8e5m2
,
fp8e4m3
,
fp16
>
(
&
args
);
else
if
(
args
.
in_type
==
"int8"
)
run
<
int8
>
(
&
args
);
else
else
throw
std
::
invalid_argument
(
"Unknown type "
+
args
.
in_type
);
throw
std
::
invalid_argument
(
"Unknown type "
+
args
.
in_type
);
...
...
superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_utils.cc
View file @
f53d941a
...
@@ -62,6 +62,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
...
@@ -62,6 +62,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
gemm_compute_type
=
CUBLAS_COMPUTE_32F
;
gemm_compute_type
=
CUBLAS_COMPUTE_32F
;
if
(
a_type
==
CUDA_R_64F
||
b_type
==
CUDA_R_64F
)
if
(
a_type
==
CUDA_R_64F
||
b_type
==
CUDA_R_64F
)
gemm_compute_type
=
CUBLAS_COMPUTE_64F
;
gemm_compute_type
=
CUBLAS_COMPUTE_64F
;
if
(
a_type
==
CUDA_R_8I
)
gemm_compute_type
=
CUBLAS_COMPUTE_32I
;
cublasLtMatmulDesc_t
op_desc
=
nullptr
;
cublasLtMatmulDesc_t
op_desc
=
nullptr
;
CUBLAS_CHECK
(
cublasLtMatmulDescCreate
(
&
op_desc
,
gemm_compute_type
,
CUDA_R_32F
));
CUBLAS_CHECK
(
cublasLtMatmulDescCreate
(
&
op_desc
,
gemm_compute_type
,
CUDA_R_32F
));
...
...
tests/benchmarks/micro_benchmarks/test_cublaslt_function.py
View file @
f53d941a
...
@@ -63,15 +63,15 @@ class CublasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
...
@@ -63,15 +63,15 @@ class CublasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
(
benchmark_cls
,
_
)
=
BenchmarkRegistry
.
_BenchmarkRegistry__select_benchmark
(
self
.
benchmark_name
,
Platform
.
CUDA
)
(
benchmark_cls
,
_
)
=
BenchmarkRegistry
.
_BenchmarkRegistry__select_benchmark
(
self
.
benchmark_name
,
Platform
.
CUDA
)
benchmark
=
benchmark_cls
(
benchmark
=
benchmark_cls
(
self
.
benchmark_name
,
self
.
benchmark_name
,
parameters
=
'--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64'
,
parameters
=
'--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64
int8
'
,
)
)
self
.
assertTrue
(
benchmark
.
_preprocess
())
self
.
assertTrue
(
benchmark
.
_preprocess
())
self
.
assertEqual
(
4
*
(
2
*
2
*
3
+
2
)
*
3
,
len
(
benchmark
.
_commands
))
self
.
assertEqual
(
4
*
(
2
*
2
*
3
+
2
)
*
len
(
benchmark
.
_args
.
in_types
)
,
len
(
benchmark
.
_commands
))
def
cmd
(
t
,
b
,
m
,
n
,
k
):
def
cmd
(
t
,
b
,
m
,
n
,
k
):
return
f
'
{
benchmark
.
_CublasLtBenchmark__bin_path
}
-m
{
m
}
-n
{
n
}
-k
{
k
}
-b
{
b
}
-w 20 -i 50 -t
{
t
}
'
return
f
'
{
benchmark
.
_CublasLtBenchmark__bin_path
}
-m
{
m
}
-n
{
n
}
-k
{
k
}
-b
{
b
}
-w 20 -i 50 -t
{
t
}
'
for
_t
in
[
'fp16'
,
'fp32'
,
'fp64'
]:
for
_t
in
[
'fp16'
,
'fp32'
,
'fp64'
,
'int8'
]:
for
_b
in
[
2
,
4
,
8
,
16
]:
for
_b
in
[
2
,
4
,
8
,
16
]:
for
_m
in
[
2
,
4
]:
for
_m
in
[
2
,
4
]:
for
_n
in
[
4
,
8
]:
for
_n
in
[
4
,
8
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment