Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
a8299ca7
Unverified
Commit
a8299ca7
authored
Jun 11, 2025
by
Chenggang Zhao
Committed by
GitHub
Jun 11, 2025
Browse files
Support CUDA graph for intranode normal kernels (#203)
parent
8da2d7b3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
89 additions
and
41 deletions
+89
-41
README.md
README.md
+1
-0
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+30
-20
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+2
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-1
csrc/kernels/intranode.cu
csrc/kernels/intranode.cu
+28
-15
deep_ep/buffer.py
deep_ep/buffer.py
+10
-4
tests/test_intranode.py
tests/test_intranode.py
+17
-0
No files found.
README.md
View file @
a8299ca7
...
...
@@ -162,6 +162,7 @@ def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
allocate_on_comm_stream
=
previous_event
is
not
None
)
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# Unless you specify `num_worst_tokens`, but this flag is for intranode only
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
num_recv_tokens_per_expert_list
,
handle
,
event
=
\
_buffer
.
dispatch
(
x
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
...
...
csrc/deep_ep.cpp
View file @
a8299ca7
...
...
@@ -284,7 +284,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
const
std
::
optional
<
torch
::
Tensor
>&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
const
std
::
optional
<
torch
::
Tensor
>&
num_tokens_per_rank
,
const
torch
::
Tensor
&
is_token_in_rank
,
const
std
::
optional
<
torch
::
Tensor
>&
num_tokens_per_expert
,
int
cached_num_recv_tokens
,
const
std
::
optional
<
torch
::
Tensor
>&
cached_rank_prefix_matrix
,
const
std
::
optional
<
torch
::
Tensor
>&
cached_channel_prefix_matrix
,
int
expert_alignment
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
int
expert_alignment
,
int
num_worst_tokens
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
bool
cached_mode
=
cached_rank_prefix_matrix
.
has_value
();
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
...
...
@@ -412,25 +413,34 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
buffer_ptrs_gpu
,
barrier_signal_ptrs_gpu
,
rank
,
comm_stream
,
num_channels
);
// Synchronize total received tokens and tokens per expert
auto
start_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
while
(
true
)
{
// Read total count
num_recv_tokens
=
static_cast
<
int
>
(
*
moe_recv_counter
);
// Read per-expert count
bool
ready
=
(
num_recv_tokens
>=
0
);
for
(
int
i
=
0
;
i
<
num_local_experts
and
ready
;
++
i
)
ready
&=
moe_recv_expert_counter
[
i
]
>=
0
;
if
(
ready
)
break
;
// Timeout check
if
(
std
::
chrono
::
duration_cast
<
std
::
chrono
::
seconds
>
(
std
::
chrono
::
high_resolution_clock
::
now
()
-
start_time
).
count
()
>
NUM_CPU_TIMEOUT_SECS
)
throw
std
::
runtime_error
(
"DeepEP error: CPU recv timeout"
);
if
(
num_worst_tokens
>
0
)
{
// No CPU sync, just allocate the worst case
num_recv_tokens
=
num_worst_tokens
;
// Must be forward with top-k stuffs
EP_HOST_ASSERT
(
topk_idx
.
has_value
());
EP_HOST_ASSERT
(
topk_weights
.
has_value
());
}
else
{
// Synchronize total received tokens and tokens per expert
auto
start_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
while
(
true
)
{
// Read total count
num_recv_tokens
=
static_cast
<
int
>
(
*
moe_recv_counter
);
// Read per-expert count
bool
ready
=
(
num_recv_tokens
>=
0
);
for
(
int
i
=
0
;
i
<
num_local_experts
and
ready
;
++
i
)
ready
&=
moe_recv_expert_counter
[
i
]
>=
0
;
if
(
ready
)
break
;
// Timeout check
if
(
std
::
chrono
::
duration_cast
<
std
::
chrono
::
seconds
>
(
std
::
chrono
::
high_resolution_clock
::
now
()
-
start_time
).
count
()
>
NUM_CPU_TIMEOUT_SECS
)
throw
std
::
runtime_error
(
"DeepEP error: CPU recv timeout"
);
}
num_recv_tokens_per_expert_list
=
std
::
vector
<
int
>
(
moe_recv_expert_counter
,
moe_recv_expert_counter
+
num_local_experts
);
}
num_recv_tokens_per_expert_list
=
std
::
vector
<
int
>
(
moe_recv_expert_counter
,
moe_recv_expert_counter
+
num_local_experts
);
}
// Allocate new tensors
...
...
@@ -472,7 +482,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
send_head
.
data_ptr
<
int
>
(),
x
.
data_ptr
(),
x_scales_ptr
,
topk_idx_ptr
,
topk_weights_ptr
,
is_token_in_rank
.
data_ptr
<
bool
>
(),
channel_prefix_matrix
.
data_ptr
<
int
>
(),
num_tokens
,
static_cast
<
int
>
(
hidden
*
recv_x
.
element_size
()
/
sizeof
(
int4
)),
num_topk
,
num_experts
,
num_scales
,
num_tokens
,
num_worst_tokens
,
static_cast
<
int
>
(
hidden
*
recv_x
.
element_size
()
/
sizeof
(
int4
)),
num_topk
,
num_experts
,
num_scales
,
buffer_ptrs_gpu
,
rank
,
num_ranks
,
comm_stream
,
config
.
num_sms
,
config
.
num_max_nvl_chunked_send_tokens
,
config
.
num_max_nvl_chunked_recv_tokens
);
...
...
csrc/deep_ep.hpp
View file @
a8299ca7
...
...
@@ -108,7 +108,8 @@ public:
const
std
::
optional
<
torch
::
Tensor
>&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
const
std
::
optional
<
torch
::
Tensor
>&
num_tokens_per_rank
,
const
torch
::
Tensor
&
is_token_in_rank
,
const
std
::
optional
<
torch
::
Tensor
>&
num_tokens_per_expert
,
int
cached_num_recv_tokens
,
const
std
::
optional
<
torch
::
Tensor
>&
cached_rank_prefix_matrix
,
const
std
::
optional
<
torch
::
Tensor
>&
cached_channel_prefix_matrix
,
int
expert_alignment
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
int
expert_alignment
,
int
num_worst_tokens
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
intranode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
...
...
csrc/kernels/api.cuh
View file @
a8299ca7
...
...
@@ -45,7 +45,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void
dispatch
(
void
*
recv_x
,
float
*
recv_x_scales
,
int
*
recv_src_idx
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
int
*
recv_channel_offset
,
int
*
send_head
,
const
void
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
);
...
...
csrc/kernels/intranode.cu
View file @
a8299ca7
...
...
@@ -25,7 +25,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
int
*
per_rank_buffer
,
*
per_expert_buffer
;
if
(
thread_id
<
kNumRanks
)
{
per_rank_buffer
=
reinterpret
_cast
<
int
*>
(
buffer_ptrs
[
thread_id
]);
per_rank_buffer
=
static
_cast
<
int
*>
(
buffer_ptrs
[
thread_id
]);
per_expert_buffer
=
per_rank_buffer
+
kNumRanks
*
kNumRanks
;
}
...
...
@@ -48,7 +48,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
// Sum per-rank counts and return to CPU
// Also pre-compute the prefix sum for data sending
auto
local_per_rank_buffer
=
reinterpret
_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
auto
local_per_rank_buffer
=
static
_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
if
(
thread_id
<
kNumRanks
)
{
#pragma unroll
for
(
int
i
=
1
;
i
<
kNumRanks
;
++
i
)
...
...
@@ -141,7 +141,7 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
// Copy and clean
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
auto
ptr
=
reinterpret
_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
auto
ptr
=
static
_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
kNumRanks
*
kNumRanks
;
i
+=
num_threads
)
ptr
[
i
]
=
rank_prefix_matrix
[
i
];
...
...
@@ -173,7 +173,7 @@ __global__ void __launch_bounds__(kNumThreads, 1)
dispatch
(
int4
*
recv_x
,
float
*
recv_x_scales
,
int
*
recv_src_idx
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
int
*
recv_channel_offset
,
int
*
send_head
,
const
int4
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
void
**
buffer_ptrs
,
int
rank
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
),
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
...
...
@@ -196,7 +196,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Calculate pointers by the specific layout
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
auto
ptr
=
reinterpret_cast
<
void
*>
(
reinterpret
_cast
<
int8_t
*>
(
buffer_ptrs
[
is_sender
?
responsible_rank
:
rank
])
+
kNumRanks
*
kNumRanks
*
sizeof
(
int
));
auto
ptr
=
reinterpret_cast
<
void
*>
(
static
_cast
<
int8_t
*>
(
buffer_ptrs
[
is_sender
?
responsible_rank
:
rank
])
+
kNumRanks
*
kNumRanks
*
sizeof
(
int
));
int
target_rank
=
is_sender
?
rank
:
responsible_rank
;
auto
num_channels_total
=
num_channels
*
kNumRanks
;
auto
channel_rank_offset
=
responsible_channel
*
kNumRanks
+
target_rank
;
...
...
@@ -286,7 +286,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int
chunk_token_idx
=
0
;
while
(
chunk_token_idx
<
num_max_send_tokens
and
token_idx
<
token_end_idx
)
{
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send
subsequent
data
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send
the following
data
if
(
lane_id
==
0
and
token_idx
%
num_send_warps_per_rank
==
send_warp_id_in_rank
)
send_head
[
token_idx
*
kNumRanks
+
responsible_rank
]
=
is_token_in_rank
[
token_idx
*
kNumRanks
+
responsible_rank
]
?
cached_channel_tail_idx
:
-
1
;
...
...
@@ -349,7 +349,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
EP_DEVICE_ASSERT
(
recv_thread_id
>=
0
and
num_recv_warps
%
kNumRanks
==
0
);
// Calculate offset first
auto
rank_prefix_matrix
=
reinterpret
_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
auto
rank_prefix_matrix
=
static
_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
int
rank_offset
=
responsible_rank
>
0
?
rank_prefix_matrix
[(
responsible_rank
-
1
)
*
kNumRanks
+
rank
]
:
0
;
// Receive channel offset
...
...
@@ -372,7 +372,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto
start_time
=
clock64
();
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
while
(
num_tokens_to_recv
>
0
)
{
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are
the
same
while
(
recv_thread_id_in_rank
==
0
)
{
cached_channel_tail_idx
=
ld_acquire_sys_global
(
channel_tail_idx
.
buffer
());;
...
...
@@ -450,12 +450,25 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
if
(
lane_id
==
0
)
tma_store_wait
();
}
// Clean unused `recv_topk_idx` as -1
if
(
num_worst_tokens
>
0
)
{
auto
rank_prefix_matrix
=
static_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
const
auto
num_recv_tokens
=
rank_prefix_matrix
[(
kNumRanks
-
1
)
*
kNumRanks
+
rank
];
const
auto
clean_start
=
num_recv_tokens
*
num_topk
+
sm_id
*
kNumThreads
;
const
auto
clean_end
=
num_worst_tokens
*
num_topk
;
const
auto
clean_stride
=
num_sms
*
kNumThreads
;
#pragma unroll
for
(
int
i
=
clean_start
+
thread_id
;
i
<
clean_end
;
i
+=
clean_stride
)
recv_topk_idx
[
i
]
=
-
1
;
}
}
void
dispatch
(
void
*
recv_x
,
float
*
recv_x_scales
,
int
*
recv_src_idx
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
int
*
recv_channel_offset
,
int
*
send_head
,
const
void
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
constexpr
int
kNumThreads
=
768
;
...
...
@@ -470,7 +483,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
num_tokens,
num_worst_tokens,
hidden_int4, num_topk, num_experts, num_scales, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break
...
...
@@ -493,7 +506,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
// Clean
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
auto
ptr
=
reinterpret
_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
auto
ptr
=
static
_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_memset_int
;
i
+=
num_threads
)
ptr
[
i
]
=
0
;
...
...
@@ -590,7 +603,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
EP_STATIC_ASSERT
(
num_send_warps
*
32
==
kNumThreads
,
"Invalid warp count"
);
// Calculate pointers by the specific layout
auto
ptr
=
reinterpret_cast
<
void
*>
(
reinterpret
_cast
<
int8_t
*>
(
buffer_ptrs
[
send_rank_id
]));
auto
ptr
=
reinterpret_cast
<
void
*>
(
static
_cast
<
int8_t
*>
(
buffer_ptrs
[
send_rank_id
]));
auto
num_channels_total
=
num_channels
*
kNumRanks
;
auto
channel_rank_offset
=
responsible_channel
*
kNumRanks
+
rank
;
...
...
@@ -682,7 +695,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
asm
volatile
(
"bar.sync 0, %0;"
::
"r"
(
kNumThreads
));
if
(
thread_id
<
32
)
{
int
*
channel_head_idx_ptr
=
reinterpret
_cast
<
int
*>
(
buffer_ptrs
[
rank
])
+
responsible_channel
*
kNumRanks
+
lane_id
;
int
*
channel_head_idx_ptr
=
static
_cast
<
int
*>
(
buffer_ptrs
[
rank
])
+
responsible_channel
*
kNumRanks
+
lane_id
;
int
*
channel_tail_idx_ptr
=
channel_head_idx_ptr
+
num_channels
*
kNumRanks
;
// Queue head updater
...
...
@@ -720,13 +733,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
auto
channel_rank_offset
=
responsible_channel
*
kNumRanks
+
i
;
auto
num_channels_total
=
num_channels
*
kNumRanks
;
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
auto
ptr
=
reinterpret_cast
<
void
*>
(
reinterpret
_cast
<
int8_t
*>
(
buffer_ptrs
[
rank
])
+
2
*
num_channels
*
kNumRanks
*
sizeof
(
int
));
auto
ptr
=
reinterpret_cast
<
void
*>
(
static
_cast
<
int8_t
*>
(
buffer_ptrs
[
rank
])
+
2
*
num_channels
*
kNumRanks
*
sizeof
(
int
));
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
channel_x_buffers
[
i
]
=
Buffer
<
int4
>
(
ptr
,
num_channels_total
*
num_recv_buffer_tokens
*
hidden_int4
,
channel_rank_offset
*
num_recv_buffer_tokens
*
hidden_int4
);
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
ptr
=
reinterpret_cast
<
void
*>
(
reinterpret
_cast
<
int8_t
*>
(
ptr
)
+
num_channels_total
*
num_recv_buffer_tokens
*
sizeof
(
int
));
ptr
=
reinterpret_cast
<
void
*>
(
static
_cast
<
int8_t
*>
(
ptr
)
+
num_channels_total
*
num_recv_buffer_tokens
*
sizeof
(
int
));
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
channel_topk_weights_buffers
[
i
]
=
Buffer
<
float
>
(
ptr
,
num_channels_total
*
num_recv_buffer_tokens
*
num_topk
,
channel_rank_offset
*
num_recv_buffer_tokens
*
num_topk
);
...
...
deep_ep/buffer.py
View file @
a8299ca7
...
...
@@ -249,7 +249,8 @@ class Buffer:
handle
:
Optional
[
Tuple
]
=
None
,
num_tokens_per_rank
:
Optional
[
torch
.
Tensor
]
=
None
,
num_tokens_per_rdma_rank
:
Optional
[
torch
.
Tensor
]
=
None
,
is_token_in_rank
:
Optional
[
torch
.
Tensor
]
=
None
,
num_tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_alignment
:
int
=
1
,
topk_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_alignment
:
int
=
1
,
num_worst_tokens
:
int
=
0
,
config
:
Optional
[
Config
]
=
None
,
previous_event
:
Optional
[
EventOverlap
]
=
None
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
)
->
\
...
...
@@ -276,6 +277,8 @@ class Buffer:
`-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
expert_alignment: align the number of tokens received by each local expert to this variable.
num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it
will be CUDA-graph compatible. Please also notice that this flag is for intranode only.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
...
...
@@ -296,6 +299,7 @@ class Buffer:
# Internode
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
:
assert
num_worst_tokens
==
0
,
'Internode dispatch does not support `num_worst_tokens > 0`'
return
self
.
internode_dispatch
(
x
,
handle
,
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
is_token_in_rank
,
num_tokens_per_expert
,
topk_idx
,
topk_weights
,
expert_alignment
,
config
,
previous_event
,
async_finish
,
allocate_on_comm_stream
)
...
...
@@ -308,14 +312,16 @@ class Buffer:
recv_x
,
recv_x_scales
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
event
=
self
.
runtime
.
intranode_dispatch
(
x
,
x_scales
,
None
,
None
,
None
,
is_token_in_rank
,
None
,
num_recv_tokens
,
rank_prefix_matrix
,
channel_prefix_matrix
,
expert_alignment
,
config
,
getattr
(
previous_event
,
'event'
,
None
),
async_finish
,
allocate_on_comm_stream
)
expert_alignment
,
num_worst_tokens
,
config
,
getattr
(
previous_event
,
'event'
,
None
),
async_finish
,
allocate_on_comm_stream
)
return
(
recv_x
,
recv_x_scales
)
if
x_scales
is
not
None
else
recv_x
,
None
,
None
,
None
,
None
,
EventOverlap
(
event
)
else
:
assert
num_tokens_per_rank
is
not
None
and
is_token_in_rank
is
not
None
and
num_tokens_per_expert
is
not
None
recv_x
,
recv_x_scales
,
recv_topk_idx
,
recv_topk_weights
,
num_recv_tokens_per_expert_list
,
rank_prefix_matrix
,
channel_prefix_matrix
,
recv_channel_prefix_matrix
,
recv_src_idx
,
send_head
,
event
=
\
self
.
runtime
.
intranode_dispatch
(
x
,
x_scales
,
topk_idx
,
topk_weights
,
num_tokens_per_rank
,
is_token_in_rank
,
num_tokens_per_expert
,
0
,
None
,
None
,
expert_alignment
,
config
,
getattr
(
previous_event
,
'event'
,
None
),
async_finish
,
allocate_on_comm_stream
)
num_tokens_per_rank
,
is_token_in_rank
,
num_tokens_per_expert
,
0
,
None
,
None
,
expert_alignment
,
num_worst_tokens
,
config
,
getattr
(
previous_event
,
'event'
,
None
),
async_finish
,
allocate_on_comm_stream
)
handle
=
(
rank_prefix_matrix
,
channel_prefix_matrix
,
recv_channel_prefix_matrix
,
recv_src_idx
,
is_token_in_rank
,
send_head
)
return
(
recv_x
,
recv_x_scales
)
if
x_scales
is
not
None
else
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
num_recv_tokens_per_expert_list
,
handle
,
EventOverlap
(
event
)
...
...
tests/test_intranode.py
View file @
a8299ca7
...
...
@@ -100,6 +100,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
assert
gbl_num_tokens_per_expert
.
view
(
num_ranks
,
-
1
)[
rank
].
tolist
()
==
recv_num_tokens_per_expert_list
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
rank_prefix_matrix
)
recv_topk_weights_clone
=
None
if
with_topk
:
# Check `topk_idx`
assert
(
recv_topk_idx
.
eq
(
-
1
)
|
((
recv_topk_idx
>=
0
)
&
(
recv_topk_idx
<
(
num_experts
//
num_ranks
)))).
sum
().
item
()
==
recv_topk_idx
.
numel
()
...
...
@@ -107,10 +108,26 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
assert
recv_topk_idx
.
eq
(
i
).
sum
().
item
()
==
count
# Check `topk_weights`
recv_topk_weights_clone
=
recv_topk_weights
.
clone
()
if
current_x
is
not
x_pure_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
check_data
(
recv_topk_weights
,
rank_prefix_matrix
)
# Test `num_worst_tokens != 0`
if
with_topk
:
num_worst_tokens
=
num_tokens
*
num_ranks
dispatch_args
.
update
({
'num_worst_tokens'
:
num_worst_tokens
})
recv_worst_x
,
recv_worst_topk_idx
,
recv_worst_topk_weights
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_worst_x
=
per_token_cast_back
(
*
recv_worst_x
)
if
isinstance
(
recv_worst_x
,
tuple
)
else
recv_worst_x
assert
num_worst_tokens
==
recv_worst_x
.
size
(
0
)
assert
num_worst_tokens
==
recv_worst_topk_idx
.
size
(
0
)
assert
num_worst_tokens
==
recv_worst_topk_weights
.
size
(
0
)
assert
torch
.
equal
(
recv_x
,
recv_worst_x
[:
recv_x
.
size
(
0
)])
assert
torch
.
equal
(
recv_topk_idx
,
recv_worst_topk_idx
[:
recv_x
.
size
(
0
)])
assert
torch
.
equal
(
recv_topk_weights_clone
,
recv_worst_topk_weights
[:
recv_x
.
size
(
0
)])
assert
torch
.
all
(
recv_worst_topk_idx
[
recv_x
.
size
(
0
):]
==
-
1
).
item
()
# Test cached dispatch (must without top-k staffs)
if
not
with_topk
:
dispatch_args
=
{
'x'
:
current_x
,
'handle'
:
handle
,
'config'
:
config
,
'async_finish'
:
async_mode
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment