"Python/Classifier_OffloadFalse.py" did not exist on "a30cc948a0d7ac954fc53aad73e2cf307bd1e072"
Unverified Commit 8d17774f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add AWQ support for all models (#1714)

parent e946260c
"""Custom activation functions.""" """Custom activation functions."""
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import activation_ops from vllm import activation_ops
from vllm.model_executor.layers.quantization import QuantizationConfig
class SiluAndMul(nn.Module): class SiluAndMul(nn.Module):
...@@ -39,6 +42,27 @@ class FastGELU(nn.Module): ...@@ -39,6 +42,27 @@ class FastGELU(nn.Module):
return out return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def __init__(
self,
act_module: nn.Module,
hidden_size: int,
params_dtype: torch.dtype,
):
super().__init__()
self.act = act_module
self.scales = nn.Parameter(
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
def forward(self, x: torch.Tensor):
return self.act(x) / self.scales
_ACTIVATION_REGISTRY = { _ACTIVATION_REGISTRY = {
"gelu": nn.GELU(), "gelu": nn.GELU(),
"gelu_fast": FastGELU(), "gelu_fast": FastGELU(),
...@@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = { ...@@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = {
} }
def get_act_fn(act_fn: str) -> nn.Module: def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
) -> nn.Module:
"""Get an activation function by name.""" """Get an activation function by name."""
act_fn = act_fn.lower() act_fn_name = act_fn_name.lower()
if act_fn in _ACTIVATION_REGISTRY: if act_fn_name not in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn] raise ValueError(
raise ValueError(f"Activation function {act_fn!r} is not supported.") f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if quant_config is not None:
if act_fn_name in quant_config.get_scaled_act_names():
if intermediate_size is None:
raise ValueError(
"intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(
act_fn,
intermediate_size,
params_dtype=torch.get_default_dtype(),
)
return act_fn
...@@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig): ...@@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig):
def get_linear_method(self) -> "AWQLinearMethod": def get_linear_method(self) -> "AWQLinearMethod":
return AWQLinearMethod(self) return AWQLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class AWQLinearMethod(LinearMethodBase): class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ. """Linear method for AWQ.
......
...@@ -54,3 +54,11 @@ class QuantizationConfig(ABC): ...@@ -54,3 +54,11 @@ class QuantizationConfig(ABC):
def get_linear_method(self) -> LinearMethodBase: def get_linear_method(self) -> LinearMethodBase:
"""Get the linear method to use for the quantized linear layer.""" """Get the linear method to use for the quantized linear layer."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_scaled_act_names(self) -> List[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError
...@@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig): ...@@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig):
def get_linear_method(self) -> "SqueezeLLMLinearMethod": def get_linear_method(self) -> "SqueezeLLMLinearMethod":
return SqueezeLLMLinearMethod(self) return SqueezeLLMLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class SqueezeLLMLinearMethod(LinearMethodBase): class SqueezeLLMLinearMethod(LinearMethodBase):
"""Linear method for SqueezeLLM. """Linear method for SqueezeLLM.
......
...@@ -145,7 +145,8 @@ class BloomMLP(nn.Module): ...@@ -145,7 +145,8 @@ class BloomMLP(nn.Module):
4 * hidden_size, 4 * hidden_size,
linear_method=linear_method, linear_method=linear_method,
) )
self.act = get_act_fn("gelu") quant_config = getattr(linear_method, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,
...@@ -154,7 +155,7 @@ class BloomMLP(nn.Module): ...@@ -154,7 +155,7 @@ class BloomMLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x) x, _ = self.dense_h_to_4h(x)
x = self.act(x) x = self.gelu_impl(x)
x, _ = self.dense_4h_to_h(x) x, _ = self.dense_4h_to_h(x)
return x return x
......
...@@ -27,6 +27,7 @@ from torch.nn import LayerNorm ...@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (PagedAttention, from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi, PagedAttentionWithALiBi,
PagedAttentionWithRoPE) PagedAttentionWithRoPE)
...@@ -131,6 +132,7 @@ class FalconAttention(nn.Module): ...@@ -131,6 +132,7 @@ class FalconAttention(nn.Module):
self.hidden_size, self.hidden_size,
bias=config.bias, bias=config.bias,
skip_bias_add=True, skip_bias_add=True,
linear_method=linear_method,
reduce_results=self.reduce_row_parallel_results) reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary self.use_rotary = config.rotary
...@@ -206,7 +208,8 @@ class FalconMLP(nn.Module): ...@@ -206,7 +208,8 @@ class FalconMLP(nn.Module):
bias=config.bias, bias=config.bias,
skip_bias_add=True, skip_bias_add=True,
linear_method=linear_method) linear_method=linear_method)
self.act = nn.GELU() quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn) or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
......
...@@ -118,7 +118,9 @@ class GPT2MLP(nn.Module): ...@@ -118,7 +118,9 @@ class GPT2MLP(nn.Module):
bias=True, bias=True,
linear_method=linear_method, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
......
...@@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module): ...@@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module):
bias=True, bias=True,
linear_method=linear_method, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
......
...@@ -128,7 +128,9 @@ class GPTJMLP(nn.Module): ...@@ -128,7 +128,9 @@ class GPTJMLP(nn.Module):
hidden_size, hidden_size,
linear_method=linear_method, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states) hidden_states, _ = self.fc_in(hidden_states)
......
...@@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module): ...@@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module):
config.hidden_size, config.hidden_size,
linear_method=linear_method, linear_method=linear_method,
) )
self.act = get_act_fn(config.hidden_act) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states, _ = self.dense_h_to_4h(hidden_states)
......
...@@ -130,7 +130,8 @@ class MPTMLP(nn.Module): ...@@ -130,7 +130,8 @@ class MPTMLP(nn.Module):
bias=not config.no_bias, bias=not config.no_bias,
linear_method=linear_method, linear_method=linear_method,
) )
self.act = get_act_fn("gelu") quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
......
...@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module): ...@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module):
linear_method=linear_method, linear_method=linear_method,
) )
self.do_layer_norm_before = config.do_layer_norm_before self.do_layer_norm_before = config.do_layer_norm_before
self.activation_fn = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.self_attn_layer_norm = nn.LayerNorm( self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, self.embed_dim,
...@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module): ...@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions) pos_embeds = self.embed_positions(positions)
if self.project_in is not None: if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds) inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)): for i in range(len(self.layers)):
...@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module): ...@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module):
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None: if self.project_out is not None:
hidden_states = self.project_out(hidden_states) hidden_states, _ = self.project_out(hidden_states)
return hidden_states return hidden_states
......
...@@ -168,7 +168,9 @@ class PhiMLP(nn.Module): ...@@ -168,7 +168,9 @@ class PhiMLP(nn.Module):
config.hidden_size, config.hidden_size,
linear_method=linear_method, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
n_inner)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment