"vscode:/vscode.git/clone" did not exist on "9bca40296e3f00fb26597a0f4cfe2fdfd2ad2fd2"
Unverified Commit d332aa3b authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: resolve fp8 moe issue (#2387)

parent c36736c8
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
...@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
return QUANTIZATION_METHODS[quantization] return QUANTIZATION_METHODS[quantization]
def fp8_moe_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
"""Enhanced apply method for FP8 MoE."""
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
)
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
def fp8_get_quant_method(self, layer, prefix): def fp8_get_quant_method(self, layer, prefix):
"""Enhanced get_quant_method for FP8 config.""" """Enhanced get_quant_method for FP8 config."""
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
...@@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix): ...@@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix):
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.linear import UnquantizedLinearMethod from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
...@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix): ...@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
def apply_monkey_patches(): def apply_monkey_patches():
"""Apply all monkey patches in one place.""" """Apply all monkey patches in one place."""
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
......
...@@ -24,11 +24,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -24,11 +24,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
) )
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -100,6 +95,8 @@ class Fp8Config(QuantizationConfig): ...@@ -100,6 +95,8 @@ class Fp8Config(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
from sglang.srt.layers.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
...@@ -306,7 +303,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -306,7 +303,7 @@ class Fp8LinearMethod(LinearMethodBase):
) )
class Fp8MoEMethod(FusedMoEMethodBase): class Fp8MoEMethod:
"""MoE method for FP8. """MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale. dynamic/static activation scale.
...@@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __init__(self, quant_config: Fp8Config): def __new__(cls, *args, **kwargs):
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
...@@ -331,6 +346,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -331,6 +346,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
...@@ -521,8 +537,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -521,8 +537,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.fused_moe_triton import FusedMoE
from vllm.model_executor.layers.fused_moe import fused_experts from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment