Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
4b8d4b15
Commit
4b8d4b15
authored
Jun 05, 2026
by
lijian6
Browse files
fix cached_notify err when sm greater than 32.
Signed-off-by:
lijian
<
lijian6@sugon.com
>
parent
e45581db
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
51 deletions
+48
-51
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+48
-51
No files found.
csrc/kernels/internode.cu
View file @
4b8d4b15
...
...
@@ -1213,63 +1213,60 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Barrier again
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
}
else
if
(
sm_id
==
1
)
{
}
else
{
if
(
is_cached_dispatch
)
return
;
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
// Iterate in reverse order
if
(
lane_id
<
num_rdma_ranks
and
warp_id
<
num_channels
)
{
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
warp_id
,
token_start_idx
,
token_end_idx
);
// NOTES: `1 << 25` is a heuristic large number
int
last_head
=
1
<<
25
;
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
auto
current_head
=
__ldg
(
combined_rdma_head
+
token_idx
*
num_rdma_ranks
+
lane_id
);
if
(
current_head
<
0
)
{
combined_rdma_head
[
token_idx
*
num_rdma_ranks
+
lane_id
]
=
-
last_head
-
1
;
}
else
{
last_head
=
current_head
;
constexpr
int
num_clean_sms
=
1
;
const
int
logical_block
=
sm_id
-
num_clean_sms
;
const
int
total_blocks
=
gridDim
.
x
-
num_clean_sms
;
if
(
logical_block
<
0
)
return
;
if
(
combined_rdma_head
!=
nullptr
)
{
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
for
(
int
chan
=
logical_block
;
chan
<
num_channels
;
chan
+=
total_blocks
)
{
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
chan
,
token_start_idx
,
token_end_idx
);
for
(
int
token_idx
=
token_end_idx
-
1
-
warp_id
;
token_idx
>=
token_start_idx
;
token_idx
-=
num_warps
)
{
int
last_head
=
1
<<
25
;
if
(
lane_id
<
num_rdma_ranks
)
{
auto
ptr
=
combined_rdma_head
+
token_idx
*
num_rdma_ranks
+
lane_id
;
int
current_head
=
__ldg
(
ptr
);
if
(
current_head
<
0
)
{
*
ptr
=
-
last_head
-
1
;
}
else
{
last_head
=
current_head
;
}
}
}
}
}
}
else
{
if
(
is_cached_dispatch
)
return
;
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
constexpr
int
num_clean_sms
=
2
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
and
warp_id
<
num_channels
)
{
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
dst_rdma_rank
+=
num_channels
*
2
-
num_clean_sms
)
{
// Iterate in reverse order
int
token_start_idx
=
warp_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp_id
-
1
];
int
token_end_idx
=
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp_id
];
int
shift
=
dst_rdma_rank
==
0
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
token_start_idx
+=
shift
,
token_end_idx
+=
shift
;
// NOTES: `1 << 25` is a heuristic large number
int
last_head
=
1
<<
25
;
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
auto
current_head
=
__ldg
(
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
);
if
(
current_head
<
0
)
{
combined_nvl_head
[
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
]
=
-
last_head
-
1
;
}
else
{
last_head
=
current_head
;
if
(
combined_nvl_head
!=
nullptr
)
{
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
);
EP_DEVICE_ASSERT
(
rdma_rank_prefix_sum
!=
nullptr
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
for
(
int
chan
=
logical_block
;
chan
<
num_channels
;
chan
+=
total_blocks
)
{
for
(
int
dst_rdma_rank
=
0
;
dst_rdma_rank
<
num_rdma_ranks
;
++
dst_rdma_rank
)
{
int
token_start_idx
=
(
chan
==
0
)
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
chan
-
1
];
int
token_end_idx
=
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
chan
];
int
shift
=
(
dst_rdma_rank
==
0
)
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
token_start_idx
+=
shift
;
token_end_idx
+=
shift
;
for
(
int
token_idx
=
token_end_idx
-
1
-
warp_id
;
token_idx
>=
token_start_idx
;
token_idx
-=
num_warps
)
{
int
last_head
=
1
<<
25
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
auto
ptr
=
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
;
int
current_head
=
__ldg
(
ptr
);
if
(
current_head
<
0
)
{
*
ptr
=
-
last_head
-
1
;
}
else
{
last_head
=
current_head
;
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment