"docs/vscode:/vscode.git/clone" did not exist on "1d34a19710c20bb27e1311326153c804903eb10f"
Unverified Commit 82eccae4 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Let ep_scatter support arbitrary strides / ue8m0 format (#7309)

parent a8c10aee
......@@ -813,14 +813,17 @@ def _fwd_kernel_ep_scatter_2(
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = index_in_s < SCALE_HIDDEN_SIZE
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
token_id = token_id_int32.to(tl.int64)
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
recv_x_scale
+ token_id * recv_x_scale_stride0
+ index_in_s * recv_x_scale_stride1,
mask=mask_s,
)
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
......@@ -841,7 +844,11 @@ def _fwd_kernel_ep_scatter_2(
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
)
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
tl.store(
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
to_copy_s,
mask=mask_s,
)
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
......@@ -856,6 +863,7 @@ def ep_scatter(
output_tensor_scale: torch.Tensor,
m_indices: torch.Tensor,
output_index: torch.Tensor,
scale_ue8m0: bool = False,
):
BLOCK_E = 128 # token num of per expert is aligned to 128
BLOCK_D = 128 # block size of quantization
......@@ -865,7 +873,15 @@ def ep_scatter(
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
scale_hidden_size = hidden_size // BLOCK_D
if scale_ue8m0:
# ue8m0 scales are packed here (4 scales per int32),
# hence the effective size of this dimension is divided by 4.
scale_hidden_size = ceil_div(scale_hidden_size, 4)
assert m_indices.shape[0] % BLOCK_E == 0
assert recv_x_scale.dtype == output_tensor_scale.dtype
assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
......@@ -904,8 +920,8 @@ def ep_scatter(
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment