Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
11864d3d
"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "2a7cb008acbf7a1942a6270b014bdf86886ed6f1"
Commit
11864d3d
authored
Aug 23, 2025
by
yuguo
Browse files
[DCU] tensorwise int8 gemm surpport bias
parent
32edae18
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
10 deletions
+77
-10
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+77
-9
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+0
-1
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
11864d3d
...
...
@@ -384,6 +384,30 @@ __launch_bounds__(1024) __global__
}
}
template
<
typename
OutputType
>
__launch_bounds__
(
1024
)
__global__
void
tensorwise_int8_bias_gradient_kernel
(
OutputType
*
dst
,
const
int8_t
*
src
,
float
*
scale
,
int
M
,
int
N
)
{
__shared__
float
g_shared
[
kColwiseReduceTileSize
][
kColwiseReduceTileSize
];
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float
grad_sum
=
0.
f
;
float
tensorwise_scale
=
scale
[
0
];
if
(
j
<
N
)
{
for
(
int
i
=
threadIdx
.
y
;
i
<
M
;
i
+=
blockDim
.
y
)
{
grad_sum
+=
static_cast
<
float
>
(
src
[
i
*
N
+
j
])
*
tensorwise_scale
;
}
}
g_shared
[
threadIdx
.
y
][
threadIdx
.
x
]
=
grad_sum
;
__syncthreads
();
float
sum
=
g_shared
[
threadIdx
.
x
][
threadIdx
.
y
];
sum
=
WarpReduceSum
<
float
>
(
sum
,
kColwiseReduceTileSize
/
2
);
if
(
threadIdx
.
x
==
0
)
{
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
;
if
(
j
<
N
)
{
dst
[
j
]
=
static_cast
<
OutputType
>
(
sum
);
}
}
}
template
<
typename
Tin
>
void
bias_gradient_kernelLauncher
(
const
Tin
*
in
,
float
*
out
,
int
m
,
int
n
,
bool
stream_order_alloc
,
hipStream_t
stream
)
{
...
...
@@ -403,6 +427,19 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
<<<
B
,
dim3
(
kColwiseReduceTileSize
,
kColwiseReduceTileSize
),
0
,
stream
>>>
(
out
,
in
,
m
,
n
);
}
template
<
typename
Tout
>
void
tensorwise_int8_bias_gradient_kernelLauncher
(
const
int8_t
*
in
,
Tout
*
out
,
float
*
scale
,
int
m
,
int
n
,
hipStream_t
stream
)
{
dim3
block
,
grid
;
constexpr
int
THREADS_PER_BLOCK
=
1024
;
int
BLOCKS_PER_COL
=
ceil
(
float
(
m
)
/
THREADS_PER_BLOCK
);
block
.
x
=
THREADS_PER_BLOCK
;
grid
.
x
=
BLOCKS_PER_COL
*
n
;
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
Tout
),
stream
));
int
B
=
(
n
-
1
)
/
kColwiseReduceTileSize
+
1
;
tensorwise_int8_bias_gradient_kernel
<
Tout
>
<<<
B
,
dim3
(
kColwiseReduceTileSize
,
kColwiseReduceTileSize
),
0
,
stream
>>>
(
out
,
in
,
scale
,
m
,
n
);
}
}
// namespace detail
transformer_engine
::
DType
get_transformer_engine_dtype
(
const
rocblas_datatype
t
)
{
...
...
@@ -962,6 +999,20 @@ static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
}
}
transformer_engine
::
DType
get_transformer_engine_dtype_from_hipblaslt_dtype
(
const
hipDataType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
HIP_R_16F
:
return
DType
::
kFloat16
;
case
HIP_R_32F
:
return
DType
::
kFloat32
;
case
HIP_R_16BF
:
return
DType
::
kBFloat16
;
default:
NVTE_ERROR
(
"Invalid type"
);
}
}
void
hipblaslt_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
...
...
@@ -1090,9 +1141,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
(
void
*
)
&
B_scale_inverse_float
,
sizeof
(
void
*
)));
if
(
bias
)
{
NVTE_CHECK
(
false
,
"tensorwise_int8 not surpport bias!"
);
}
}
if
(
bias
&&
gelu
)
{
...
...
@@ -1109,14 +1157,34 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ld_gelumat
,
sizeof
(
ld_gelumat
)));
}
else
if
(
bias
)
{
if
(
grad
)
{
// grad output is always input B
epilogue
=
HIPBLASLT_EPILOGUE_BGRADB
;
if
(
tensorwise_int8
)
{
if
(
grad
)
{
int
batch_size
=
k
;
int
output_dim
=
n
;
DType
te_bias_dtype
=
get_transformer_engine_dtype_from_hipblaslt_dtype
(
bias_type
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
te_bias_dtype
,
BType
,
·
detail
::
tensorwise_int8_bias_gradient_kernelLauncher
<
BType
>
(
reinterpret_cast
<
const
int8_t
*>
(
B
),
reinterpret_cast
<
BType
*>
(
bias_ptr
),
B_scale_inverse_float
,
batch_size
,
output_dim
,
stream
););
}
else
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE
,
&
bias_type
,
sizeof
(
bias_type
)));
epilogue
=
HIPBLASLT_EPILOGUE_BIAS
;
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias_ptr
,
sizeof
(
bias_ptr
)));
}
}
else
{
epilogue
=
HIPBLASLT_EPILOGUE_BIAS
;
if
(
grad
)
{
// grad output is always input B
epilogue
=
HIPBLASLT_EPILOGUE_BGRADB
;
}
else
{
epilogue
=
HIPBLASLT_EPILOGUE_BIAS
;
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias_ptr
,
sizeof
(
bias_ptr
)));
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias_ptr
,
sizeof
(
bias_ptr
)));
}
else
if
(
gelu
)
{
if
(
grad
)
{
epilogue
=
HIPBLASLT_EPILOGUE_DGELU
;
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
11864d3d
...
...
@@ -185,7 +185,6 @@ def general_gemm(
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
))
and
int8_simulation_fp8_tensorwise
:
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment