Unverified Commit 2339d59f authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[BugFix] Fix quantization for all other methods (#11547)

parent 1b875a0e
...@@ -41,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -41,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def apply(self, layer: torch.nn.Module, x: torch.Tensor, def apply(
router_logits: torch.Tensor, top_k: int, renormalize: bool, self,
use_grouped_topk: bool) -> torch.Tensor: layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -79,7 +90,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -79,7 +90,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
......
...@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
...@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
...@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,
...@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
...@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,
......
...@@ -601,14 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -601,14 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
......
...@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
...@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=None) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment