Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
0395bf27
Commit
0395bf27
authored
Mar 31, 2026
by
lishen
Browse files
nmz上normal-dispatch优化
parent
a42ecdc0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
54 deletions
+69
-54
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+44
-48
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+25
-6
No files found.
csrc/kernels/internode.cu
View file @
0395bf27
...
@@ -21,7 +21,7 @@ namespace internode {
...
@@ -21,7 +21,7 @@ namespace internode {
extern
shmem_team_t
cpu_rdma_team
;
extern
shmem_team_t
cpu_rdma_team
;
struct
SourceMeta
{
struct
SourceMeta
{
int
src_rdma_rank
,
is_token_in_nvl_rank_bits
;
int
src_rdma_rank
,
is_token_in_nvl_rank_bits
;
// sizeof(SourceMeta) = 8
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
==
8
,
"Invalid number of maximum NVL peers"
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
==
8
,
"Invalid number of maximum NVL peers"
);
...
@@ -619,47 +619,40 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -619,47 +619,40 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
}
EP_DEVICE_ASSERT
(
num_topk_ranks
<=
kNumTopkRDMARanks
);
EP_DEVICE_ASSERT
(
num_topk_ranks
<=
kNumTopkRDMARanks
);
//////////////// 复制数据到发送缓冲区 ////////////////
// 复制源元数据到对称发送缓冲区
if
(
lane_id
<
num_topk_ranks
)
{
st_na_global
(
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
lane_id
]),
src_meta
);
}
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
i
])
+
1
;
}
// 复制 `x` 到对称发送缓冲区
// 复制 `x` 到对称发送缓冲区
auto
st_broadcast
=
[
=
](
const
int
key
,
const
int4
&
value
)
{
auto
st_broadcast
=
[
=
](
const
int
key
,
const
int4
&
value
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
st_na_global
(
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
j
])
+
key
,
value
);
st_na_global
(
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
j
])
+
key
,
value
);
}
}
};
};
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
0
,
x
+
token_idx
*
hidden_int4
,
ld_nc_global
,
st_broadcast
);
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
0
,
x
+
token_idx
*
hidden_int4
,
ld_nc_global
,
st_broadcast
);
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
i
])
+
hidden_int4
;
dst_send_buffers
[
i
]
=
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
i
])
+
hidden_int4
;
}
}
// 复制源元数据到对称发送缓冲区
if
(
lane_id
<
num_topk_ranks
)
{
st_na_global
(
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
lane_id
]),
src_meta
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
i
])
+
1
;
}
// 复制 `x_scales` 到对称发送缓冲区
// 复制 `x_scales` 到对称发送缓冲区
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
kWarpSize
)
{
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
kWarpSize
)
{
auto
value
=
ld_nc_global
(
x_scales
+
token_idx
*
num_scales
+
i
);
auto
value
=
ld_nc_global
(
x_scales
+
token_idx
*
num_scales
+
i
);
// auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
// auto value = ld_nc_global(x_scales + offset);
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
}
}
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
i
])
+
num_scales
;
dst_send_buffers
[
i
]
=
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
i
])
+
num_scales
;
}
}
// 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
// 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_topk
*
num_topk_ranks
;
i
+=
kWarpSize
)
{
for
(
int
i
=
lane_id
;
i
<
num_topk
*
num_topk_ranks
;
i
+=
kWarpSize
)
{
auto
rank_idx
=
i
/
num_topk
,
copy_idx
=
i
%
num_topk
;
auto
rank_idx
=
i
/
num_topk
,
copy_idx
=
i
%
num_topk
;
auto
idx_value
=
static_cast
<
int
>
(
ld_nc_global
(
topk_idx
+
token_idx
*
num_topk
+
copy_idx
));
auto
idx_value
=
static_cast
<
int
>
(
ld_nc_global
(
topk_idx
+
token_idx
*
num_topk
+
copy_idx
));
...
@@ -899,7 +892,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -899,7 +892,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
auto
rdma_slot_idx
=
i
%
num_max_rdma_chunked_recv_tokens
;
auto
rdma_slot_idx
=
i
%
num_max_rdma_chunked_recv_tokens
;
// 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
// 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
void
*
shifted
=
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
void
*
shifted
=
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
src_meta
=
ld_nc_global
(
reinterpret_cast
<
SourceMeta
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
));
auto
src_meta
=
ld_nc_global
(
reinterpret_cast
<
SourceMeta
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)));
if
(
lane_id
==
src_rdma_rank
)
{
if
(
lane_id
==
src_rdma_rank
)
{
num_tokens_to_recv_from_rdma
-=
1
;
num_tokens_to_recv_from_rdma
-=
1
;
}
}
...
@@ -918,37 +911,40 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -918,37 +911,40 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 获取一个空闲槽位
// 获取一个空闲槽位
int
dst_slot_idx
=
(
cached_nvl_channel_tail
++
)
%
num_max_nvl_chunked_recv_tokens
;
int
dst_slot_idx
=
(
cached_nvl_channel_tail
++
)
%
num_max_nvl_chunked_recv_tokens
;
// 复制数据
// 设置 src和dst 位置
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
auto
src_gpu_buffer_x
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
sizeof
(
SourceMeta
));
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_int4
,
auto
src_gpu_buffer_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_x
)
+
hidden_bytes
);
reinterpret_cast
<
int4
*>
(
shifted
),
auto
src_gpu_buffer_topk_idx
=
reinterpret_cast
<
int
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_scales
)
+
num_scales
*
sizeof
(
float
));
ld_nc_global
,
st_na_global
);
auto
src_gpu_buffer_topk_weights
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_topk_idx
)
+
num_topk
*
sizeof
(
int
));
shifted
=
reinterpret_cast
<
int4
*>
(
shifted
)
+
hidden_int4
;
// 复制源元数据
auto
dst_gpu_buffer_x
=
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_int4
;
if
(
lane_id
==
0
)
auto
dst_gpu_buffer_scales
=
nvl_channel_x_scales
.
buffer
()
+
dst_slot_idx
*
num_scales
;
st_na_global
(
nvl_channel_
src_meta
.
buffer
()
+
dst_slot_idx
,
src_meta
)
;
auto
dst_gpu_buffer_topk_idx
=
nvl_channel_
topk_idx
.
buffer
()
+
dst_slot_idx
*
num_topk
;
shifted
=
reinterpret_cast
<
SourceMeta
*>
(
shifted
)
+
1
;
auto
dst_gpu_buffer_topk_weights
=
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
;
// 复制 `x_scales`
if
(
lane_id
==
0
)
{
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
st_na_global
(
reinterpret_cast
<
int64_t
*>
(
nvl_channel_src_meta
.
buffer
()
+
dst_slot_idx
),
nvl_channel_x_scales
.
buffer
()
+
dst_slot_idx
*
num_scales
,
*
reinterpret_cast
<
int64_t
*>
(
&
src_meta
));
reinterpret_cast
<
float
*>
(
shifted
),
}
ld_nc_global
,
st_na_global
);
shifted
=
reinterpret_cast
<
float
*>
(
shifted
)
+
num_scales
;
// 复制 `topk_idx` 和 `topk_weights`
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
if
(
lane_id
<
num_topk
)
{
dst_gpu_buffer_x
,
// 读取
src_gpu_buffer_x
,
auto
idx_value
=
ld_nc_global
(
reinterpret_cast
<
int
*>
(
shifted
)
+
lane_id
);
ld_direct_global
,
st_na_global
);
shifted
=
reinterpret_cast
<
int
*>
(
shifted
)
+
num_topk
;
auto
weight_value
=
ld_nc_global
(
reinterpret_cast
<
float
*>
(
shifted
)
+
lane_id
);
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
dst_gpu_buffer_scales
,
// 转换和写入
src_gpu_buffer_scales
,
idx_value
=
(
idx_value
>=
dst_rank_expert_begin
&&
idx_value
<
dst_rank_expert_end
)
?
idx_value
-
dst_rank_expert_begin
:
-
1
;
ld_direct_global
,
st_na_global
);
st_na_global
(
nvl_channel_topk_idx
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
idx_value
);
weight_value
=
idx_value
>=
0
?
weight_value
:
0.0
f
;
for
(
int
t
=
lane_id
;
t
<
num_topk
;
t
+=
kWarpSize
)
{
st_na_global
(
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
weight_value
);
int
idx_val
=
ld_direct_global
(
reinterpret_cast
<
int
*>
(
src_gpu_buffer_topk_idx
)
+
t
);
float
w_val
=
ld_direct_global
(
reinterpret_cast
<
float
*>
(
src_gpu_buffer_topk_weights
)
+
t
);
int
new_idx
=
(
idx_val
>=
dst_rank_expert_begin
&&
idx_val
<
dst_rank_expert_end
)
?
(
idx_val
-
dst_rank_expert_begin
)
:
-
1
;
float
new_w
=
(
new_idx
!=
-
1
)
?
w_val
:
0.0
f
;
dst_gpu_buffer_topk_idx
[
t
]
=
new_idx
;
dst_gpu_buffer_topk_weights
[
t
]
=
new_w
;
}
}
// 在NVL缓冲区不足的情况下,提前停止
// 在NVL缓冲区不足的情况下,提前停止
...
...
csrc/kernels/utils.cuh
View file @
0395bf27
...
@@ -54,6 +54,7 @@
...
@@ -54,6 +54,7 @@
}
}
// HELPER FUNCTIONS
// HELPER FUNCTIONS
// #####################################################################################
// #####################################################################################
#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
shfl_xor
(
const
T
val
,
int
laneMask
,
int
width
=
kWarpSize
,
__device__
__forceinline__
T
shfl_xor
(
const
T
val
,
int
laneMask
,
int
width
=
kWarpSize
,
...
@@ -118,7 +119,6 @@ __device__ __forceinline__ void trap() {
...
@@ -118,7 +119,6 @@ __device__ __forceinline__ void trap() {
}
}
__device__
__forceinline__
void
memory_fence
()
{
__device__
__forceinline__
void
memory_fence
()
{
__threadfence_system
();
__threadfence_system
();
}
}
...
@@ -151,11 +151,13 @@ __device__ __forceinline__ int ld_relaxed_sys_global(const int *ptr) {
...
@@ -151,11 +151,13 @@ __device__ __forceinline__ int ld_relaxed_sys_global(const int *ptr) {
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_SYSTEM
);
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_SYSTEM
);
return
ret
;
return
ret
;
}
}
__device__
__forceinline__
int
ld_relaxed_sys_global
(
const
uint64_t
*
ptr
)
{
__device__
__forceinline__
int
ld_relaxed_sys_global
(
const
uint64_t
*
ptr
)
{
uint64_t
ret
;
uint64_t
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_SYSTEM
);
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_SYSTEM
);
return
ret
;
return
ret
;
}
}
__device__
__forceinline__
int
ld_relaxed_sys_global
(
const
int64_t
*
ptr
)
{
__device__
__forceinline__
int
ld_relaxed_sys_global
(
const
int64_t
*
ptr
)
{
int64_t
ret
;
int64_t
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_SYSTEM
);
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_SYSTEM
);
...
@@ -180,7 +182,6 @@ __device__ __forceinline__ int64_t ld_acquire_sys_global(const int64_t *ptr) {
...
@@ -180,7 +182,6 @@ __device__ __forceinline__ int64_t ld_acquire_sys_global(const int64_t *ptr) {
return
ret
;
return
ret
;
}
}
__device__
__forceinline__
int
ld_acquire_global
(
const
int
*
ptr
)
{
__device__
__forceinline__
int
ld_acquire_global
(
const
int
*
ptr
)
{
int
ret
;
int
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_AGENT
);
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_AGENT
);
...
@@ -269,12 +270,22 @@ __device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
...
@@ -269,12 +270,22 @@ __device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
}
}
__device__
__forceinline__
void
st_na_relaxed
(
const
float
*
ptr
,
float
val
)
{
float
*
non_const_ptr
=
const_cast
<
float
*>
(
ptr
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
}
__device__
__forceinline__
void
st_na_relaxed
(
const
int64_t
*
ptr
,
int64_t
val
)
{
int64_t
*
non_const_ptr
=
const_cast
<
int64_t
*>
(
ptr
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
}
__device__
__forceinline__
void
st_na_relaxed
(
const
int4
*
ptr
,
int4
val
)
{
__device__
__forceinline__
void
st_na_relaxed
(
const
int4
*
ptr
,
int4
val
)
{
int4
*
non_const_ptr
=
const_cast
<
int4
*>
(
ptr
);
int4
*
non_const_ptr
=
const_cast
<
int4
*>
(
ptr
);
non_const_ptr
->
x
=
val
.
x
;
__hip_atomic_store
(
&
(
non_const_ptr
->
x
),
val
.
x
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
)
;
non_const_ptr
->
y
=
val
.
y
;
__hip_atomic_store
(
&
(
non_const_ptr
->
y
),
val
.
y
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
)
;
non_const_ptr
->
z
=
val
.
z
;
__hip_atomic_store
(
&
(
non_const_ptr
->
z
),
val
.
z
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
)
;
non_const_ptr
->
w
=
val
.
w
;
__hip_atomic_store
(
&
(
non_const_ptr
->
w
),
val
.
w
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
)
;
}
}
__device__
__forceinline__
void
st_na_release
(
const
int
*
ptr
,
int
val
)
{
__device__
__forceinline__
void
st_na_release
(
const
int
*
ptr
,
int
val
)
{
...
@@ -297,6 +308,14 @@ __device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) {
...
@@ -297,6 +308,14 @@ __device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) {
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
}
__device__
__forceinline__
void
st_na_release
(
const
int4
*
ptr
,
int4
val
)
{
int4
*
non_const_ptr
=
const_cast
<
int4
*>
(
ptr
);
__hip_atomic_store
(
&
(
non_const_ptr
->
x
),
val
.
x
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
__hip_atomic_store
(
&
(
non_const_ptr
->
y
),
val
.
y
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
__hip_atomic_store
(
&
(
non_const_ptr
->
z
),
val
.
z
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
__hip_atomic_store
(
&
(
non_const_ptr
->
w
),
val
.
w
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
// TODO:: apply "st.global.L1::no_allocate" in ROCM
// TODO:: apply "st.global.L1::no_allocate" in ROCM
template
<
typename
dtype_t
>
template
<
typename
dtype_t
>
__device__
__forceinline__
void
st_na_global
(
const
dtype_t
*
ptr
,
const
dtype_t
&
value
)
{
__device__
__forceinline__
void
st_na_global
(
const
dtype_t
*
ptr
,
const
dtype_t
&
value
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment