Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
6ed9a3e4
Commit
6ed9a3e4
authored
May 22, 2025
by
wenjh
Browse files
[DCU] Add width to __shfl
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
b27e513d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+4
-2
No files found.
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
6ed9a3e4
...
...
@@ -135,7 +135,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
);
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
,
THREADS_PER_WARP
);
__syncthreads
();
#else
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
#endif
...
...
@@ -362,7 +363,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
);
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
,
THREADS_PER_WARP
);
__syncthreads
();
#else
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment