Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
68ae8b3d
Unverified
Commit
68ae8b3d
authored
May 23, 2025
by
cywork121
Committed by
GitHub
May 23, 2025
Browse files
Feature: LL nvlink p2p (#173)
parent
d5ca4495
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
5 deletions
+34
-5
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+32
-4
deep_ep/buffer.py
deep_ep/buffer.py
+2
-1
No files found.
csrc/kernels/internode_ll.cu
View file @
68ae8b3d
...
@@ -150,7 +150,15 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -150,7 +150,15 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
dst_rank
,
dst_expert_local_idx
,
lane_id
,
slot_idx
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
else
{
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
dst_rank
,
dst_expert_local_idx
,
lane_id
,
slot_idx
);
}
}
else
{
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
@@ -215,7 +223,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -215,7 +223,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
// P2P enabled
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
}
else
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
);
}
}
else
{
}
else
{
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
}
}
...
@@ -421,7 +435,15 @@ combine(void* combined_x,
...
@@ -421,7 +435,15 @@ combine(void* combined_x,
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
if
(
not
zero_copy
)
if
(
not
zero_copy
)
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
buf_ptr
,
hidden
*
sizeof
(
nv_bfloat16
),
dst_rank
,
local_expert_idx
,
lane_id
,
token_idx
-
offset
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
buf_ptr
,
hidden
*
sizeof
(
nv_bfloat16
),
dst_rank
,
local_expert_idx
,
lane_id
,
token_idx
-
offset
);
}
}
}
}
}
...
@@ -431,7 +453,13 @@ combine(void* combined_x,
...
@@ -431,7 +453,13 @@ combine(void* combined_x,
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
,
local_expert_idx
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
req_rptr_actual
,
1
);
}
else
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
,
local_expert_idx
);
}
}
else
{
}
else
{
st_na_release
(
rdma_recv_flag
+
global_expert_idx
,
1
);
st_na_release
(
rdma_recv_flag
+
global_expert_idx
,
1
);
}
}
...
...
deep_ep/buffer.py
View file @
68ae8b3d
...
@@ -68,7 +68,8 @@ class Buffer:
...
@@ -68,7 +68,8 @@ class Buffer:
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
or
low_latency_mode
:
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
or
low_latency_mode
:
# Enable IBGDA
# Enable IBGDA
assert
num_qps_per_rank
>
0
assert
num_qps_per_rank
>
0
os
.
environ
[
'NVSHMEM_DISABLE_P2P'
]
=
'1'
if
not
os
.
getenv
(
"NVSHMEM_DISABLE_P2P"
):
os
.
environ
[
'NVSHMEM_DISABLE_P2P'
]
=
'1'
os
.
environ
[
'NVSHMEM_IB_ENABLE_IBGDA'
]
=
'1'
os
.
environ
[
'NVSHMEM_IB_ENABLE_IBGDA'
]
=
'1'
os
.
environ
[
'NVSHMEM_IBGDA_NIC_HANDLER'
]
=
'gpu'
os
.
environ
[
'NVSHMEM_IBGDA_NIC_HANDLER'
]
=
'gpu'
os
.
environ
[
'NVSHMEM_IBGDA_NUM_RC_PER_PE'
]
=
f
'
{
num_qps_per_rank
}
'
os
.
environ
[
'NVSHMEM_IBGDA_NUM_RC_PER_PE'
]
=
f
'
{
num_qps_per_rank
}
'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment