Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
1a24c8b6
"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "9bb7f101078fc9b8bee51c03f4eb4c56ad0528df"
Commit
1a24c8b6
authored
Apr 20, 2026
by
lishen
Browse files
normal-combine深度优化
parent
ab0afb04
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
42 deletions
+63
-42
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+63
-42
No files found.
csrc/kernels/internode.cu
View file @
1a24c8b6
...
@@ -1357,20 +1357,14 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
...
@@ -1357,20 +1357,14 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
#pragma unroll
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
kWarpSize
)
{
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
kWarpSize
)
{
// Read buffers
// Read buffers
// TODO: maybe too many registers here
float
values
[
kDtypePerInt4
]
=
{
0
};
// 8 × 4B = 32B
int4
recv_value_int4
[
kMaxNumRanks
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
recv_value_int4
[
j
]
=
ld_nc_global
(
get_addr_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
));
int4
recv_value
=
ld_nc_global
(
get_addr_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
));
auto
recv_dtypes
=
reinterpret_cast
<
const
dtype_t
*>
(
&
recv_value
);
// Reduce all-to-all results
float
values
[
kDtypePerInt4
]
=
{
0
};
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
auto
recv_value_dtypes
=
reinterpret_cast
<
const
dtype_t
*>
(
&
recv_value_int4
[
j
]);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
kDtypePerInt4
;
++
k
)
for
(
int
k
=
0
;
k
<
kDtypePerInt4
;
++
k
)
values
[
k
]
+=
static_cast
<
float
>
(
recv_
value_
dtypes
[
k
]);
values
[
k
]
+=
static_cast
<
float
>
(
recv_dtypes
[
k
]);
}
}
// Cast back to `dtype_t` and write
// Cast back to `dtype_t` and write
...
@@ -1835,47 +1829,74 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1835,47 +1829,74 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int
token_start_idx
,
token_end_idx
;
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
// Iterate over all tokens and combine
// ==================== Token 级展开 x4 ====================
constexpr
int
kTokenUnroll
=
4
;
int
cached_channel_tail_idx
=
0
;
int
cached_channel_tail_idx
=
0
;
for
(
int64_t
token_idx
=
token_start_idx
+
target_rank
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumRDMAReceivers
)
{
// Read expected head
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA peers"
);
int
expected_head
=
-
1
;
if
(
lane_id
<
kNumRDMARanks
)
{
expected_head
=
ld_nc_global
(
combined_rdma_head
+
token_idx
*
kNumRDMARanks
+
lane_id
);
(
expected_head
<
0
)
?
(
rdma_receiver_rdma_head
[
target_rank
][
lane_id
]
=
-
expected_head
-
1
)
:
(
rdma_receiver_rdma_head
[
target_rank
][
lane_id
]
=
expected_head
);
}
// Wait lanes to be ready
for
(
int64_t
base
=
token_start_idx
+
target_rank
;
auto
start_time
=
wall_clock64
();
base
<
token_end_idx
;
while
(
cached_channel_tail_idx
<=
expected_head
)
{
base
+=
(
int64_t
)
kNumRDMAReceivers
*
kTokenUnroll
)
{
cached_channel_tail_idx
=
static_cast
<
int
>
(
ld_acquire_sys_global
(
rdma_channel_tail
.
buffer
(
lane_id
)));
// ---- Phase 1: 批量预取所有 token 的 expected_head ----
int
cached_expected_head
[
kTokenUnroll
];
int
max_expected_head
=
-
1
;
#pragma unroll
for
(
int
u
=
0
;
u
<
kTokenUnroll
;
++
u
)
{
int64_t
tidx
=
base
+
(
int64_t
)
u
*
kNumRDMAReceivers
;
cached_expected_head
[
u
]
=
-
1
;
if
(
tidx
<
token_end_idx
&&
lane_id
<
kNumRDMARanks
)
{
int
expected_head
=
ld_nc_global
(
combined_rdma_head
+
tidx
*
kNumRDMARanks
+
lane_id
);
cached_expected_head
[
u
]
=
expected_head
;
// 更新 rdma_receiver_rdma_head(coordinator 需要)
(
expected_head
<
0
)
?
(
rdma_receiver_rdma_head
[
target_rank
][
lane_id
]
=
-
expected_head
-
1
)
:
(
rdma_receiver_rdma_head
[
target_rank
][
lane_id
]
=
expected_head
);
if
(
expected_head
>
max_expected_head
)
max_expected_head
=
expected_head
;
}
}
// Timeout check
// ---- Phase 2: 一次等待,覆盖所有 token ----
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
if
(
max_expected_head
>=
0
)
{
printf
(
"DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d
\n
"
,
auto
start_time
=
wall_clock64
();
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
cached_channel_tail_idx
,
token_idx
,
expected_head
);
while
(
cached_channel_tail_idx
<=
max_expected_head
)
{
trap
();
cached_channel_tail_idx
=
static_cast
<
int
>
(
ld_acquire_sys_global
(
rdma_channel_tail
.
buffer
(
lane_id
)));
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP combine RDMA receiver timeout (unroll x%d), "
"ch: %d, rdma: %d, nvl: %d, lane: %d, "
"tail: %d, wait: %d
\n
"
,
kTokenUnroll
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
cached_channel_tail_idx
,
max_expected_head
);
trap
();
}
}
}
}
}
syncwarp
();
syncwarp
();
// Combine current token
// ---- Phase 3: 批量处理所有就绪 token ----
auto
get_addr_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
;
};
#pragma unroll
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
for
(
int
u
=
0
;
u
<
kTokenUnroll
;
++
u
)
{
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
,
false
>
(
expected_head
>=
0
,
int64_t
tidx
=
base
+
(
int64_t
)
u
*
kNumRDMAReceivers
;
expected_head
,
lane_id
,
if
(
tidx
<
token_end_idx
)
{
hidden_int4
,
num_topk
,
int
expected_head
=
cached_expected_head
[
u
];
combined_x
+
token_idx
*
hidden_int4
,
// Combine current token
combined_topk_weights
+
token_idx
*
num_topk
,
auto
get_addr_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
;
};
num_max_rdma_chunked_recv_tokens
,
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
get_addr_fn
,
recv_tw_fn
);
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
,
false
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
combined_x
+
tidx
*
hidden_int4
,
combined_topk_weights
+
tidx
*
num_topk
,
num_max_rdma_chunked_recv_tokens
,
get_addr_fn
,
recv_tw_fn
);
}
}
}
}
// Retired
// Retired
syncwarp
();
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
rdma_receiver_retired
[
target_rank
]
=
true
;
rdma_receiver_retired
[
target_rank
]
=
true
;
}
}
}
else
if
(
warp_role
==
WarpRole
::
kNVLCoordinator
)
{
}
else
if
(
warp_role
==
WarpRole
::
kNVLCoordinator
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment