Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
8a0688f3
Commit
8a0688f3
authored
Dec 15, 2025
by
lishen
Browse files
简化FORCE_NVSHMEM_API宏定义的数量
parent
6b49c021
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
15 deletions
+14
-15
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+14
-15
No files found.
csrc/kernels/internode_ll.cu
View file @
8a0688f3
...
@@ -272,8 +272,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -272,8 +272,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
else
{
}
else
#endif
#endif
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
#else
...
@@ -281,9 +282,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -281,9 +282,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#endif
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
num_bytes_per_msg
,
dst_rank
);
#if defined(FORCE_NVSHMEM_API)
}
}
#endif
}
else
{
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
@@ -349,19 +348,19 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -349,19 +348,19 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
// P2P enabled
if
(
peer_base_addr
)
{
// P2P enabled
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
}
else
{
}
else
#endif
#endif
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
#else
#else
internode
::
shmem_long_atomic_add
(
internode
::
shmem_long_atomic_add
(
#endif
#endif
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
#if defined(FORCE_NVSHMEM_API)
}
}
#endif
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
}
}
...
@@ -648,8 +647,9 @@ combine(void* combined_x,
...
@@ -648,8 +647,9 @@ combine(void* combined_x,
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
}
else
#endif
#endif
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
#else
...
@@ -657,9 +657,7 @@ combine(void* combined_x,
...
@@ -657,9 +657,7 @@ combine(void* combined_x,
#endif
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
#if defined(FORCE_NVSHMEM_API)
}
}
#endif
}
}
}
}
...
@@ -677,19 +675,19 @@ combine(void* combined_x,
...
@@ -677,19 +675,19 @@ combine(void* combined_x,
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
if
(
peer_base_addr
)
{
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
req_rptr_actual
,
1
);
st_na_release
(
req_rptr_actual
,
1
);
}
else
{
}
else
#endif
#endif
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
#else
#else
internode
::
shmem_long_atomic_add
(
internode
::
shmem_long_atomic_add
(
#endif
#endif
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
#if defined(FORCE_NVSHMEM_API)
}
}
#endif
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
}
}
...
@@ -750,7 +748,8 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -750,7 +748,8 @@ LOW_LATENCY_COMBINE_RECV:
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
if
(
reg_topk_idx
[
i
]
>=
0
)
{
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
if
(
reg_topk_idx
[
i
]
>=
0
)
{
// Read from sources
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
+
4
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
+
4
);
// Reduce
// Reduce
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment