Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
dbf9fd61
Commit
dbf9fd61
authored
Feb 05, 2026
by
lishen
Browse files
Merge branch 'quant_main' into 'main'
量化scales传输size优化 See merge request dcutoolkit/deeplearing/DeepEP!20
parents
d0fcf024
e57e9270
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
37 additions
and
35 deletions
+37
-35
csrc/config.hpp
csrc/config.hpp
+4
-4
csrc/deep_ep.cu
csrc/deep_ep.cu
+3
-3
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+6
-6
deep_ep/buffer.py
deep_ep/buffer.py
+6
-4
tests/test_internode.py
tests/test_internode.py
+3
-3
tests/test_intranode.py
tests/test_intranode.py
+4
-4
tests/test_low_latency.py
tests/test_low_latency.py
+3
-3
tests/utils.py
tests/utils.py
+7
-7
No files found.
csrc/config.hpp
View file @
dbf9fd61
...
@@ -135,8 +135,8 @@ struct LowLatencyLayout {
...
@@ -135,8 +135,8 @@ struct LowLatencyLayout {
}
}
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
int
num_ranks
,
int
num_experts
,
int
quant_group_size
=
0
)
{
const
int
num_scales
=
hidden
/
QUANTIZATION_GROUPSIZE
;
const
int
num_scales
=
quant_group_size
==
0
?
4
:
hidden
/
QUANTIZATION_GROUPSIZE
;
// 应该是1,但是代码中为了满足int4对齐
// Dispatch and combine layout:
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even send buffer
...
@@ -205,9 +205,9 @@ struct LowLatencyLayout {
...
@@ -205,9 +205,9 @@ struct LowLatencyLayout {
};
};
inline
size_t
get_low_latency_rdma_size_hint
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
inline
size_t
get_low_latency_rdma_size_hint
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
int
num_ranks
,
int
num_experts
,
int
quant_group_size
=
0
)
{
auto
num_bytes
=
auto
num_bytes
=
LowLatencyLayout
(
nullptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
)
LowLatencyLayout
(
nullptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
,
quant_group_size
)
.
total_bytes
;
.
total_bytes
;
return
((
num_bytes
+
NUM_BUFFER_ALIGNMENT_BYTES
)
/
NUM_BUFFER_ALIGNMENT_BYTES
)
*
return
((
num_bytes
+
NUM_BUFFER_ALIGNMENT_BYTES
)
/
NUM_BUFFER_ALIGNMENT_BYTES
)
*
NUM_BUFFER_ALIGNMENT_BYTES
;
NUM_BUFFER_ALIGNMENT_BYTES
;
...
...
csrc/deep_ep.cu
View file @
dbf9fd61
...
@@ -1271,10 +1271,10 @@ Buffer::internode_combine(
...
@@ -1271,10 +1271,10 @@ Buffer::internode_combine(
#endif
#endif
}
}
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
,
int
quant_group_size
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
auto
layout
=
LowLatencyLayout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
layout
=
LowLatencyLayout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
,
quant_group_size
);
auto
clean_meta_0
=
layout
.
buffers
[
0
].
clean_meta
();
auto
clean_meta_0
=
layout
.
buffers
[
0
].
clean_meta
();
auto
clean_meta_1
=
layout
.
buffers
[
1
].
clean_meta
();
auto
clean_meta_1
=
layout
.
buffers
[
1
].
clean_meta
();
...
@@ -1311,7 +1311,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1311,7 +1311,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto
num_local_experts
=
num_experts
/
num_ranks
;
auto
num_local_experts
=
num_experts
/
num_ranks
;
// Buffer control
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
,
quant_group_size
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
...
...
csrc/deep_ep.hpp
View file @
dbf9fd61
...
@@ -172,7 +172,7 @@ public:
...
@@ -172,7 +172,7 @@ public:
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
int
num_experts
,
int
quant_group_size
=
0
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
...
...
csrc/kernels/internode_ll.cu
View file @
dbf9fd61
...
@@ -210,13 +210,13 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -210,13 +210,13 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// Message package: hidden data, FP8 scales, index at source
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
typename
std
::
conditional
<
kUseQuant8Bit
,
int2
,
int4
>::
type
;
using
vec_t
=
typename
std
::
conditional
<
kUseQuant8Bit
,
int2
,
int4
>::
type
;
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseQuant8Bit
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseQuant8Bit
?
(
kHidden
+
(
kQuantGroupSize
==
0
?
4
:
kNumScales
)
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
);
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
);
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
// Expert counts
// Expert counts
constexpr
int
kNumMaxWarpGroups
=
1024
/
kWarpSize
;
__shared__
int
shared_num_tokens_sent_per_expert
[
kMaxNumWarps
];
__shared__
int
shared_num_tokens_sent_per_expert
[
kNumMaxWarpGroups
];
// Sending phase
// Sending phase
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
...
@@ -230,7 +230,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -230,7 +230,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
constexpr
int
kNumThreadPerGroup
=
QUANTIZATION_GROUPSIZE
/
kNumElemsPerRead
;
constexpr
int
kNumThreadPerGroup
=
QUANTIZATION_GROUPSIZE
/
kNumElemsPerRead
;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
const
auto
num_threads
=
num_warps
*
kWarpSize
;
constexpr
int
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
constexpr
int
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
...
@@ -375,7 +375,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -375,7 +375,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
}
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
// This SM should be responsible for some destination experts, read `topk_idx` for them
int
expert_count
[
k
NumMaxWarpGrou
ps
]
=
{
0
};
int
expert_count
[
k
MaxNumWar
ps
]
=
{
0
};
const
auto
expert_begin_idx
=
sm_id
*
num_warp_groups
;
const
auto
expert_begin_idx
=
sm_id
*
num_warp_groups
;
const
auto
expert_end_idx
=
min
(
expert_begin_idx
+
num_warp_groups
,
num_experts
);
const
auto
expert_end_idx
=
min
(
expert_begin_idx
+
num_warp_groups
,
num_experts
);
...
@@ -465,7 +465,7 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -465,7 +465,7 @@ LOW_LATENCY_DISPATCH_RECV:
(
kQuantGroupSize
==
0
?
1
:
num_aligned_scales
);
(
kQuantGroupSize
==
0
?
1
:
num_aligned_scales
);
// Shared between sub-warps in warp groups
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
k
NumMaxWarpGrou
ps
],
shared_recv_token_begin_idx
[
k
NumMaxWarpGrou
ps
];
__shared__
int
shared_num_recv_tokens
[
k
MaxNumWar
ps
],
shared_recv_token_begin_idx
[
k
MaxNumWar
ps
];
// Wait tokens to arrive
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
// NOTES: using sub-warp 1 to overlap with sub-warp 0
...
...
deep_ep/buffer.py
View file @
dbf9fd61
...
@@ -212,7 +212,7 @@ class Buffer:
...
@@ -212,7 +212,7 @@ class Buffer:
@
staticmethod
@
staticmethod
def
get_low_latency_rdma_size_hint
(
def
get_low_latency_rdma_size_hint
(
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_ranks
:
int
,
num_experts
:
int
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_ranks
:
int
,
num_experts
:
int
,
quant_group_size
:
int
=
0
)
->
int
:
)
->
int
:
"""
"""
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
...
@@ -222,12 +222,13 @@ class Buffer:
...
@@ -222,12 +222,13 @@ class Buffer:
hidden: the hidden dimension of each token.
hidden: the hidden dimension of each token.
num_ranks: the number of EP group ranks.
num_ranks: the number of EP group ranks.
num_experts: the number of all experts.
num_experts: the number of all experts.
quant_group_size: the group size if use quant.
Returns:
Returns:
size: the RDMA buffer size recommended.
size: the RDMA buffer size recommended.
"""
"""
return
deep_ep_cpp
.
get_low_latency_rdma_size_hint
(
return
deep_ep_cpp
.
get_low_latency_rdma_size_hint
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
,
quant_group_size
)
)
def
get_comm_stream
(
self
)
->
torch
.
Stream
:
def
get_comm_stream
(
self
)
->
torch
.
Stream
:
...
@@ -823,7 +824,7 @@ class Buffer:
...
@@ -823,7 +824,7 @@ class Buffer:
return
combined_x
,
combined_topk_weights
,
EventOverlap
(
event
)
return
combined_x
,
combined_topk_weights
,
EventOverlap
(
event
)
def
clean_low_latency_buffer
(
def
clean_low_latency_buffer
(
self
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_experts
:
int
self
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_experts
:
int
,
quant_group_size
:
int
=
0
)
->
None
:
)
->
None
:
"""
"""
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
...
@@ -835,8 +836,9 @@ class Buffer:
...
@@ -835,8 +836,9 @@ class Buffer:
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
hidden: the hidden dimension of each token.
hidden: the hidden dimension of each token.
num_experts: the number of all experts.
num_experts: the number of all experts.
quant_group_size: the group size if use quant.
"""
"""
self
.
runtime
.
clean_low_latency_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
self
.
runtime
.
clean_low_latency_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
,
quant_group_size
)
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
...
tests/test_internode.py
View file @
dbf9fd61
...
@@ -6,7 +6,7 @@ import torch.distributed as dist
...
@@ -6,7 +6,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
create_grouped_scores
,
inplace_unique
,
per_token_cast_to_fp8
,
per_token_cast_back
,
hash_tensor
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
create_grouped_scores
,
inplace_unique
,
per_token_cast_to_fp8
,
per_token_cast_
pg_
back
,
hash_tensor
# Test compatibility with low latency functions
# Test compatibility with low latency functions
import
test_low_latency
import
test_low_latency
...
@@ -127,7 +127,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -127,7 +127,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
hash_value
+=
hash_tensor
(
recv_x
[
0
])
hash_value
+=
hash_tensor
(
recv_x
[
0
])
hash_value
+=
hash_tensor
(
recv_x
[
1
])
hash_value
+=
hash_tensor
(
recv_x
[
1
])
recv_x
=
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
recv_x
=
per_token_cast_
pg_
back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
# Checks
# Checks
recv_gbl_rank_prefix_sum
=
handle
[
-
4
]
recv_gbl_rank_prefix_sum
=
handle
[
-
4
]
...
@@ -153,7 +153,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -153,7 +153,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
recv_x
=
per_token_cast_
pg_
back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
if
not
is_rand
:
if
not
is_rand
:
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
...
...
tests/test_intranode.py
View file @
dbf9fd61
...
@@ -5,7 +5,7 @@ import torch.distributed as dist
...
@@ -5,7 +5,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
calc_diff
,
inplace_unique
,
per_token_cast_to_fp8
,
per_token_cast_back
from
utils
import
init_dist
,
bench
,
calc_diff
,
inplace_unique
,
per_token_cast_to_fp8
,
per_token_cast_
pg_
back
# Test compatibility with low latency functions
# Test compatibility with low latency functions
import
test_low_latency
import
test_low_latency
...
@@ -99,7 +99,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
...
@@ -99,7 +99,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
recv_num_tokens_per_expert_list
,
handle
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
recv_num_tokens_per_expert_list
,
handle
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
recv_x
=
per_token_cast_
pg_
back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
# Checks
# Checks
rank_prefix_matrix
=
handle
[
0
]
rank_prefix_matrix
=
handle
[
0
]
...
@@ -126,7 +126,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
...
@@ -126,7 +126,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args
.
update
({
'num_worst_tokens'
:
num_worst_tokens
})
dispatch_args
.
update
({
'num_worst_tokens'
:
num_worst_tokens
})
recv_worst_x
,
recv_worst_topk_idx
,
recv_worst_topk_weights
,
empty_list
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
recv_worst_x
,
recv_worst_topk_idx
,
recv_worst_topk_weights
,
empty_list
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
recv_worst_x
=
per_token_cast_back
(
*
recv_worst_x
)
if
isinstance
(
recv_worst_x
,
tuple
)
else
recv_worst_x
recv_worst_x
=
per_token_cast_
pg_
back
(
*
recv_worst_x
)
if
isinstance
(
recv_worst_x
,
tuple
)
else
recv_worst_x
assert
len
(
empty_list
)
==
0
assert
len
(
empty_list
)
==
0
assert
num_worst_tokens
==
recv_worst_x
.
size
(
0
)
assert
num_worst_tokens
==
recv_worst_x
.
size
(
0
)
assert
num_worst_tokens
==
recv_worst_topk_idx
.
size
(
0
)
assert
num_worst_tokens
==
recv_worst_topk_idx
.
size
(
0
)
...
@@ -143,7 +143,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
...
@@ -143,7 +143,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
dispatch_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
recv_x
=
per_token_cast_
pg_
back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
if
current_x
is
not
x_pure_rand
:
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
rank_prefix_matrix
)
check_data
(
recv_x
,
rank_prefix_matrix
)
...
...
tests/test_low_latency.py
View file @
dbf9fd61
...
@@ -4,7 +4,7 @@ import torch.distributed as dist
...
@@ -4,7 +4,7 @@ import torch.distributed as dist
from
functools
import
partial
from
functools
import
partial
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_
pg_
back
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
...
@@ -44,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -44,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
# print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
# return
# return
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
simulated_gemm_x
=
per_token_cast_
pg_
back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
# print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
# print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
# print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
# print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
...
@@ -53,7 +53,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -53,7 +53,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_x
=
per_token_cast_
pg_
back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
# Check expert indices
...
...
tests/utils.py
View file @
dbf9fd61
...
@@ -72,16 +72,16 @@ def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
...
@@ -72,16 +72,16 @@ def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
,
1
)
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32_padded
*
x_scales
).
view
(
x_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
return
(
x_fp32_padded
*
x_scales
).
view
(
x_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
def
per_token_cast_pc_back
(
x
_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
def
per_token_cast_pc_back
(
x
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x
_int8
.
numel
()
==
0
:
if
x
.
numel
()
==
0
:
return
x
_int8
.
to
(
torch
.
bfloat16
)
return
x
.
to
(
torch
.
bfloat16
)
assert
x
_int8
.
dim
()
==
2
assert
x
.
dim
()
==
2
m
,
n
=
x
_int8
.
shape
m
,
n
=
x
.
shape
aligned_n
=
align_up
(
n
,
128
)
aligned_n
=
align_up
(
n
,
128
)
x_
int8_
padded
=
torch
.
nn
.
functional
.
pad
(
x
_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_padded
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_fp32_padded
=
x_
int8_
padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_fp32_padded
=
x_padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment