Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
736e8f8b
Commit
736e8f8b
authored
Jul 18, 2025
by
yuguo
Browse files
[DCU] fix bias gradient
parent
f5349823
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
1 deletion
+36
-1
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+36
-1
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
736e8f8b
...
...
@@ -352,6 +352,39 @@ void __launch_bounds__(THREADS_PER_BLOCK) bias_gradient_kernel(const Tin* in, fl
atomicAdd
(
&
out
[
col_idx
],
local_sum
);
}
constexpr
int
kColwiseReduceTileSize
=
32
;
template
<
typename
T
>
__inline__
__device__
T
WarpReduceSum
(
T
val
,
int
max
=
32
)
{
for
(
int
offset
=
max
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
__shfl_down
(
val
,
offset
);
}
return
val
;
}
template
<
typename
InputType
>
__launch_bounds__
(
1024
)
__global__
void
bias_gradient_kernel_v2
(
float
*
dst
,
const
InputType
*
src
,
int
M
,
int
N
)
{
__shared__
float
g_shared
[
kColwiseReduceTileSize
][
kColwiseReduceTileSize
];
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float
grad_sum
=
0.
f
;
if
(
j
<
N
)
{
for
(
int
i
=
threadIdx
.
y
;
i
<
M
;
i
+=
blockDim
.
y
)
{
grad_sum
+=
static_cast
<
float
>
(
src
[
i
*
N
+
j
]);
}
}
g_shared
[
threadIdx
.
y
][
threadIdx
.
x
]
=
grad_sum
;
__syncthreads
();
float
sum
=
g_shared
[
threadIdx
.
x
][
threadIdx
.
y
];
sum
=
WarpReduceSum
<
float
>
(
sum
,
kColwiseReduceTileSize
/
2
);
if
(
threadIdx
.
x
==
0
)
{
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
;
if
(
j
<
N
)
{
dst
[
j
]
=
static_cast
<
float
>
(
sum
);
}
}
}
template
<
typename
Tin
>
void
bias_gradient_kernelLauncher
(
const
Tin
*
in
,
float
*
out
,
int
m
,
int
n
,
bool
stream_order_alloc
,
hipStream_t
stream
)
{
dim3
block
,
grid
;
...
...
@@ -364,7 +397,9 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
}
else
{
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
float
),
stream
)
);
}
hipLaunchKernelGGL
((
bias_gradient_kernel
<
Tin
,
THREADS_PER_BLOCK
>
),
dim3
(
grid
),
dim3
(
block
),
0
,
stream
,
in
,
out
,
m
,
n
);
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int
B
=
(
n
-
1
)
/
kColwiseReduceTileSize
+
1
;
bias_gradient_kernel_v2
<
Tin
><<<
B
,
dim3
(
kColwiseReduceTileSize
,
kColwiseReduceTileSize
),
0
,
stream
>>>
(
out
,
in
,
m
,
n
);
}
}
// namespace detail
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment