Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
d1bf10d3
Commit
d1bf10d3
authored
Nov 14, 2025
by
lishen
Browse files
基于rocm的DeepEP,低延迟优化
parent
ee3551ab
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
88 additions
and
62 deletions
+88
-62
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+88
-62
No files found.
csrc/kernels/internode_ll.cu
View file @
d1bf10d3
...
@@ -36,7 +36,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) {
...
@@ -36,7 +36,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) {
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
while
(
__hip_atomic_load
(
global_counter
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
)
!=
num_blocks
);
while
(
__hip_atomic_load
(
global_counter
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
)
!=
num_blocks
);
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -69,7 +69,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
...
@@ -69,7 +69,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
for
(
int
i
=
thread_id
;
i
<
num_clean_int_1
;
i
+=
kNumThreads
)
for
(
int
i
=
thread_id
;
i
<
num_clean_int_1
;
i
+=
kNumThreads
)
clean_1
[
i
]
=
0
;
clean_1
[
i
]
=
0
;
// Barrier after cleaning (make sure low-latency mode work
// Barrier after cleaning (make sure low-latency mode work
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
internode
::
shmem_device_barrier_all
();
internode
::
shmem_device_barrier_all
();
}
}
...
@@ -96,13 +96,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -96,13 +96,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
bool
round_scale
,
int
phases
)
{
#if !defined(ROCM_DISABLE_CTX)
__shared__
internode
::
shmem_ctx_t
ctx
;
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
...
@@ -131,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -131,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
__shared__
volatile
int
sync_large_warp_counters
[
num_sync_large_iteration
];
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_sync_large_iteration
;
i
+=
blockDim
.
x
)
{
sync_large_warp_counters
[
i
]
=
0
;
}
__syncthreads
();
// Expert counts
// Expert counts
constexpr
int
kNumMaxWarpGroups
=
1024
/
kWarpSize
;
constexpr
int
kNumMaxWarpGroups
=
1024
/
kWarpSize
;
__shared__
int
shared_num_tokens_sent_per_expert
[
kNumMaxWarpGroups
];
__shared__
int
shared_num_tokens_sent_per_expert
[
kNumMaxWarpGroups
];
...
@@ -150,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -150,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
goto
LOW_LATENCY_DISPATCH_RECV
;
goto
LOW_LATENCY_DISPATCH_RECV
;
#if !defined(ROCM_DISABLE_CTX)
__shared__
internode
::
shmem_ctx_t
ctx
;
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
// There are 2 kinds of warps in this part:
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
// 2. The last warp for reading `topk_idx` and count for per-expert information
...
@@ -220,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -220,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
#if !defined(ROCM_DISABLE_CTX)
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
internode
::
shmem_fence
();
#else
internode
::
shmemx_int8_put_nbi_warp
(
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
}
else
{
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
@@ -274,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -274,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
}
}
}
}
//revert sync_large_warp_counters to 0 for next sync
__syncthreads
();
__syncthreads
();
// Issue count sends
// Issue count sends
...
@@ -287,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -287,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
internode
::
shmem_long_atomic_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
#else
internode
::
shmem_long_atomic_add
(
#endif
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
}
}
...
@@ -302,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -302,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
syncwarp
();
syncwarp
();
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
// Receiving phase
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
LOW_LATENCY_DISPATCH_RECV:
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
...
@@ -312,20 +317,31 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -312,20 +317,31 @@ LOW_LATENCY_DISPATCH_RECV:
grid_barrier
(
global_atomic_counter
,
num_sms
);
grid_barrier
(
global_atomic_counter
,
num_sms
);
}
}
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
__shared__
volatile
int
sync_large_warp_counters
[
num_sync_large_iteration
];
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_sync_large_iteration
;
i
+=
blockDim
.
x
)
{
sync_large_warp_counters
[
i
]
=
0
;
}
__syncthreads
();
// Receiving and packing
// Receiving and packing
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
const
auto
local_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
const
auto
local_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
const
auto
rdma_recv_x_uint8
=
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
const
auto
rdma_recv_x_uint8
=
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
const
auto
recv_x_int4
=
reinterpret_cast
<
int4
*>
(
packed_recv_x
)
+
const
auto
recv_x_int4
=
reinterpret_cast
<
int4
*>
(
packed_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
num_scales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
num_scales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
// Shared between sub-warps in warp groups
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
...
@@ -393,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -393,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV:
}
}
}
}
}
}
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
}
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
...
@@ -407,9 +419,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -407,9 +419,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
11
;
constexpr
int
kNumMaxTopK
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
...
@@ -464,11 +476,6 @@ combine(void* combined_x,
...
@@ -464,11 +476,6 @@ combine(void* combined_x,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
#if !defined(ROCM_DISABLE_CTX)
__shared__
internode
::
shmem_ctx_t
ctx
;
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
...
@@ -488,7 +495,7 @@ combine(void* combined_x,
...
@@ -488,7 +495,7 @@ combine(void* combined_x,
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
hip_bfloat16
);
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
hip_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 16 is the max possible number of warps in AMD GPUs
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
if
(
threadIdx
.
x
==
0
){
if
(
threadIdx
.
x
==
0
){
...
@@ -503,6 +510,11 @@ combine(void* combined_x,
...
@@ -503,6 +510,11 @@ combine(void* combined_x,
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
goto
LOW_LATENCY_COMBINE_RECV
;
goto
LOW_LATENCY_COMBINE_RECV
;
#if !defined(ROCM_DISABLE_CTX)
__shared__
internode
::
shmem_ctx_t
ctx
;
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
// Clean up next buffer
// Clean up next buffer
if
(
sm_id
==
0
and
warp_group_id
==
0
and
sub_warp_id
==
0
)
{
if
(
sm_id
==
0
and
warp_group_id
==
0
and
sub_warp_id
==
0
)
{
#pragma unroll
#pragma unroll
...
@@ -522,10 +534,10 @@ combine(void* combined_x,
...
@@ -522,10 +534,10 @@ combine(void* combined_x,
const
auto
global_expert_idx
=
rank
*
num_local_experts
+
local_expert_idx
;
const
auto
global_expert_idx
=
rank
*
num_local_experts
+
local_expert_idx
;
const
auto
layout
=
__ldg
(
layout_range
+
local_expert_idx
*
num_ranks
+
dst_rank
);
const
auto
layout
=
__ldg
(
layout_range
+
local_expert_idx
*
num_ranks
+
dst_rank
);
const
auto
local_x
=
reinterpret_cast
<
const
int4
*>
(
x
)
+
const
auto
local_x
=
reinterpret_cast
<
const
int4
*>
(
x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_bf16_int4
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_bf16_int4
;
const
auto
local_src_info
=
src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
local_src_info
=
src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
rdma_send_x_vec
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_x
)
+
const
auto
rdma_send_x_vec
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_slot
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_slot
;
// Unpack layout
// Unpack layout
int
offset
,
num_tokens_to_send
;
int
offset
,
num_tokens_to_send
;
...
@@ -548,21 +560,16 @@ combine(void* combined_x,
...
@@ -548,21 +560,16 @@ combine(void* combined_x,
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
if
(
not
zero_copy
)
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX)
internode
::
shmemx_int8_put_nbi_warp
(
#else
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
#if defined(ROCM_DISABLE_CTX)
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
internode
::
shmem_fence
();
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
#else
internode
::
shmem
_ctx_quiet
(
ctx
);
internode
::
shmem
x_int8_put_nbi_warp
(
#endif
#endif
}
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
}
}
}
// Put finishing flag
// Put finishing flag
...
@@ -573,27 +580,49 @@ combine(void* combined_x,
...
@@ -573,27 +580,49 @@ combine(void* combined_x,
}
}
syncwarp
();
syncwarp
();
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_group
);
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_group
);
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if defined(ROCM_DISABLE_CTX)
#if
!
defined(ROCM_DISABLE_CTX)
internode
::
shmem_long_atomic_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
internode
::
shmem_
ctx_
long_atomic_add
(
ctx
,
#else
#else
internode
::
shmem_
ctx_
long_atomic_add
(
ctx
,
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
internode
::
shmem_long_atomic_add
(
#endif
#endif
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
}
}
atomic_add_release_global
(
atomic_clean_flag
,
-
1
);
atomic_add_release_global
(
atomic_clean_flag
,
-
1
);
}
}
syncwarp
();
syncwarp
();
if
(
num_ranks
>
8
){
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_quiet
(
ctx
);
#else
internode
::
shmem_fence
();
#endif
}
}
}
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
// Receiving phase
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
LOW_LATENCY_COMBINE_RECV:
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
return
;
// if (num_ranks > 8){
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
// }
// Wait all ranks to arrive and notify PCIe usage
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
...
@@ -641,9 +670,6 @@ combine(void* combined_x,
...
@@ -641,9 +670,6 @@ combine(void* combined_x,
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
)[
thread_id
]
=
combined_int4
;
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
)[
thread_id
]
=
combined_int4
;
}
}
}
}
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
}
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment