Unverified Commit 399c7986 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Remove ScaledActivation for AWQ (#10057)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
parent 406d4cc4
......@@ -9,7 +9,6 @@ import torch.nn.functional as F
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict
......@@ -277,28 +276,14 @@ _ACTIVATION_REGISTRY = LazyDict({
})
def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
return _ACTIVATION_REGISTRY[act_fn_name]
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
......@@ -307,25 +292,11 @@ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
})
def get_act_and_mul_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
......@@ -213,9 +213,6 @@ class AQLMConfig(QuantizationConfig):
return AQLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class AQLMLinearMethod(LinearMethodBase):
"""Linear method for AQLM.
......
......@@ -77,9 +77,6 @@ class AWQConfig(QuantizationConfig):
return AWQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
......
......@@ -127,9 +127,6 @@ class AWQMarlinConfig(QuantizationConfig):
return AWQMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
......
......@@ -133,11 +133,3 @@ class QuantizationConfig(ABC):
method.
"""
raise NotImplementedError
@abstractmethod
def get_scaled_act_names(self) -> List[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError
......@@ -114,9 +114,6 @@ class BitsAndBytesConfig(QuantizationConfig):
return BitsAndBytesLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
# Split the prefix into its dot-separated components
......
......@@ -45,9 +45,6 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
......
......@@ -50,9 +50,6 @@ class DeepSpeedFPConfig(QuantizationConfig):
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
return DeepSpeedFPLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
......
......@@ -45,9 +45,6 @@ class ExpertsInt8Config(QuantizationConfig):
return ExpertsInt8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
......
......@@ -64,9 +64,6 @@ class FBGEMMFp8Config(QuantizationConfig):
return FBGEMMFp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class FBGEMMFp8LinearMethod(LinearMethodBase):
......
......@@ -92,9 +92,6 @@ class Fp8Config(QuantizationConfig):
return Fp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
......
......@@ -48,9 +48,6 @@ class GGUFConfig(QuantizationConfig):
return GGUFEmbeddingMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
......
......@@ -80,9 +80,6 @@ class GPTQConfig(QuantizationConfig):
return GPTQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ExllamaState(Enum):
......
......@@ -125,9 +125,6 @@ class GPTQMarlinConfig(QuantizationConfig):
return GPTQMarlinMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
......
......@@ -127,9 +127,6 @@ class GPTQMarlin24Config(QuantizationConfig):
return GPTQMarlin24LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class GPTQMarlin24LinearMethod(LinearMethodBase):
"""Linear method for Marlin24.
......
......@@ -93,12 +93,6 @@ class IPEXConfig(QuantizationConfig):
return self.quant_method(self)
return None
def get_scaled_act_names(self) -> List[str]:
if self.method == "awq":
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
else:
return []
class IPEXAWQLinearMethod(AWQLinearMethod):
"""AWQ linear method using IPEX for the CPU backend.
......
......@@ -110,9 +110,6 @@ class MarlinConfig(QuantizationConfig):
return MarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin.
......
......@@ -68,9 +68,6 @@ class ModelOptFp8Config(QuantizationConfig):
return ModelOptFp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
......
......@@ -57,9 +57,6 @@ class NeuronQuantConfig(QuantizationConfig):
"Neuron Quantization is only supported through"
" transformers_neuronx.")
def get_scaled_act_names(self) -> List[str]:
return []
def get_quantization_config(self):
from transformers_neuronx.config import QuantizationConfig
return QuantizationConfig(quant_dtype=self.quant_dtype,
......
......@@ -112,9 +112,6 @@ class QQQConfig(QuantizationConfig):
return QQQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class QQQLinearMethod(LinearMethodBase):
"""Linear method for QQQ.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment