Unverified Commit d62cfe54 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[MoE Refactoring][Bugfix]Wrap WNA16 Triton kernel into mk and change...


[MoE Refactoring][Bugfix]Wrap WNA16 Triton kernel into mk and change compressed tensor kernel selection (#31752)
Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 6cdf015c
...@@ -85,6 +85,7 @@ if HAS_TRITON: ...@@ -85,6 +85,7 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk, GroupedTopk,
TritonExperts, TritonExperts,
TritonWNA16Experts,
fused_experts, fused_experts,
fused_topk, fused_topk,
get_config_file_name, get_config_file_name,
...@@ -103,6 +104,7 @@ if HAS_TRITON: ...@@ -103,6 +104,7 @@ if HAS_TRITON:
"CutlassBatchedExpertsFp8", "CutlassBatchedExpertsFp8",
"CutlassExpertsW4A8Fp8", "CutlassExpertsW4A8Fp8",
"TritonExperts", "TritonExperts",
"TritonWNA16Experts",
"BatchedTritonExperts", "BatchedTritonExperts",
"DeepGemmExperts", "DeepGemmExperts",
"BatchedDeepGemmExperts", "BatchedDeepGemmExperts",
......
...@@ -624,11 +624,11 @@ def invoke_fused_moe_wna16_triton_kernel( ...@@ -624,11 +624,11 @@ def invoke_fused_moe_wna16_triton_kernel(
compute_type: tl.dtype, compute_type: tl.dtype,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
block_shape: list[int], block_shape: list[int] | None,
): ):
assert B_scale is not None and B_scale.ndim == 3 assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3 assert B_zp is None or B_zp.ndim == 3
assert block_shape is None or block_shape[0] == 0 assert block_shape is not None and block_shape[0] == 0
M = A.size(0) M = A.size(0)
num_tokens = M * top_k num_tokens = M * top_k
...@@ -2447,6 +2447,148 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2447,6 +2447,148 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
ops.moe_sum(input, output) ops.moe_sum(input, output)
class TritonWNA16Experts(TritonExperts):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
)
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
]
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)
if global_num_experts == -1:
global_num_experts = E
config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
)
invoke_fused_moe_wna16_triton_kernel(
hidden_states,
w1,
intermediate_cache1,
self.w1_scale,
self.quant_config.w1_zp,
None, # topk_weights
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False, # mul_routed_weights
top_k_num,
config,
compute_type=compute_type,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)
self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)
a2q_scale: torch.Tensor | None = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
self.quant_dtype,
self.per_act_token_quant,
self.block_shape,
)
invoke_fused_moe_wna16_triton_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
self.w2_scale,
self.quant_config.w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe( def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEModularKernel:
......
...@@ -1693,11 +1693,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1693,11 +1693,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
if HAS_TRITON: if HAS_TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe import TritonWNA16Experts
layer.w13_weight = layer.w13_weight_packed layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed layer.w2_weight = layer.w2_weight_packed
return TritonExperts(quant_config=self.moe_quant_config) return TritonWNA16Experts(quant_config=self.moe_quant_config)
else: else:
raise NotImplementedError( raise NotImplementedError(
"TritonExperts requires Triton. " "TritonExperts requires Triton. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment