Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
eac75188
Commit
eac75188
authored
Jul 22, 2025
by
yuguo
Browse files
[DCU] fix merge
parent
44740c6c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
1 deletion
+26
-1
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
+4
-0
transformer_engine/common/fused_router/utils.h
transformer_engine/common/fused_router/utils.h
+21
-0
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+1
-1
No files found.
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
View file @
eac75188
...
...
@@ -168,6 +168,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
int
num_cols
,
int
topk
,
float
coeff
,
DataType
*
aux_loss
,
float
*
Const_buf
,
cudaStream_t
stream
)
{
#ifndef __HIP_PLATFORM_AMD__
if
(
cuda
::
sm_arch
(
cuda
::
current_device
())
>=
90
)
{
cudaLaunchConfig_t
config
=
{
0
};
int
cluster_size
=
8
;
...
...
@@ -193,11 +194,14 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
aux_loss
,
Const_buf
);
}
else
{
#endif
size_t
smem_size
=
sizeof
(
CompType
)
*
num_cols
;
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
<<<
1
,
1024
,
smem_size
,
stream
>>>
(
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
aux_loss
,
Const_buf
);
#ifndef __HIP_PLATFORM_AMD__
}
#endif
}
void
fused_moe_aux_loss_forward
(
const
Tensor
&
probs
,
const
Tensor
&
tokens_per_expert
,
...
...
transformer_engine/common/fused_router/utils.h
View file @
eac75188
...
...
@@ -39,11 +39,19 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_
}
// Warp shuffle between threads
#ifdef __HIP_PLATFORM_AMD__
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
16
,
kThreadsPerWarp
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
8
,
kThreadsPerWarp
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
4
,
kThreadsPerWarp
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
2
,
kThreadsPerWarp
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
1
,
kThreadsPerWarp
));
#else
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
16
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
8
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
4
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
2
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
1
));
#endif
__syncwarp
();
return
T
(
val
);
}
...
...
@@ -71,11 +79,19 @@ __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int dat
}
// Warp shuffle between threads
#ifdef __HIP_PLATFORM_AMD__
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
16
,
kThreadsPerWarp
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
8
,
kThreadsPerWarp
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
4
,
kThreadsPerWarp
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
2
,
kThreadsPerWarp
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
1
,
kThreadsPerWarp
));
#else
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
16
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
8
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
4
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
2
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
1
));
#endif
__syncwarp
();
return
T
(
val
);
}
...
...
@@ -165,8 +181,13 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
}
// Warp shuffle between threads
for
(
int
s
=
16
;
s
>
0
;
s
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
volatile
auto
shuffled_val
=
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
val
,
s
,
kThreadsPerWarp
);
volatile
auto
shuffled_index
=
__shfl_xor_sync
((
unsigned
long
long
)
0xffffffff
,
index
,
s
,
kThreadsPerWarp
);
#else
volatile
auto
shuffled_val
=
__shfl_xor_sync
(
0xffffffff
,
val
,
s
);
volatile
auto
shuffled_index
=
__shfl_xor_sync
(
0xffffffff
,
index
,
s
);
#endif
if
(
shuffled_val
>
val
)
{
val
=
shuffled_val
;
index
=
shuffled_index
;
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
eac75188
...
...
@@ -382,7 +382,7 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t
k_dim
=
shape
.
size
()
==
0
?
1u
:
shape
.
back
();
size_t
m_dim
=
numel
/
k_dim
;
constexpr
size_t
kBlockLen
=
128
;
size_t
kBlockLen
=
static_cast
<
size_t
>
(
blockwise_fp8_block_len
())
;
Float8BlockScaleTensorFormat
data_format
=
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment