Unverified Commit 8d17774f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add AWQ support for all models (#1714)

parent e946260c
"""Custom activation functions."""
from typing import Optional
import torch
import torch.nn as nn
from vllm import activation_ops
from vllm.model_executor.layers.quantization import QuantizationConfig
class SiluAndMul(nn.Module):
......@@ -39,6 +42,27 @@ class FastGELU(nn.Module):
return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def __init__(
self,
act_module: nn.Module,
hidden_size: int,
params_dtype: torch.dtype,
):
super().__init__()
self.act = act_module
self.scales = nn.Parameter(
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
def forward(self, x: torch.Tensor):
return self.act(x) / self.scales
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
......@@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = {
}
def get_act_fn(act_fn: str) -> nn.Module:
def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
) -> nn.Module:
"""Get an activation function by name."""
act_fn = act_fn.lower()
if act_fn in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn]
raise ValueError(f"Activation function {act_fn!r} is not supported.")
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if quant_config is not None:
if act_fn_name in quant_config.get_scaled_act_names():
if intermediate_size is None:
raise ValueError(
"intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(
act_fn,
intermediate_size,
params_dtype=torch.get_default_dtype(),
)
return act_fn
......@@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig):
def get_linear_method(self) -> "AWQLinearMethod":
return AWQLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.
......
......@@ -54,3 +54,11 @@ class QuantizationConfig(ABC):
def get_linear_method(self) -> LinearMethodBase:
"""Get the linear method to use for the quantized linear layer."""
raise NotImplementedError
@abstractmethod
def get_scaled_act_names(self) -> List[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError
......@@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig):
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
return SqueezeLLMLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class SqueezeLLMLinearMethod(LinearMethodBase):
"""Linear method for SqueezeLLM.
......
......@@ -145,7 +145,8 @@ class BloomMLP(nn.Module):
4 * hidden_size,
linear_method=linear_method,
)
self.act = get_act_fn("gelu")
quant_config = getattr(linear_method, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
......@@ -154,7 +155,7 @@ class BloomMLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x)
x = self.act(x)
x = self.gelu_impl(x)
x, _ = self.dense_4h_to_h(x)
return x
......
......@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi,
PagedAttentionWithRoPE)
......@@ -131,6 +132,7 @@ class FalconAttention(nn.Module):
self.hidden_size,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
......@@ -206,7 +208,8 @@ class FalconMLP(nn.Module):
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method)
self.act = nn.GELU()
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear(
......
......@@ -118,7 +118,9 @@ class GPT2MLP(nn.Module):
bias=True,
linear_method=linear_method,
)
self.act = get_act_fn(config.activation_function)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
......
......@@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module):
bias=True,
linear_method=linear_method,
)
self.act = get_act_fn(config.activation_function)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
......
......@@ -128,7 +128,9 @@ class GPTJMLP(nn.Module):
hidden_size,
linear_method=linear_method,
)
self.act = get_act_fn(config.activation_function)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states)
......
......@@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module):
config.hidden_size,
linear_method=linear_method,
)
self.act = get_act_fn(config.hidden_act)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states)
......
......@@ -130,7 +130,8 @@ class MPTMLP(nn.Module):
bias=not config.no_bias,
linear_method=linear_method,
)
self.act = get_act_fn("gelu")
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
......
......@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module):
linear_method=linear_method,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.activation_fn = get_act_fn(config.activation_function)
quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
......@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module):
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)):
......@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module):
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
hidden_states, _ = self.project_out(hidden_states)
return hidden_states
......
......@@ -168,7 +168,9 @@ class PhiMLP(nn.Module):
config.hidden_size,
linear_method=linear_method,
)
self.act = get_act_fn(config.activation_function)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
n_inner)
def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment