Unverified Commit 90f44b74 authored by SijiaYang's avatar SijiaYang Committed by GitHub
Browse files

fix: w4afp8 accuracy problem and rebase (#8752)


Signed-off-by: default avataryangsijia.614 <yangsijia.614@bytedance.com>
Co-authored-by: default avatarJinwu <ayrnb@users.noreply.github.com>
parent 38907fe6
...@@ -11,7 +11,7 @@ from sgl_kernel import ( ...@@ -11,7 +11,7 @@ from sgl_kernel import (
) )
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
post_reorder_triton_kernel, post_reorder_triton_kernel_for_cutlass_moe,
pre_reorder_triton_kernel_for_cutlass_moe, pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess, run_cutlass_moe_ep_preproess,
) )
...@@ -199,14 +199,13 @@ def cutlass_w4a8_moe( ...@@ -199,14 +199,13 @@ def cutlass_w4a8_moe(
) )
output = torch.empty_like(a) output = torch.empty_like(a)
post_reorder_triton_kernel[(m,)]( post_reorder_triton_kernel_for_cutlass_moe[(m,)](
c2, c2,
output, output,
src2dst, src2dst,
topk_ids_, local_topk_ids,
topk_weights, topk_weights,
start_expert_id, num_experts,
end_expert_id,
topk, topk,
k, k,
0, 0,
......
...@@ -581,6 +581,49 @@ def post_reorder_triton_kernel( ...@@ -581,6 +581,49 @@ def post_reorder_triton_kernel(
) )
@triton.jit
def post_reorder_triton_kernel_for_cutlass_moe(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
num_experts,
topk,
hidden_size,
dst_start,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty
src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk
store_ptr = output_ptr + src_idx * hidden_size
vec = tl.arange(0, BLOCK_SIZE)
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id != num_experts:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - dst_start
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)
@triton.jit @triton.jit
def compute_m_range( def compute_m_range(
pid, pid,
......
...@@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
assert "weight_loader" in extra_weight_attrs assert "weight_loader" in extra_weight_attrs
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
...@@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts, num_experts,
...@@ -274,8 +279,11 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -274,8 +279,11 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: EPMoE, layer: EPMoE,
hidden_states: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -284,19 +292,17 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -284,19 +292,17 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids local_topk_ids = topk_ids
if layer.expert_map is not None: local_topk_ids = torch.where(
"Translate info from expert_map to topk_ids" topk_ids == -1,
local_topk_ids = torch.where( layer.num_experts,
layer.expert_map[topk_ids] != layer.num_experts, topk_ids,
layer.expert_map[topk_ids], )
layer.num_experts,
) output = cutlass_w4a8_moe(
return cutlass_w4a8_moe(
layer.start_expert_id, layer.start_expert_id,
layer.end_expert_id, layer.end_expert_id,
layer.num_experts, layer.num_experts,
hidden_states, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
layer.w13_weight_scale_inv, layer.w13_weight_scale_inv,
...@@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale, layer.w13_input_scale,
layer.w2_input_scale, layer.w2_input_scale,
) )
if routed_scaling_factor is not None:
output *= routed_scaling_factor
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment