Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
42494864
Commit
42494864
authored
Apr 07, 2025
by
Chenggang Zhao
Browse files
Remove useless control metadata for low-latency combine
parent
2a0b3d7a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
9 deletions
+8
-9
csrc/config.hpp
csrc/config.hpp
+3
-3
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+4
-5
tests/test_low_latency.py
tests/test_low_latency.py
+1
-1
No files found.
csrc/config.hpp
View file @
42494864
...
...
@@ -122,7 +122,6 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
const
int
num_scales
=
hidden
/
128
;
const
int
num_local_experts
=
num_experts
/
num_ranks
;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
...
...
@@ -130,9 +129,10 @@ struct LowLatencyLayout {
// - 2 symmetric odd/even signaling buffers
// Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data transformation
EP_HOST_ASSERT
(
num_scales
*
sizeof
(
float
)
<=
hidden
);
size_t
num_bytes_per_dispatch_msg
=
sizeof
(
int4
)
+
std
::
max
(
hidden
*
sizeof
(
nv_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
size_t
num_bytes_per_combine_msg
=
sizeof
(
int4
)
+
hidden
*
sizeof
(
nv_bfloat16
);
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
nv_bfloat16
);
// Send buffer
size_t
dispatch_send_buffer_bytes
=
num_max_dispatch_tokens_per_rank
*
num_bytes_per_dispatch_msg
;
...
...
@@ -167,7 +167,7 @@ struct LowLatencyLayout {
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
+
sizeof
(
int4
)
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
num_bytes_per_combine_msg
};
}
...
...
csrc/kernels/internode_ll.cu
View file @
42494864
...
...
@@ -369,8 +369,7 @@ combine(void* combined_x,
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerInt4
;
// Message package
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
nv_bfloat16
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
nv_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// Sending phase
...
...
@@ -409,12 +408,12 @@ combine(void* combined_x,
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
kNumWarpsPerGroup
)
{
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
);
// Copy directly to local rank, or copy to buffer and issue RDMA
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
)
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
if
(
dst_rank
==
rank
)
{
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
...
...
@@ -473,7 +472,7 @@ combine(void* combined_x,
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
if
(
reg_topk_idx
[
i
]
>=
0
)
{
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
+
4
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
);
// Reduce
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
...
...
tests/test_low_latency.py
View file @
42494864
...
...
@@ -84,7 +84,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
diff
<
1e-5
,
f
'Error: diff=
{
diff
}
'
assert
diff
<
1e-5
,
f
'Error:
{
diff
=
}
,
{
zero_copy
=
}
'
hash_value
^=
hash_tensor
(
combined_x
)
def
create_test_cast_with_outliers
(
num_outliers
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment