Unverified Commit a62aaf1d authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

parent 603ad848
......@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.linear_method, Fp8LinearMethod)
assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
......@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config,
linear_method=mock_linear_method)
quant_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config,
linear_method=mock_linear_method)
quant_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value
......
......@@ -389,10 +389,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor,
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
......@@ -416,7 +415,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
output_parallel = self.apply_weights(input_, bias)
output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
......@@ -523,10 +522,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
......@@ -765,10 +763,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
......@@ -862,9 +859,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
......@@ -897,7 +893,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
......
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
......@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
......@@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
class LinearMethodBase(ABC):
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
......@@ -50,7 +51,7 @@ class LinearMethodBase(ABC):
raise NotImplementedError
@abstractmethod
def apply_weights(self,
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
......@@ -59,13 +60,6 @@ class LinearMethodBase(ABC):
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization.
......@@ -92,7 +86,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply_weights(self,
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
......@@ -104,8 +98,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
return F.linear(x, weight, bias)
class ReplicatedLinear(torch.nn.Module):
"""Replicated linear layer.
class LinearBase(torch.nn.Module):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
......@@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module):
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -134,12 +127,43 @@ class ReplicatedLinear(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_method.create_weights(self, self.input_size,
if quant_config is None:
self.quant_method = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
......@@ -149,12 +173,12 @@ class ReplicatedLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self, x, bias)
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class ColumnParallelLinear(torch.nn.Module):
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
......@@ -171,7 +195,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
......@@ -184,28 +208,20 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
):
super().__init__()
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size)
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
if output_sizes is None:
output_sizes = [output_size]
self.linear_method = linear_method
self.linear_method.create_weights(self,
self.quant_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
self.input_size,
......@@ -239,7 +255,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(self, input_, bias)
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
......@@ -267,7 +283,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
"""
def __init__(
......@@ -278,13 +294,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method,
skip_bias_add, params_dtype, quant_config,
self.output_sizes)
def weight_loader(self,
......@@ -384,7 +400,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
"""
def __init__(
......@@ -396,7 +412,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
......@@ -424,7 +440,7 @@ class QKVParallelLinear(ColumnParallelLinear):
]
super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method, output_sizes)
params_dtype, quant_config, output_sizes)
def weight_loader(self,
param: Parameter,
......@@ -517,7 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight)
class RowParallelLinear(torch.nn.Module):
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
......@@ -540,7 +556,7 @@ class RowParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
"""
def __init__(
......@@ -552,26 +568,18 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_method.create_weights(self,
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
self.input_size,
......@@ -616,8 +624,7 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self, input_parallel)
output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
......
......@@ -4,7 +4,7 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import FP8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
......@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"fp8": FP8Config,
"fp8": Fp8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
......
......@@ -9,10 +9,10 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
def get_int_dtype(nbits: int) -> torch.dtype:
......@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig):
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)
def get_linear_method(self) -> "AQLMLinearMethod":
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase):
return AQLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
......
......@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class AWQConfig(QuantizationConfig):
......@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig):
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
def get_linear_method(self) -> "AWQLinearMethod":
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
......@@ -147,7 +150,7 @@ class AWQLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(self,
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
......
......@@ -2,8 +2,33 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List
import torch
from torch import nn
from vllm.model_executor.layers.linear import LinearMethodBase
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
**extra_weight_attrs):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class QuantizationConfig(ABC):
......@@ -51,8 +76,8 @@ class QuantizationConfig(ABC):
"quantization config.")
@abstractmethod
def get_linear_method(self) -> LinearMethodBase:
"""Get the linear method to use for the quantized linear layer."""
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
"""Get the quantize method to use for the quantized layer."""
raise NotImplementedError
@abstractmethod
......
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
class FP8Config(QuantizationConfig):
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@classmethod
......@@ -33,11 +34,14 @@ class FP8Config(QuantizationConfig):
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
return cls()
def get_linear_method(self) -> "Fp8LinearMethod":
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config.
"""
def __init__(self, quant_config: FP8Config):
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(
......@@ -86,24 +90,24 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("weight_scaling_factor", w_scale)
def process_weights_after_loading(self, layer: Module) -> None:
# Although the linear_method is propagated to all layers,
# Although the quant_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"):
return
qweight, weight_scale = per_tensor_quantize(layer.weight)
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scaling_factor.data.copy_(weight_scale)
def apply_weights(self,
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qinput, x_scale = per_tensor_quantize(x)
qinput, x_scale = ops.scaled_fp8_quant(x)
output, _ = torch._scaled_mm(
qinput,
layer.weight,
......@@ -113,27 +117,3 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias,
)
return output
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
......@@ -7,10 +7,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class GPTQConfig(QuantizationConfig):
......@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
def get_linear_method(self) -> "GPTQLinearMethod":
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -194,7 +197,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.exllama_state = exllama_state
def apply_weights(self,
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
......
......@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class MarlinConfig(QuantizationConfig):
......@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
def get_linear_method(self) -> "MarlinLinearMethod":
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return MarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def apply_weights(
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
......
......@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip
......@@ -51,14 +51,18 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return
def get_scaled_act_names(self) -> List[str]:
return []
class SqueezeLLMLinearMethod(LinearMethodBase):
class SqueezeLLMLinearMethod(QuantizeMethodBase):
"""Linear method for SqueezeLLM.
Args:
......@@ -112,7 +116,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs)
def apply_weights(self,
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
......
......@@ -3,8 +3,7 @@ import copy
import glob
import os
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
Type)
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import torch
from torch import nn
......@@ -13,6 +12,8 @@ from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
tensorizer_weights_iterator)
......@@ -24,9 +25,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearMethodBase
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
......@@ -34,11 +32,10 @@ _VISION_MODEL_CLASSES = [
logger = init_logger(__name__)
def _get_linear_method(
def _get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional["LinearMethodBase"]:
"""Get the (maybe quantized) linear method."""
linear_method = None
load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability()
......@@ -55,9 +52,8 @@ def _get_linear_method(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
return linear_method
return quant_config
return None
def _get_model_initialization_kwargs(
......@@ -85,10 +81,10 @@ def _initialize_model(
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, load_config)
quant_config = _get_quantization_config(model_config, load_config)
return model_class(config=model_config.hf_config,
linear_method=linear_method,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config))
......@@ -229,9 +225,11 @@ class DefaultModelLoader(BaseModelLoader):
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
linear_method = getattr(module, "linear_method", None)
if linear_method is not None:
linear_method.process_weights_after_loading(module)
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
return model.eval()
......@@ -314,11 +312,11 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config,
self.load_config)
quant_config = _get_quantization_config(
model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)
extra_kwargs["linear_method"] = linear_method
extra_kwargs["quant_config"] = quant_config
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
......
......@@ -13,7 +13,8 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -251,7 +252,7 @@ class TensorizerAgent:
"""
def __init__(self, tensorizer_config: TensorizerConfig,
linear_method: LinearMethodBase, **extra_kwargs):
quant_config: QuantizationConfig, **extra_kwargs):
if tensorizer_load_fail is not None:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
......@@ -262,10 +263,10 @@ class TensorizerAgent:
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs
if extra_kwargs.get("linear_method", None) is not None:
self.linear_method = extra_kwargs["linear_method"]
if extra_kwargs.get("quant_config", None) is not None:
self.quant_config = extra_kwargs["quant_config"]
else:
self.linear_method = linear_method
self.quant_config = quant_config
self.model = self._init_model()
def _init_model(self):
......@@ -274,7 +275,7 @@ class TensorizerAgent:
with no_init_or_tensor():
return self.tensorizer_config.model_class(
config=model_args,
linear_method=self.linear_method,
quant_config=self.quant_config,
**self.extra_kwargs)
def _resize_lora_embeddings(self):
......
......@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
self.total_num_heads,
self.total_num_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
......@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
......@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
......@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, linear_method)
BaiChuanDecoderLayer(config, position_embedding, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = BaiChuanModel(config, position_embedding, linear_method)
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method, lora_config)
super().__init__(config, "ROPE", quant_config, lora_config)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method, lora_config)
super().__init__(config, "ALIBI", quant_config, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
......@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, "ROPE", linear_method, lora_config)
super().__init__(config, "ROPE", quant_config, lora_config)
......@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
......@@ -129,21 +130,21 @@ class BloomMLP(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
linear_method=linear_method,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -158,17 +159,17 @@ class BloomBlock(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, linear_method)
self.self_attention = BloomAttention(config, quant_config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, linear_method)
self.mlp = BloomMLP(config, quant_config)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
......@@ -214,7 +215,7 @@ class BloomModel(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.hidden_size
......@@ -229,7 +230,7 @@ class BloomModel(nn.Module):
# Transformer blocks
self.h = nn.ModuleList([
BloomBlock(config, linear_method)
BloomBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
......@@ -262,12 +263,12 @@ class BloomForCausalLM(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = BloomModel(config, linear_method)
self.quant_config = quant_config
self.transformer = BloomModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
linear_method=linear_method,
quant_config=quant_config,
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
......@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
linear_method=linear_method,
quant_config=quant_config,
)
self.activation_func = SiluAndMul()
......@@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
linear_method=linear_method,
quant_config=quant_config,
)
def forward(self, hidden_states):
......@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
......@@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, linear_method)
self.self_attention = GLMAttention(config, quant_config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
......@@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config, linear_method)
self.mlp = GLMMLP(config, quant_config)
def forward(
self,
......@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
......@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config, linear_method) for i in range(self.num_layers)])
[GLMBlock(config, quant_config) for i in range(self.num_layers)])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
......@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, linear_method)
self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
......@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.linear_method = linear_method
self.transformer = ChatGLMModel(config, linear_method)
self.quant_config = quant_config
self.transformer = ChatGLMModel(config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
......
......@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.act_fn = SiluAndMul()
......@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
def __init__(
self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
......@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
def __init__(self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, linear_method=linear_method)
self.self_attn = CohereAttention(config, quant_config=quant_config)
self.mlp = CohereMLP(config, linear_method=linear_method)
self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
......@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
def __init__(
self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
CohereDecoderLayer(config, linear_method=linear_method)
CohereDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = LayerNorm(param_shape=(config.hidden_size),
......@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
def __init__(
self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale)
self.model = CohereModel(config, linear_method)
self.model = CohereModel(config, quant_config)
self.sampler = Sampler()
@torch.no_grad()
......
......@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
self.num_total_experts,
bias=False,
params_dtype=params_dtype,
linear_method=None,
quant_config=None,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
......@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
......@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, linear_method)
self.attn = DbrxAttention(config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
......@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
self.ffn = DbrxExperts(config, linear_method)
self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
self.ffn = DbrxExperts(config, quant_config)
def forward(
self,
......@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.wte = VocabParallelEmbedding(
......@@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
[DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
......@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, linear_method)
self.transformer = DbrxModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,
......
......@@ -29,7 +29,8 @@ import torch
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
......@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config,
linear_method=linear_method,
quant_config=quant_config,
lora_config=lora_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment