Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
d81f8119
Commit
d81f8119
authored
Sep 18, 2025
by
wenjh
Browse files
Adapt to changes of hipblaslt
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
3f800f01
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
27 deletions
+2
-27
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+2
-27
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
d81f8119
...
...
@@ -1076,17 +1076,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
hipblasLtMatmulPreference_t
preference
=
nullptr
;
hipblasLtEpilogue_t
epilogue
=
HIPBLASLT_EPILOGUE_DEFAULT
;
hipblasLtMatmulFlags_t
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16
;
if
(
tensorwise_int8
)
{
if
(
D_type
==
HIP_R_16BF
)
{
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16
;
}
else
if
(
D_type
==
HIP_R_32F
)
{
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_FP32
;
}
else
{
NVTE_CHECK
(
false
,
"tensorwise_int8 only surpport D_type bf16 or fp32!"
);
}
}
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
...
...
@@ -1099,11 +1088,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
transb
==
HIPBLAS_OP_N
?
n
:
k
,
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
if
(
tensorwise_int8
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
,
matmul_flag
));
}
else
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
...
...
@@ -1450,16 +1435,6 @@ void hipblaslt_batchgemm_tensorwise_int8(const Tensor *inputA,
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
hipblasLtMatmulPreference_t
preference
=
nullptr
;
hipblasLtEpilogue_t
epilogue
=
HIPBLASLT_EPILOGUE_DEFAULT
;
hipblasLtMatmulFlags_t
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16
;
if
(
tensorwise_int8
)
{
if
(
D_type
==
HIP_R_16BF
)
{
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16
;
}
else
if
(
D_type
==
HIP_R_32F
)
{
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_FP32
;
}
else
{
NVTE_CHECK
(
false
,
"tensorwise_int8 only surpport D_type bf16 or fp32!"
);
}
}
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
...
...
@@ -1491,7 +1466,7 @@ void hipblaslt_batchgemm_tensorwise_int8(const Tensor *inputA,
hipblasLtMatrixLayoutSetAttribute
(
Ddesc
,
HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch_count
,
sizeof
(
int32_t
));
hipblasLtMatrixLayoutSetAttribute
(
Ddesc
,
HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideD
,
sizeof
(
int64_t
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
,
matmul_flag
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
}
else
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment