Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
92405ddf
Commit
92405ddf
authored
May 23, 2025
by
Chenggang Zhao
Browse files
Code cleanup and bug fixed
parent
68ae8b3d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
45 additions
and
48 deletions
+45
-48
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+2
-1
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+14
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+21
-43
deep_ep/buffer.py
deep_ep/buffer.py
+7
-3
No files found.
csrc/deep_ep.cpp
View file @
92405ddf
...
@@ -1190,8 +1190,9 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1190,8 +1190,9 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
}
}
torch
::
Tensor
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
{
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
dtype
=
torch
::
kBFloat16
;
auto
dtype
=
torch
::
kBFloat16
;
auto
num_msg_elems
=
static_cast
<
int
>
(
buffer
.
num_bytes_per_combine_msg
/
elementSize
(
torch
::
kBFloat16
));
auto
num_msg_elems
=
static_cast
<
int
>
(
buffer
.
num_bytes_per_combine_msg
/
elementSize
(
torch
::
kBFloat16
));
...
...
csrc/deep_ep.hpp
View file @
92405ddf
...
@@ -147,7 +147,7 @@ public:
...
@@ -147,7 +147,7 @@ public:
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
torch
::
Tensor
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
;
};
};
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/ibgda_device.cuh
View file @
92405ddf
...
@@ -432,4 +432,18 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
...
@@ -432,4 +432,18 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
}
}
}
}
__device__
__forceinline__
uint64_t
nvshmemi_get_p2p_ptr
(
const
uint64_t
&
ptr
,
const
int
&
rank
,
const
int
&
dst_rank
)
{
// Local rank, no need for mapping
if
(
rank
==
dst_rank
)
return
ptr
;
auto
peer_base
=
__ldg
(
reinterpret_cast
<
uint64_t
*>
(
nvshmemi_device_state_d
.
peer_heap_base_p2p
)
+
dst_rank
);
// RDMA connected
if
(
peer_base
==
0
)
return
0
;
// NVLink P2P is enabled
return
peer_base
+
(
ptr
-
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
));
}
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/internode_ll.cu
View file @
92405ddf
...
@@ -149,20 +149,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -149,20 +149,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
const
auto
dst_p2p_ptr
=
nvshmemi_get_p2p_ptr
(
dst_ptr
,
rank
,
dst_rank
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
dst_p2p_ptr
==
0
)
{
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
else
{
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
dst_rank
,
dst_expert_local_idx
,
lane_id
,
slot_idx
);
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
dst_rank
,
dst_expert_local_idx
,
lane_id
,
slot_idx
);
}
}
else
{
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_
p2p_
ptr
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
}
...
@@ -222,16 +215,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -222,16 +215,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
auto
dst_p2p_ptr
=
nvshmemi_get_p2p_ptr
(
dst_ptr
,
rank
,
dst_rank
);
if
(
peer_base_addr
)
{
// P2P enabled
if
(
dst_p2p_ptr
==
0
)
{
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
nvshmemi_ibgda_amo_nonfetch_add
(
reinterpret_cast
<
int
*>
(
dst_ptr
),
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
);
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
}
else
{
}
else
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
);
st_release_sys_global
(
reinterpret_cast
<
int
*>
(
dst_p2p_ptr
),
-
num_tokens_sent
-
1
);
}
}
else
{
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
}
}
// Clean workspace for next use
// Clean workspace for next use
...
@@ -428,22 +417,15 @@ combine(void* combined_x,
...
@@ -428,22 +417,15 @@ combine(void* combined_x,
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
if
(
dst_rank
==
rank
)
{
const
auto
dst_p2p_ptr
=
nvshmemi_get_p2p_ptr
(
dst_ptr
,
rank
,
dst_rank
);
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
if
(
dst_p2p_ptr
==
0
)
{
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
if
(
not
zero_copy
)
if
(
not
zero_copy
)
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
buf_ptr
,
hidden
*
sizeof
(
nv_bfloat16
),
dst_rank
,
local_expert_idx
,
lane_id
,
token_idx
-
offset
);
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
buf_ptr
,
hidden
*
sizeof
(
nv_bfloat16
),
dst_rank
,
local_expert_idx
,
lane_id
,
token_idx
-
offset
);
}
}
else
{
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
}
}
}
...
@@ -452,16 +434,12 @@ combine(void* combined_x,
...
@@ -452,16 +434,12 @@ combine(void* combined_x,
asm
volatile
(
"bar.sync %0, %1;"
::
"r"
(
warp_group_id
+
1
),
"r"
(
kNumWarpsPerGroup
*
32
));
asm
volatile
(
"bar.sync %0, %1;"
::
"r"
(
warp_group_id
+
1
),
"r"
(
kNumWarpsPerGroup
*
32
));
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_flag
+
global_expert_idx
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
auto
dst_p2p_ptr
=
nvshmemi_get_p2p_ptr
(
dst_ptr
,
rank
,
dst_rank
);
if
(
peer_base_addr
)
{
if
(
dst_p2p_ptr
==
0
)
{
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
nvshmemi_ibgda_amo_nonfetch_add
(
reinterpret_cast
<
int
*>
(
dst_ptr
),
1
,
dst_rank
,
local_expert_idx
);
st_na_release
(
req_rptr_actual
,
1
);
}
else
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
,
local_expert_idx
);
}
}
else
{
}
else
{
st_
na_
release
(
rdma_recv_flag
+
global_expert_idx
,
1
);
st_release
_sys_global
(
reinterpret_cast
<
int
*>
(
dst_p2p_ptr
)
,
1
);
}
}
atomic_add_release_global
(
atomic_clean_flag
,
-
1
);
atomic_add_release_global
(
atomic_clean_flag
,
-
1
);
}
}
...
@@ -473,7 +451,7 @@ combine(void* combined_x,
...
@@ -473,7 +451,7 @@ combine(void* combined_x,
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
return
;
// Wait all ranks to arrive
and notify PCIe usage
// Wait all ranks to arrive
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
...
...
deep_ep/buffer.py
View file @
92405ddf
...
@@ -31,7 +31,8 @@ class Buffer:
...
@@ -31,7 +31,8 @@ class Buffer:
def
__init__
(
self
,
group
:
dist
.
ProcessGroup
,
def
__init__
(
self
,
group
:
dist
.
ProcessGroup
,
num_nvl_bytes
:
int
=
0
,
num_rdma_bytes
:
int
=
0
,
num_nvl_bytes
:
int
=
0
,
num_rdma_bytes
:
int
=
0
,
low_latency_mode
:
bool
=
False
,
num_qps_per_rank
:
int
=
12
)
->
None
:
low_latency_mode
:
bool
=
False
,
num_qps_per_rank
:
int
=
12
,
allow_nvlink_for_low_latency_mode
:
bool
=
False
)
->
None
:
"""
"""
Initialize the communication buffer.
Initialize the communication buffer.
...
@@ -42,6 +43,10 @@ class Buffer:
...
@@ -42,6 +43,10 @@ class Buffer:
low_latency_mode: whether to enable low-latency mode.
low_latency_mode: whether to enable low-latency mode.
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
to the number of local experts.
to the number of local experts.
allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice
this is somehow incompatible with the hook-based overlapping.
Warning: PCIe connections may lead to errors due to memory ordering issues,
please make sure all connections are via NVLink.
"""
"""
# Initialize the CPP runtime
# Initialize the CPP runtime
...
@@ -68,8 +73,7 @@ class Buffer:
...
@@ -68,8 +73,7 @@ class Buffer:
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
or
low_latency_mode
:
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
or
low_latency_mode
:
# Enable IBGDA
# Enable IBGDA
assert
num_qps_per_rank
>
0
assert
num_qps_per_rank
>
0
if
not
os
.
getenv
(
"NVSHMEM_DISABLE_P2P"
):
os
.
environ
[
'NVSHMEM_DISABLE_P2P'
]
=
'0'
if
allow_nvlink_for_low_latency_mode
else
'1'
os
.
environ
[
'NVSHMEM_DISABLE_P2P'
]
=
'1'
os
.
environ
[
'NVSHMEM_IB_ENABLE_IBGDA'
]
=
'1'
os
.
environ
[
'NVSHMEM_IB_ENABLE_IBGDA'
]
=
'1'
os
.
environ
[
'NVSHMEM_IBGDA_NIC_HANDLER'
]
=
'gpu'
os
.
environ
[
'NVSHMEM_IBGDA_NIC_HANDLER'
]
=
'gpu'
os
.
environ
[
'NVSHMEM_IBGDA_NUM_RC_PER_PE'
]
=
f
'
{
num_qps_per_rank
}
'
os
.
environ
[
'NVSHMEM_IBGDA_NUM_RC_PER_PE'
]
=
f
'
{
num_qps_per_rank
}
'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment