Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
72ba94a1
Commit
72ba94a1
authored
Dec 19, 2025
by
wenjh
Browse files
Add bias fwd/bwd at group gemm
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
e698a0a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
16 deletions
+45
-16
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+4
-3
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+41
-13
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
72ba94a1
...
@@ -1405,13 +1405,14 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
...
@@ -1405,13 +1405,14 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
n
.
push_back
(
B0
);
n
.
push_back
(
B0
);
}
}
}
}
bool
use_bias
=
biasTensor
[
0
]
->
data
.
dptr
!=
nullptr
?
true
:
false
;
Tensor
*
wspace
=
convertNVTETensorCheck
(
workspace
[
0
]);
Tensor
*
wspace
=
convertNVTETensorCheck
(
workspace
[
0
]);
if
((
biasTensor
[
0
]
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
[
0
]
->
data
.
dptr
!=
nullptr
)
)
{
if
(
outputGelu
[
0
]
->
data
.
dptr
!=
nullptr
)
{
NVTE_ERROR
(
"MOE nvte_grouped_gemm not surpport
bias or
gelu."
);
NVTE_ERROR
(
"MOE nvte_grouped_gemm not surpport gelu."
);
}
}
hipblaslt_groupedgemm
(
inputA
,
inputB
,
outputD
,
m
,
n
,
k
,
b
,
hipblaslt_groupedgemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
use_bias
,
grad
,
m
,
n
,
k
,
b
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
72ba94a1
...
@@ -362,9 +362,9 @@ __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
...
@@ -362,9 +362,9 @@ __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
return
val
;
return
val
;
}
}
template
<
typename
InputType
>
template
<
typename
InputType
,
typename
OutputType
>
__launch_bounds__
(
1024
)
__global__
__launch_bounds__
(
1024
)
__global__
void
bias_gradient_kernel_v2
(
float
*
dst
,
const
InputType
*
src
,
int
M
,
int
N
)
{
void
bias_gradient_kernel_v2
(
OutputType
*
dst
,
const
InputType
*
src
,
int
M
,
int
N
)
{
__shared__
float
g_shared
[
kColwiseReduceTileSize
][
kColwiseReduceTileSize
];
__shared__
float
g_shared
[
kColwiseReduceTileSize
][
kColwiseReduceTileSize
];
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float
grad_sum
=
0.
f
;
float
grad_sum
=
0.
f
;
...
@@ -380,7 +380,7 @@ __launch_bounds__(1024) __global__
...
@@ -380,7 +380,7 @@ __launch_bounds__(1024) __global__
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
;
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
;
if
(
j
<
N
)
{
if
(
j
<
N
)
{
dst
[
j
]
=
static_cast
<
float
>
(
sum
);
dst
[
j
]
=
static_cast
<
OutputType
>
(
sum
);
}
}
}
}
}
}
...
@@ -409,8 +409,8 @@ __launch_bounds__(1024) __global__
...
@@ -409,8 +409,8 @@ __launch_bounds__(1024) __global__
}
}
}
}
template
<
typename
Tin
>
template
<
typename
Tin
,
typename
Tout
>
void
bias_gradient_kernelLauncher
(
const
Tin
*
in
,
floa
t
*
out
,
int
m
,
int
n
,
bool
stream_order_alloc
,
void
bias_gradient_kernelLauncher
(
const
Tin
*
in
,
Tou
t
*
out
,
int
m
,
int
n
,
bool
stream_order_alloc
,
hipStream_t
stream
)
{
hipStream_t
stream
)
{
dim3
block
,
grid
;
dim3
block
,
grid
;
constexpr
int
THREADS_PER_BLOCK
=
1024
;
constexpr
int
THREADS_PER_BLOCK
=
1024
;
...
@@ -418,13 +418,13 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
...
@@ -418,13 +418,13 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
block
.
x
=
THREADS_PER_BLOCK
;
block
.
x
=
THREADS_PER_BLOCK
;
grid
.
x
=
BLOCKS_PER_COL
*
n
;
grid
.
x
=
BLOCKS_PER_COL
*
n
;
if
(
!
stream_order_alloc
)
{
if
(
!
stream_order_alloc
)
{
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
floa
t
)));
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
Tou
t
)));
}
else
{
}
else
{
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
floa
t
),
stream
));
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
Tou
t
),
stream
));
}
}
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int
B
=
(
n
-
1
)
/
kColwiseReduceTileSize
+
1
;
int
B
=
(
n
-
1
)
/
kColwiseReduceTileSize
+
1
;
bias_gradient_kernel_v2
<
Tin
>
bias_gradient_kernel_v2
<
Tin
,
Tout
>
<<<
B
,
dim3
(
kColwiseReduceTileSize
,
kColwiseReduceTileSize
),
0
,
stream
>>>
(
out
,
in
,
m
,
n
);
<<<
B
,
dim3
(
kColwiseReduceTileSize
,
kColwiseReduceTileSize
),
0
,
stream
>>>
(
out
,
in
,
m
,
n
);
}
}
...
@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
...
@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
}
}
static
void
DestroyHipBlasLtHandle
(
hipblasLtHandle_t
handle
)
{
static
void
DestroyHipBlasLtHandle
(
hipblasLtHandle_t
handle
)
{
if
(
handle
!=
nullptr
)
if
(
handle
!=
nullptr
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtDestroy
(
handle
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtDestroy
(
handle
));
}
}
}
}
...
@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache
...
@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache
{
{
HipBlasltUserArgsCache
()
{}
HipBlasltUserArgsCache
()
{}
HipBlasltUserArgsCache
(
const
HipBlasltUserArgsCache
&
)
=
delete
;
HipBlasltUserArgsCache
(
const
HipBlasltUserArgsCache
&
)
=
delete
;
HipBlasltUserArgs
Buffer
&
operator
=
(
const
HipBlasltUserArgs
Buffer
&
)
=
delete
;
HipBlasltUserArgs
Cache
&
operator
=
(
const
HipBlasltUserArgs
Cache
&
)
=
delete
;
HipBlasltUserArgsBuffer
&
getBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
HipBlasltUserArgsBuffer
&
getBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
{
{
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>&
buffers
=
host
?
host_buffers_
:
device_buffers_
;
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>&
buffers
=
host
?
host_buffers_
:
device_buffers_
;
...
@@ -1427,7 +1427,7 @@ struct HipBlasltUserArgsCacheManager {
...
@@ -1427,7 +1427,7 @@ struct HipBlasltUserArgsCacheManager {
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
const
Tensor
*>&
bias
,
bool
use_bias
,
bool
grad
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
void
*
workspace
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
...
@@ -1467,6 +1467,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1467,6 +1467,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std
::
vector
<
hipblaslt_ext
::
GemmEpilogue
>
epilogue
{
hipblaslt_ext
::
GemmEpilogue
()};
std
::
vector
<
hipblaslt_ext
::
GemmEpilogue
>
epilogue
{
hipblaslt_ext
::
GemmEpilogue
()};
if
(
use_bias
&&
!
grad
)
{
const
hipDataType
bias_type
=
get_hipblaslt_dtype
(
bias
[
0
]
->
data
.
dtype
);
NVTE_CHECK
(
bias_type
==
HIP_R_32F
||
bias_type
==
HIP_R_16BF
);
epilogue
[
0
].
mode
=
HIPBLASLT_EPILOGUE_BIAS
;
epilogue
[
0
].
bias_data_type
=
bias_type
;
}
std
::
vector
<
hipblaslt_ext
::
GemmInputs
>
inputs
(
m
.
size
());
std
::
vector
<
hipblaslt_ext
::
GemmInputs
>
inputs
(
m
.
size
());
for
(
int
i
=
0
;
i
<
m
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
m
.
size
();
i
++
)
{
assert
(
m
[
i
]
!=
0
);
assert
(
m
[
i
]
!=
0
);
...
@@ -1477,6 +1484,7 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1477,6 +1484,7 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
inputs
[
i
].
b
=
inputB
[
i
]
->
data
.
dptr
;
inputs
[
i
].
b
=
inputB
[
i
]
->
data
.
dptr
;
inputs
[
i
].
c
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
c
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
d
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
d
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
bias
=
bias
[
i
]
->
data
.
dptr
;
inputs
[
i
].
alpha
=
use_int8
?
static_cast
<
void
*>
(
&
int_one
)
:
static_cast
<
void
*>
(
&
one
);
inputs
[
i
].
alpha
=
use_int8
?
static_cast
<
void
*>
(
&
int_one
)
:
static_cast
<
void
*>
(
&
one
);
inputs
[
i
].
beta
=
use_int8
?
static_cast
<
void
*>
(
&
int_beta
)
:
static_cast
<
void
*>
(
&
beta
);
inputs
[
i
].
beta
=
use_int8
?
static_cast
<
void
*>
(
&
int_beta
)
:
static_cast
<
void
*>
(
&
beta
);
}
}
...
@@ -1512,6 +1520,26 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1512,6 +1520,26 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
device_args
,
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
device_args
,
stream
));
device_user_args
.
setStream
(
stream
);
device_user_args
.
setStream
(
stream
);
NVTE_CHECK_CUDA
(
hipEventRecord
(
device_event
,
stream
));
NVTE_CHECK_CUDA
(
hipEventRecord
(
device_event
,
stream
));
if
(
use_bias
&&
grad
)
{
DType
input_type
=
inputB
[
0
]
->
data
.
dtype
;
DType
bias_type
=
bias
[
0
]
->
data
.
dtype
;
NVTE_CHECK
(
bias_type
==
DType
::
kFloat32
||
bias_type
==
DType
::
kFloat16
||
bias_type
==
DType
::
kBFloat16
);
for
(
int
i
=
0
;
i
<
m
.
size
();
++
i
)
{
void
*
input_ptr
=
inputB
[
i
]
->
data
.
dptr
;
void
*
bias_ptr
=
bias
[
i
]
->
data
.
dptr
;
int
batch_size
=
static_cast
<
int
>
(
k
[
i
]);
int
output_dim
=
static_cast
<
int
>
(
n
[
i
]);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
input_type
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
bias_type
,
OType
,
detail
::
bias_gradient_kernelLauncher
<
IType
,
OType
>
(
reinterpret_cast
<
const
IType
*>
(
input_ptr
),
reinterpret_cast
<
OType
*>
(
bias_ptr
),
batch_size
,
output_dim
,
true
,
stream
);));
}
}
}
}
#endif //USE_HIPBLASLT
#endif //USE_HIPBLASLT
...
@@ -1738,7 +1766,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1738,7 +1766,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
output_dtype
,
OType
,
output_dtype
,
OType
,
detail
::
bias_gradient_kernelLauncher
<
OType
>
(
detail
::
bias_gradient_kernelLauncher
<
OType
,
float
>
(
reinterpret_cast
<
const
OType
*>
(
D
),
reinterpret_cast
<
float
*>
(
bias_tmp
),
batch_size
,
reinterpret_cast
<
const
OType
*>
(
D
),
reinterpret_cast
<
float
*>
(
bias_tmp
),
batch_size
,
input_dim
,
stream_order_alloc
,
stream
););
input_dim
,
stream_order_alloc
,
stream
););
...
@@ -1808,7 +1836,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1808,7 +1836,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
DType
bias_dtype
=
get_transformer_engine_dtype
(
bias_type
);
DType
bias_dtype
=
get_transformer_engine_dtype
(
bias_type
);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
input_dtype
,
IType
,
input_dtype
,
IType
,
detail
::
bias_gradient_kernelLauncher
<
IType
>
(
detail
::
bias_gradient_kernelLauncher
<
IType
,
float
>
(
reinterpret_cast
<
const
IType
*>
(
B
),
reinterpret_cast
<
float
*>
(
bias_tmp
),
batch_size
,
reinterpret_cast
<
const
IType
*>
(
B
),
reinterpret_cast
<
float
*>
(
bias_tmp
),
batch_size
,
output_dim
,
stream_order_alloc
,
stream
););
output_dim
,
stream_order_alloc
,
stream
););
if
(
bias_type
!=
rocblas_datatype_f32_r
)
{
if
(
bias_type
!=
rocblas_datatype_f32_r
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment