Unverified Commit cd4b39a9 authored by Bowen Bao's avatar Bowen Bao Committed by GitHub
Browse files

[quantization] Properly ignore quantization for layers excluded in quant_config (#11205)

parent 420c99ac
...@@ -207,15 +207,11 @@ class FusedMoE(torch.nn.Module): ...@@ -207,15 +207,11 @@ class FusedMoE(torch.nn.Module):
gemm1_clamp_limit=gemm1_clamp_limit, gemm1_clamp_limit=gemm1_clamp_limit,
) )
if quant_config is None: self.quant_method: Optional[FusedMoEMethodBase] = None
self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod( if quant_config is not None:
self.use_triton_kernels self.quant_method = quant_config.get_quant_method(self, prefix)
) if self.quant_method is None:
else: self.quant_method = UnquantizedFusedMoEMethod(self.use_triton_kernels)
self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
self, prefix
)
assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
......
...@@ -65,7 +65,9 @@ class QuarkConfig(QuantizationConfig): ...@@ -65,7 +65,9 @@ class QuarkConfig(QuantizationConfig):
if should_ignore_layer( if should_ignore_layer(
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
): ):
return UnquantizedLinearMethod() if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
return None
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix) scheme = self.get_scheme(layer=layer, layer_name=prefix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment