Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
6e838aa5
Commit
6e838aa5
authored
Dec 12, 2025
by
lijian6
Browse files
Feature: LL nvlink p2p for nvshmem.
Signed-off-by:
lijian
<
lijian6@sugon.com
>
parent
26298255
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
6 deletions
+43
-6
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+43
-6
No files found.
csrc/kernels/internode_ll.cu
View file @
6e838aa5
...
@@ -265,6 +265,15 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -265,6 +265,15 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
else
{
#endif
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
#else
...
@@ -272,11 +281,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -272,11 +281,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#endif
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
num_bytes_per_msg
,
dst_rank
);
// #if !defined(ROCM_DISABLE_CTX)
#if defined(FORCE_NVSHMEM_API)
// internode::shmem_ctx_quiet(ctx);
}
// #else
#endif
// internode::shmem_fence();
// #endif
}
else
{
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
@@ -339,12 +346,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -339,12 +346,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
// P2P enabled
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
}
else
{
#endif
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
#else
#else
internode
::
shmem_long_atomic_add
(
internode
::
shmem_long_atomic_add
(
#endif
#endif
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
#if defined(FORCE_NVSHMEM_API)
}
#endif
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
}
}
...
@@ -625,7 +642,14 @@ combine(void* combined_x,
...
@@ -625,7 +642,14 @@ combine(void* combined_x,
if
(
not
zero_copy
)
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
#endif
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
#else
...
@@ -633,6 +657,9 @@ combine(void* combined_x,
...
@@ -633,6 +657,9 @@ combine(void* combined_x,
#endif
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
#if defined(FORCE_NVSHMEM_API)
}
#endif
}
}
}
}
...
@@ -647,12 +674,22 @@ combine(void* combined_x,
...
@@ -647,12 +674,22 @@ combine(void* combined_x,
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
req_rptr_actual
,
1
);
}
else
{
#endif
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
#else
#else
internode
::
shmem_long_atomic_add
(
internode
::
shmem_long_atomic_add
(
#endif
#endif
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
#if defined(FORCE_NVSHMEM_API)
}
#endif
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment