Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
830124e1
Commit
830124e1
authored
Feb 04, 2026
by
lishen
Browse files
量化scales传输size优化
parent
d0fcf024
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
20 additions
and
18 deletions
+20
-18
csrc/config.hpp
csrc/config.hpp
+4
-4
csrc/deep_ep.cu
csrc/deep_ep.cu
+3
-3
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+6
-6
deep_ep/buffer.py
deep_ep/buffer.py
+6
-4
No files found.
csrc/config.hpp
View file @
830124e1
...
@@ -135,8 +135,8 @@ struct LowLatencyLayout {
...
@@ -135,8 +135,8 @@ struct LowLatencyLayout {
}
}
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
int
num_ranks
,
int
num_experts
,
int
quant_group_size
=
0
)
{
const
int
num_scales
=
hidden
/
QUANTIZATION_GROUPSIZE
;
const
int
num_scales
=
quant_group_size
==
0
?
4
:
hidden
/
QUANTIZATION_GROUPSIZE
;
// 应该是1,但是代码中为了满足int4对齐
// Dispatch and combine layout:
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even send buffer
...
@@ -205,9 +205,9 @@ struct LowLatencyLayout {
...
@@ -205,9 +205,9 @@ struct LowLatencyLayout {
};
};
inline
size_t
get_low_latency_rdma_size_hint
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
inline
size_t
get_low_latency_rdma_size_hint
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
int
num_ranks
,
int
num_experts
,
int
quant_group_size
=
0
)
{
auto
num_bytes
=
auto
num_bytes
=
LowLatencyLayout
(
nullptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
)
LowLatencyLayout
(
nullptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
,
quant_group_size
)
.
total_bytes
;
.
total_bytes
;
return
((
num_bytes
+
NUM_BUFFER_ALIGNMENT_BYTES
)
/
NUM_BUFFER_ALIGNMENT_BYTES
)
*
return
((
num_bytes
+
NUM_BUFFER_ALIGNMENT_BYTES
)
/
NUM_BUFFER_ALIGNMENT_BYTES
)
*
NUM_BUFFER_ALIGNMENT_BYTES
;
NUM_BUFFER_ALIGNMENT_BYTES
;
...
...
csrc/deep_ep.cu
View file @
830124e1
...
@@ -1271,10 +1271,10 @@ Buffer::internode_combine(
...
@@ -1271,10 +1271,10 @@ Buffer::internode_combine(
#endif
#endif
}
}
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
,
int
quant_group_size
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
auto
layout
=
LowLatencyLayout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
layout
=
LowLatencyLayout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
,
quant_group_size
);
auto
clean_meta_0
=
layout
.
buffers
[
0
].
clean_meta
();
auto
clean_meta_0
=
layout
.
buffers
[
0
].
clean_meta
();
auto
clean_meta_1
=
layout
.
buffers
[
1
].
clean_meta
();
auto
clean_meta_1
=
layout
.
buffers
[
1
].
clean_meta
();
...
@@ -1311,7 +1311,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1311,7 +1311,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto
num_local_experts
=
num_experts
/
num_ranks
;
auto
num_local_experts
=
num_experts
/
num_ranks
;
// Buffer control
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
,
quant_group_size
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
...
...
csrc/deep_ep.hpp
View file @
830124e1
...
@@ -172,7 +172,7 @@ public:
...
@@ -172,7 +172,7 @@ public:
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
int
num_experts
,
int
quant_group_size
=
0
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
...
...
csrc/kernels/internode_ll.cu
View file @
830124e1
...
@@ -210,13 +210,13 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -210,13 +210,13 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// Message package: hidden data, FP8 scales, index at source
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
typename
std
::
conditional
<
kUseQuant8Bit
,
int2
,
int4
>::
type
;
using
vec_t
=
typename
std
::
conditional
<
kUseQuant8Bit
,
int2
,
int4
>::
type
;
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseQuant8Bit
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseQuant8Bit
?
(
kHidden
+
(
kQuantGroupSize
==
0
?
4
:
kNumScales
)
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
);
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
);
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
// Expert counts
// Expert counts
constexpr
int
kNumMaxWarpGroups
=
1024
/
kWarpSize
;
__shared__
int
shared_num_tokens_sent_per_expert
[
kMaxNumWarps
];
__shared__
int
shared_num_tokens_sent_per_expert
[
kNumMaxWarpGroups
];
// Sending phase
// Sending phase
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
...
@@ -230,7 +230,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -230,7 +230,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
constexpr
int
kNumThreadPerGroup
=
QUANTIZATION_GROUPSIZE
/
kNumElemsPerRead
;
constexpr
int
kNumThreadPerGroup
=
QUANTIZATION_GROUPSIZE
/
kNumElemsPerRead
;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
const
auto
num_threads
=
num_warps
*
kWarpSize
;
constexpr
int
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
constexpr
int
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
...
@@ -375,7 +375,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -375,7 +375,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
}
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
// This SM should be responsible for some destination experts, read `topk_idx` for them
int
expert_count
[
k
NumMaxWarpGrou
ps
]
=
{
0
};
int
expert_count
[
k
MaxNumWar
ps
]
=
{
0
};
const
auto
expert_begin_idx
=
sm_id
*
num_warp_groups
;
const
auto
expert_begin_idx
=
sm_id
*
num_warp_groups
;
const
auto
expert_end_idx
=
min
(
expert_begin_idx
+
num_warp_groups
,
num_experts
);
const
auto
expert_end_idx
=
min
(
expert_begin_idx
+
num_warp_groups
,
num_experts
);
...
@@ -465,7 +465,7 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -465,7 +465,7 @@ LOW_LATENCY_DISPATCH_RECV:
(
kQuantGroupSize
==
0
?
1
:
num_aligned_scales
);
(
kQuantGroupSize
==
0
?
1
:
num_aligned_scales
);
// Shared between sub-warps in warp groups
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
k
NumMaxWarpGrou
ps
],
shared_recv_token_begin_idx
[
k
NumMaxWarpGrou
ps
];
__shared__
int
shared_num_recv_tokens
[
k
MaxNumWar
ps
],
shared_recv_token_begin_idx
[
k
MaxNumWar
ps
];
// Wait tokens to arrive
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
// NOTES: using sub-warp 1 to overlap with sub-warp 0
...
...
deep_ep/buffer.py
View file @
830124e1
...
@@ -212,7 +212,7 @@ class Buffer:
...
@@ -212,7 +212,7 @@ class Buffer:
@
staticmethod
@
staticmethod
def
get_low_latency_rdma_size_hint
(
def
get_low_latency_rdma_size_hint
(
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_ranks
:
int
,
num_experts
:
int
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_ranks
:
int
,
num_experts
:
int
,
quant_group_size
:
int
=
0
)
->
int
:
)
->
int
:
"""
"""
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
...
@@ -222,12 +222,13 @@ class Buffer:
...
@@ -222,12 +222,13 @@ class Buffer:
hidden: the hidden dimension of each token.
hidden: the hidden dimension of each token.
num_ranks: the number of EP group ranks.
num_ranks: the number of EP group ranks.
num_experts: the number of all experts.
num_experts: the number of all experts.
quant_group_size: the group size if use quant.
Returns:
Returns:
size: the RDMA buffer size recommended.
size: the RDMA buffer size recommended.
"""
"""
return
deep_ep_cpp
.
get_low_latency_rdma_size_hint
(
return
deep_ep_cpp
.
get_low_latency_rdma_size_hint
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
,
quant_group_size
)
)
def
get_comm_stream
(
self
)
->
torch
.
Stream
:
def
get_comm_stream
(
self
)
->
torch
.
Stream
:
...
@@ -823,7 +824,7 @@ class Buffer:
...
@@ -823,7 +824,7 @@ class Buffer:
return
combined_x
,
combined_topk_weights
,
EventOverlap
(
event
)
return
combined_x
,
combined_topk_weights
,
EventOverlap
(
event
)
def
clean_low_latency_buffer
(
def
clean_low_latency_buffer
(
self
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_experts
:
int
self
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_experts
:
int
,
quant_group_size
:
int
=
0
)
->
None
:
)
->
None
:
"""
"""
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
...
@@ -835,8 +836,9 @@ class Buffer:
...
@@ -835,8 +836,9 @@ class Buffer:
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
hidden: the hidden dimension of each token.
hidden: the hidden dimension of each token.
num_experts: the number of all experts.
num_experts: the number of all experts.
quant_group_size: the group size if use quant.
"""
"""
self
.
runtime
.
clean_low_latency_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
self
.
runtime
.
clean_low_latency_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
,
quant_group_size
)
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment