Unverified Commit e31946f8 authored by Xiaozhu Meng's avatar Xiaozhu Meng Committed by GitHub
Browse files

[flashinfer] fix FI all2all with FI cutlass moe (#28166)


Signed-off-by: default avatarXiaozhu <mxz297@gmail.com>
parent bde50393
......@@ -233,12 +233,13 @@ def flashinfer_alltoall_dispatch(
max_num_token = (
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
)
orig_topk_weights_dtype = topk_weights.dtype
alltoall_info, topk_ids, topk_weights, _ = (
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
topk_ids,
topk_weights,
None,
all2all_manager.prepare_workspace,
all2all_manager.prepare_workspace_tensor,
max_num_token,
ep_rank,
ep_size,
......@@ -247,6 +248,7 @@ def flashinfer_alltoall_dispatch(
top_k,
)
)
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
x, x_sf = moe_kernel_quantize_input(
x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment