Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
5a2e37fa
Unverified
Commit
5a2e37fa
authored
Jun 09, 2025
by
Chenggang Zhao
Committed by
GitHub
Jun 09, 2025
Browse files
Support statistics tensor for low-latency kernels (#196)
parent
0d1a855d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
3 deletions
+27
-3
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+7
-0
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-0
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+5
-0
deep_ep/buffer.py
deep_ep/buffer.py
+9
-3
tests/test_low_latency.py
tests/test_low_latency.py
+4
-0
No files found.
csrc/deep_ep.cpp
View file @
5a2e37fa
...
@@ -1030,6 +1030,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
...
@@ -1030,6 +1030,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
)
{
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
...
@@ -1042,6 +1043,11 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1042,6 +1043,11 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT
(
x
.
size
(
0
)
==
topk_idx
.
size
(
0
)
and
x
.
size
(
0
)
<=
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
x
.
size
(
0
)
==
topk_idx
.
size
(
0
)
and
x
.
size
(
0
)
<=
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
if
(
cumulative_local_expert_recv_stats
.
has_value
())
{
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
dim
()
==
1
and
cumulative_local_expert_recv_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
size
(
0
)
==
num_experts
/
num_ranks
);
}
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_scales
=
hidden
/
128
,
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
auto
num_scales
=
hidden
/
128
,
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
...
@@ -1084,6 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1084,6 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales_ptr
,
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales_ptr
,
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
cumulative_local_expert_recv_stats
.
has_value
()
?
cumulative_local_expert_recv_stats
->
data_ptr
<
int
>
()
:
nullptr
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
...
...
csrc/deep_ep.hpp
View file @
5a2e37fa
...
@@ -142,6 +142,7 @@ public:
...
@@ -142,6 +142,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
...
...
csrc/kernels/api.cuh
View file @
5a2e37fa
...
@@ -133,6 +133,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
...
@@ -133,6 +133,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
...
...
csrc/kernels/internode_ll.cu
View file @
5a2e37fa
...
@@ -41,6 +41,7 @@ __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
...
@@ -41,6 +41,7 @@ __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
...
@@ -273,6 +274,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -273,6 +274,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
if
(
cumulative_local_expert_recv_stats
!=
nullptr
)
atomicAdd
(
cumulative_local_expert_recv_stats
+
local_expert_idx
,
num_recv_tokens
);
}
}
asm
volatile
(
"bar.sync %0, %1;"
::
"r"
(
warp_group_id
+
2
),
"r"
(
kNumWarpsPerGroup
*
32
));
asm
volatile
(
"bar.sync %0, %1;"
::
"r"
(
warp_group_id
+
2
),
"r"
(
kNumWarpsPerGroup
*
32
));
num_recv_tokens
=
shared_num_recv_tokens
[
warp_group_id
];
num_recv_tokens
=
shared_num_recv_tokens
[
warp_group_id
];
...
@@ -310,6 +313,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -310,6 +313,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
...
@@ -338,6 +342,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
...
@@ -338,6 +342,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
packed_recv_count, \
cumulative_local_expert_recv_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
...
...
deep_ep/buffer.py
View file @
5a2e37fa
...
@@ -473,6 +473,7 @@ class Buffer:
...
@@ -473,6 +473,7 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
cumulative_local_expert_recv_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
...
@@ -481,7 +482,7 @@ class Buffer:
...
@@ -481,7 +482,7 @@ class Buffer:
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
low-latency kernels' result tensor
s
at a single moment.
Arguments:
Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
...
@@ -490,6 +491,9 @@ class Buffer:
...
@@ -490,6 +491,9 @@ class Buffer:
are supported. `-1` indices (not selecting any expert) are supported.
are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
num_experts: the number of all experts.
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
...
@@ -508,19 +512,21 @@ class Buffer:
...
@@ -508,19 +512,21 @@ class Buffer:
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before, not all tokens are valid in `recv_x`.
expert receive
s
. As mentioned before, not all tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
"""
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
cumulative_local_expert_recv_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
async_finish
,
return_recv_hook
)
use_fp8
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
packed_recv_src_info
,
packed_recv_layout_range
,
cumulative_local_expert_recv_stats
)
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
...
...
tests/test_low_latency.py
View file @
5a2e37fa
...
@@ -36,8 +36,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -36,8 +36,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
for
dispatch_use_fp8
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
num_times
+=
1
num_times
+=
1
for
i
in
range
((
num_times
%
2
)
+
1
):
for
i
in
range
((
num_times
%
2
)
+
1
):
cumulative_local_expert_recv_stats
=
torch
.
zeros
((
num_local_experts
,
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
...
@@ -53,6 +55,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -53,6 +55,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Check expert indices
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
num_valid_tokens
=
recv_count
.
item
()
assert
cumulative_local_expert_recv_stats
[
i
].
item
()
==
num_valid_tokens
,
f
'
{
cumulative_local_expert_recv_stats
[
i
].
item
()
}
!=
{
num_valid_tokens
}
'
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
().
item
()
}
'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
().
item
()
}
'
...
@@ -108,6 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -108,6 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
def
test_func
(
zero_copy
:
bool
,
return_recv_hook
:
bool
):
def
test_func
(
zero_copy
:
bool
,
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
if
zero_copy
:
if
zero_copy
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment