Unverified Commit a62aaf1d authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

parent 603ad848
...@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None: ...@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
fc1 = model.model.decoder.layers[0].fc1 fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.linear_method, Fp8LinearMethod) assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn assert fc1.weight.dtype == torch.float8_e4m3fn
...@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config): ...@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_agent_instance.deserialize.return_value = MagicMock() mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config, result = load_with_tensorizer(tensorizer_config,
linear_method=mock_linear_method) quant_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config, mock_agent.assert_called_once_with(tensorizer_config,
linear_method=mock_linear_method) quant_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once() mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value assert result == mock_agent_instance.deserialize.return_value
......
...@@ -389,10 +389,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -389,10 +389,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices = base_indices self.indices = base_indices
self.indices_len = indices_len self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.base_layer, x, bias)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -416,7 +415,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -416,7 +415,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if not self.base_layer.skip_bias_add else None) if not self.base_layer.skip_bias_add else None)
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply_weights(input_, bias) output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output: if self.base_layer.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
...@@ -523,10 +522,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -523,10 +522,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True) lora_b[1].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -765,10 +763,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -765,10 +763,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True) lora_a[2].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -862,9 +859,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -862,9 +859,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices = base_indices self.indices = base_indices
self.indices_len = indices_len self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x)
self.base_layer, x)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -897,7 +893,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -897,7 +893,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply_weights(input_parallel) output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1: if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
......
from abc import ABC, abstractmethod from abc import abstractmethod
from typing import List, Optional from typing import List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset): ...@@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
class LinearMethodBase(ABC): class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
@abstractmethod @abstractmethod
...@@ -50,22 +51,15 @@ class LinearMethodBase(ABC): ...@@ -50,22 +51,15 @@ class LinearMethodBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor. """Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer.""" Expects create_weights to have been called before on the layer."""
raise NotImplementedError raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class UnquantizedLinearMethod(LinearMethodBase): class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization. """Linear method without quantization.
...@@ -92,10 +86,10 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -92,10 +86,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs) set_weight_attrs(weight, extra_weight_attrs)
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight weight = layer.weight
if self.separate_bias_add: if self.separate_bias_add:
if bias is not None: if bias is not None:
...@@ -104,8 +98,8 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -104,8 +98,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
return F.linear(x, weight, bias) return F.linear(x, weight, bias)
class ReplicatedLinear(torch.nn.Module): class LinearBase(torch.nn.Module):
"""Replicated linear layer. """Base linear layer.
Args: Args:
input_size: input dimension of the linear layer. input_size: input dimension of the linear layer.
...@@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module):
bias: If true, add bias. bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it. skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
""" """
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
output_size: int, output_size: int,
bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -134,12 +127,43 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -134,12 +127,43 @@ class ReplicatedLinear(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
if linear_method is None: if quant_config is None:
linear_method = UnquantizedLinearMethod() self.quant_method = UnquantizedLinearMethod()
self.linear_method = linear_method else:
self.linear_method.create_weights(self, self.input_size, self.quant_method = quant_config.get_quant_method(self)
[self.output_size], self.input_size,
self.output_size, self.params_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype)) torch.empty(self.output_size, dtype=self.params_dtype))
...@@ -149,12 +173,12 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -149,12 +173,12 @@ class ReplicatedLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self, x, bias) output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along The linear layer is defined as Y = XA + b. A is parallelized along
...@@ -171,7 +195,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -171,7 +195,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. we bias can be fused with other element-wise operations. we
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3. the list would be size 3.
""" """
...@@ -184,34 +208,26 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -184,34 +208,26 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None, output_sizes: Optional[List[int]] = None,
): ):
super().__init__() super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size) self.output_size_per_partition = divide(output_size, tp_size)
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
if output_sizes is None: if output_sizes is None:
output_sizes = [output_size] output_sizes = [output_size]
self.linear_method = linear_method self.quant_method.create_weights(self,
self.linear_method.create_weights(self, self.input_size,
self.input_size, [x // tp_size for x in output_sizes],
[x // tp_size for x in output_sizes], self.input_size,
self.input_size, self.output_size,
self.output_size, self.params_dtype,
self.params_dtype, weight_loader=self.weight_loader)
weight_loader=self.weight_loader)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
...@@ -239,7 +255,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -239,7 +255,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
output_parallel = self.linear_method.apply_weights(self, input_, bias) output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
...@@ -267,7 +283,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -267,7 +283,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we bias can be fused with other element-wise operations. we
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
""" """
def __init__( def __init__(
...@@ -278,13 +294,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -278,13 +294,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output, super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method, skip_bias_add, params_dtype, quant_config,
self.output_sizes) self.output_sizes)
def weight_loader(self, def weight_loader(self,
...@@ -384,7 +400,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -384,7 +400,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we bias can be fused with other element-wise operations. we
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
""" """
def __init__( def __init__(
...@@ -396,7 +412,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -396,7 +412,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
...@@ -424,7 +440,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -424,7 +440,7 @@ class QKVParallelLinear(ColumnParallelLinear):
] ]
super().__init__(input_size, output_size, bias, False, skip_bias_add, super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method, output_sizes) params_dtype, quant_config, output_sizes)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -517,7 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -517,7 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along The linear layer is defined as Y = XA + b. A is parallelized along
...@@ -540,7 +556,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -540,7 +556,7 @@ class RowParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. bias can be fused with other element-wise operations.
We skip adding bias but instead return it. We skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
""" """
def __init__( def __init__(
...@@ -552,32 +568,24 @@ class RowParallelLinear(torch.nn.Module): ...@@ -552,32 +568,24 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True, reduce_results: bool = True,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__(input_size, output_size, skip_bias_add, params_dtype,
# Keep input parameters quant_config)
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add self.quant_method.create_weights(self,
if linear_method is None: self.input_size_per_partition,
linear_method = UnquantizedLinearMethod() [self.output_size],
self.linear_method = linear_method self.input_size,
self.linear_method.create_weights(self, self.output_size,
self.input_size_per_partition, self.params_dtype,
[self.output_size], weight_loader=self.weight_loader)
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
...@@ -616,8 +624,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -616,8 +624,7 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
output_parallel = self.linear_method.apply_weights( output_parallel = self.quant_method.apply(self, input_parallel)
self, input_parallel)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
......
...@@ -4,7 +4,7 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig ...@@ -4,7 +4,7 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import FP8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig ...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = { QUANTIZATION_METHODS = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"fp8": FP8Config, "fp8": Fp8Config,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig, "marlin": MarlinConfig,
......
...@@ -9,10 +9,10 @@ import torch.nn.functional as F ...@@ -9,10 +9,10 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
def get_int_dtype(nbits: int) -> torch.dtype: def get_int_dtype(nbits: int) -> torch.dtype:
...@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig): ...@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig):
return cls(in_group_size, nbits_per_codebook, num_code_books, return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size) out_group_size)
def get_linear_method(self) -> "AQLMLinearMethod": def get_quant_method(
return AQLMLinearMethod(self) self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase):
return AQLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase): ...@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs) set_weight_attrs(scales, extra_weight_attrs)
def apply_weights( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -4,10 +4,10 @@ import torch ...@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
...@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig): ...@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig):
zero_point = cls.get_from_keys(config, ["zero_point"]) zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point) return cls(weight_bits, group_size, zero_point)
def get_linear_method(self) -> "AWQLinearMethod": def get_quant_method(
return AWQLinearMethod(self) self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
...@@ -147,10 +150,10 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -147,10 +150,10 @@ class AWQLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs) set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight qweight = layer.qweight
scales = layer.scales scales = layer.scales
qzeros = layer.qzeros qzeros = layer.qzeros
......
...@@ -2,8 +2,33 @@ from abc import ABC, abstractmethod ...@@ -2,8 +2,33 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from torch import nn
from vllm.model_executor.layers.linear import LinearMethodBase
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
**extra_weight_attrs):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
...@@ -51,8 +76,8 @@ class QuantizationConfig(ABC): ...@@ -51,8 +76,8 @@ class QuantizationConfig(ABC):
"quantization config.") "quantization config.")
@abstractmethod @abstractmethod
def get_linear_method(self) -> LinearMethodBase: def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
"""Get the linear method to use for the quantized linear layer.""" """Get the quantize method to use for the quantized layer."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm import _custom_ops as ops
set_weight_attrs) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
class FP8Config(QuantizationConfig): class Fp8Config(QuantizationConfig):
"""Config class for FP8.""" """Config class for FP8."""
@classmethod @classmethod
...@@ -33,11 +34,14 @@ class FP8Config(QuantizationConfig): ...@@ -33,11 +34,14 @@ class FP8Config(QuantizationConfig):
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config": def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
return cls() return cls()
def get_linear_method(self) -> "Fp8LinearMethod": def get_quant_method(
return Fp8LinearMethod(self) self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __init__(self, quant_config: FP8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
...@@ -86,24 +90,24 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -86,24 +90,24 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("weight_scaling_factor", w_scale) layer.register_parameter("weight_scaling_factor", w_scale)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Although the linear_method is propagated to all layers, # Although the quant_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check # only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine # whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization. # whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"): if not hasattr(layer, "weight_scaling_factor"):
return return
qweight, weight_scale = per_tensor_quantize(layer.weight) qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
# torch._scaled_mm requires column-major in the second # torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight. # input (weight), so we transpose the quantized weight.
layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scaling_factor.data.copy_(weight_scale) layer.weight_scaling_factor.data.copy_(weight_scale)
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qinput, x_scale = per_tensor_quantize(x) qinput, x_scale = ops.scaled_fp8_quant(x)
output, _ = torch._scaled_mm( output, _ = torch._scaled_mm(
qinput, qinput,
layer.weight, layer.weight,
...@@ -113,27 +117,3 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -113,27 +117,3 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
) )
return output return output
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
...@@ -7,10 +7,10 @@ import torch ...@@ -7,10 +7,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class GPTQConfig(QuantizationConfig): class GPTQConfig(QuantizationConfig):
...@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig): ...@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
desc_act = cls.get_from_keys(config, ["desc_act"]) desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act) return cls(weight_bits, group_size, desc_act)
def get_linear_method(self) -> "GPTQLinearMethod": def get_quant_method(
return GPTQLinearMethod(self) self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase):
layer.exllama_state = exllama_state layer.exllama_state = exllama_state
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
......
...@@ -4,10 +4,10 @@ import torch ...@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class MarlinConfig(QuantizationConfig): class MarlinConfig(QuantizationConfig):
...@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig): ...@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size) return cls(group_size)
def get_linear_method(self) -> "MarlinLinearMethod": def get_quant_method(
return MarlinLinearMethod(self) self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return MarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
layer.register_parameter("workspace", workspace) layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs) set_weight_attrs(workspace, extra_weight_attrs)
def apply_weights( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -4,10 +4,10 @@ import torch ...@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip from vllm.utils import is_hip
...@@ -51,14 +51,18 @@ class SqueezeLLMConfig(QuantizationConfig): ...@@ -51,14 +51,18 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["wbits"]) weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits) return cls(weight_bits)
def get_linear_method(self) -> "SqueezeLLMLinearMethod": def get_quant_method(
return SqueezeLLMLinearMethod(self) self,
layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
class SqueezeLLMLinearMethod(LinearMethodBase): class SqueezeLLMLinearMethod(QuantizeMethodBase):
"""Linear method for SqueezeLLM. """Linear method for SqueezeLLM.
Args: Args:
...@@ -112,10 +116,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -112,10 +116,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
layer.register_parameter("lookup_table", lookup_table) layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs) set_weight_attrs(lookup_table, extra_weight_attrs)
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight qweight = layer.qweight
lookup_table = layer.lookup_table lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
......
...@@ -3,8 +3,7 @@ import copy ...@@ -3,8 +3,7 @@ import copy
import glob import glob
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, from typing import Any, Dict, Generator, List, Optional, Tuple, Type
Type)
import torch import torch
from torch import nn from torch import nn
...@@ -13,6 +12,8 @@ from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig, ...@@ -13,6 +12,8 @@ from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig) SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
tensorizer_weights_iterator) tensorizer_weights_iterator)
...@@ -24,9 +25,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -24,9 +25,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator, safetensors_weights_iterator) pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration from vllm.model_executor.models.llava import LlavaForConditionalGeneration
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearMethodBase
_VISION_MODEL_CLASSES = [ _VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
] ]
...@@ -34,11 +32,10 @@ _VISION_MODEL_CLASSES = [ ...@@ -34,11 +32,10 @@ _VISION_MODEL_CLASSES = [
logger = init_logger(__name__) logger = init_logger(__name__)
def _get_linear_method( def _get_quantization_config(
model_config: ModelConfig, model_config: ModelConfig,
load_config: LoadConfig) -> Optional["LinearMethodBase"]: load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the (maybe quantized) linear method.""" """Get the quantization config."""
linear_method = None
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
...@@ -55,9 +52,8 @@ def _get_linear_method( ...@@ -55,9 +52,8 @@ def _get_linear_method(
f"{model_config.dtype} is not supported for quantization " f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: " f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}") f"{supported_dtypes}")
return quant_config
linear_method = quant_config.get_linear_method() return None
return linear_method
def _get_model_initialization_kwargs( def _get_model_initialization_kwargs(
...@@ -85,10 +81,10 @@ def _initialize_model( ...@@ -85,10 +81,10 @@ def _initialize_model(
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, load_config) quant_config = _get_quantization_config(model_config, load_config)
return model_class(config=model_config.hf_config, return model_class(config=model_config.hf_config,
linear_method=linear_method, quant_config=quant_config,
**_get_model_initialization_kwargs( **_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)) model_class, lora_config, vision_language_config))
...@@ -229,9 +225,11 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -229,9 +225,11 @@ class DefaultModelLoader(BaseModelLoader):
"fall_back_to_pt_during_load", "fall_back_to_pt_during_load",
True)), ) True)), )
for _, module in model.named_modules(): for _, module in model.named_modules():
linear_method = getattr(module, "linear_method", None) quant_method = getattr(module, "quant_method", None)
if linear_method is not None: if quant_method is not None:
linear_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"): if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading() module.process_weights_after_loading()
return model.eval() return model.eval()
...@@ -314,11 +312,11 @@ class TensorizerLoader(BaseModelLoader): ...@@ -314,11 +312,11 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, quant_config = _get_quantization_config(
self.load_config) model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs( extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config) model_class, lora_config, vision_language_config)
extra_kwargs["linear_method"] = linear_method extra_kwargs["quant_config"] = quant_config
tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class tensorizer_config.model_class = model_class
......
...@@ -13,7 +13,8 @@ from transformers import PretrainedConfig ...@@ -13,7 +13,8 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig from vllm.config import ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -251,7 +252,7 @@ class TensorizerAgent: ...@@ -251,7 +252,7 @@ class TensorizerAgent:
""" """
def __init__(self, tensorizer_config: TensorizerConfig, def __init__(self, tensorizer_config: TensorizerConfig,
linear_method: LinearMethodBase, **extra_kwargs): quant_config: QuantizationConfig, **extra_kwargs):
if tensorizer_load_fail is not None: if tensorizer_load_fail is not None:
raise ImportError( raise ImportError(
"Tensorizer is not installed. Please install tensorizer " "Tensorizer is not installed. Please install tensorizer "
...@@ -262,10 +263,10 @@ class TensorizerAgent: ...@@ -262,10 +263,10 @@ class TensorizerAgent:
self.tensorizer_args = ( self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args()) self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs self.extra_kwargs = extra_kwargs
if extra_kwargs.get("linear_method", None) is not None: if extra_kwargs.get("quant_config", None) is not None:
self.linear_method = extra_kwargs["linear_method"] self.quant_config = extra_kwargs["quant_config"]
else: else:
self.linear_method = linear_method self.quant_config = quant_config
self.model = self._init_model() self.model = self._init_model()
def _init_model(self): def _init_model(self):
...@@ -274,7 +275,7 @@ class TensorizerAgent: ...@@ -274,7 +275,7 @@ class TensorizerAgent:
with no_init_or_tensor(): with no_init_or_tensor():
return self.tensorizer_config.model_class( return self.tensorizer_config.model_class(
config=model_args, config=model_args,
linear_method=self.linear_method, quant_config=self.quant_config,
**self.extra_kwargs) **self.extra_kwargs)
def _resize_lora_embeddings(self): def _resize_lora_embeddings(self):
......
...@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module): ...@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module): ...@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str, position_embedding: str,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module): ...@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_heads, self.total_num_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI": if self.postion_embedding == "ALIBI":
...@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding, position_embedding=position_embedding,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module): ...@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module): ...@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, linear_method) BaiChuanDecoderLayer(config, position_embedding, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
config, config,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, linear_method) self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method, lora_config) super().__init__(config, "ROPE", quant_config, lora_config)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method, lora_config) super().__init__(config, "ALIBI", quant_config, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
...@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__(config, "ROPE", linear_method, lora_config) super().__init__(config, "ROPE", quant_config, lora_config)
...@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -70,7 +71,7 @@ class BloomAttention(nn.Module): ...@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -87,13 +88,13 @@ class BloomAttention(nn.Module): ...@@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
...@@ -129,21 +130,21 @@ class BloomMLP(nn.Module): ...@@ -129,21 +130,21 @@ class BloomMLP(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
hidden_size, hidden_size,
4 * hidden_size, 4 * hidden_size,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None) quant_config = getattr(quant_config, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,
linear_method=linear_method, quant_config=quant_config,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -158,17 +159,17 @@ class BloomBlock(nn.Module): ...@@ -158,17 +159,17 @@ class BloomBlock(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size, self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, linear_method) self.self_attention = BloomAttention(config, quant_config)
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon) hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, linear_method) self.mlp = BloomMLP(config, quant_config)
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm) config.apply_residual_connection_post_layernorm)
...@@ -214,7 +215,7 @@ class BloomModel(nn.Module): ...@@ -214,7 +215,7 @@ class BloomModel(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -229,7 +230,7 @@ class BloomModel(nn.Module): ...@@ -229,7 +230,7 @@ class BloomModel(nn.Module):
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.h = nn.ModuleList([
BloomBlock(config, linear_method) BloomBlock(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
...@@ -262,12 +263,12 @@ class BloomForCausalLM(nn.Module): ...@@ -262,12 +263,12 @@ class BloomForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = BloomModel(config, linear_method) self.transformer = BloomModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig ...@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -33,7 +34,7 @@ class GLMAttention(nn.Module): ...@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -65,13 +66,13 @@ class GLMAttention(nn.Module): ...@@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias, bias=config.add_bias_linear or config.add_qkv_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
linear_method=linear_method, quant_config=quant_config,
) )
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
...@@ -123,7 +124,7 @@ class GLMMLP(nn.Module): ...@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -134,7 +135,7 @@ class GLMMLP(nn.Module): ...@@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
config.hidden_size, config.hidden_size,
[config.ffn_hidden_size] * 2, [config.ffn_hidden_size] * 2,
bias=config.add_bias_linear, bias=config.add_bias_linear,
linear_method=linear_method, quant_config=quant_config,
) )
self.activation_func = SiluAndMul() self.activation_func = SiluAndMul()
...@@ -144,7 +145,7 @@ class GLMMLP(nn.Module): ...@@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
config.ffn_hidden_size, config.ffn_hidden_size,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
linear_method=linear_method, quant_config=quant_config,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -166,7 +167,7 @@ class GLMBlock(nn.Module): ...@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
...@@ -180,7 +181,7 @@ class GLMBlock(nn.Module): ...@@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon) eps=config.layernorm_epsilon)
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, linear_method) self.self_attention = GLMAttention(config, quant_config)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
...@@ -188,7 +189,7 @@ class GLMBlock(nn.Module): ...@@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon)
# MLP # MLP
self.mlp = GLMMLP(config, linear_method) self.mlp = GLMMLP(config, quant_config)
def forward( def forward(
self, self,
...@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module): ...@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.post_layer_norm = config.post_layer_norm self.post_layer_norm = config.post_layer_norm
...@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module): ...@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[GLMBlock(config, linear_method) for i in range(self.num_layers)]) [GLMBlock(config, quant_config) for i in range(self.num_layers)])
if self.post_layer_norm: if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
...@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module): ...@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module): ...@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, linear_method) self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size, self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size) config.hidden_size)
...@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = ChatGLMModel(config, linear_method) self.transformer = ChatGLMModel(config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -91,7 +92,7 @@ class CohereMLP(nn.Module): ...@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -101,13 +102,13 @@ class CohereMLP(nn.Module): ...@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
self.hidden_size, self.hidden_size,
[self.intermediate_size] * 2, [self.intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
self.intermediate_size, self.intermediate_size,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
...@@ -123,7 +124,7 @@ class CohereAttention(nn.Module): ...@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -158,13 +159,13 @@ class CohereAttention(nn.Module): ...@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module): ...@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: CohereConfig, config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, linear_method=linear_method) self.self_attn = CohereAttention(config, quant_config=quant_config)
self.mlp = CohereMLP(config, linear_method=linear_method) self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -257,7 +258,7 @@ class CohereModel(nn.Module): ...@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -265,7 +266,7 @@ class CohereModel(nn.Module): ...@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
CohereDecoderLayer(config, linear_method=linear_method) CohereDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = LayerNorm(param_shape=(config.hidden_size), self.norm = LayerNorm(param_shape=(config.hidden_size),
...@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module): ...@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size, self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale) scale=config.logit_scale)
self.model = CohereModel(config, linear_method) self.model = CohereModel(config, quant_config)
self.sampler = Sampler() self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
......
...@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (QKVParallelLinear,
QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module): ...@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
params_dtype=params_dtype, params_dtype=params_dtype,
linear_method=None, quant_config=None,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module): ...@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
): ):
super().__init__() super().__init__()
...@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module): ...@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module): ...@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
self.d_model, self.d_model,
self.d_model, self.d_model,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module): ...@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
self.attn = DbrxAttention(config, linear_method) self.attn = DbrxAttention(config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model) self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model)
...@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module): ...@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method) self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
self.ffn = DbrxExperts(config, linear_method) self.ffn = DbrxExperts(config, quant_config)
def forward( def forward(
self, self,
...@@ -307,7 +308,7 @@ class DbrxModel(nn.Module): ...@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
...@@ -315,7 +316,7 @@ class DbrxModel(nn.Module): ...@@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[DbrxBlock(config, linear_method) for _ in range(config.n_layers)]) [DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules(): for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias, if hasattr(module, "bias") and isinstance(module.bias,
...@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module): ...@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, linear_method) self.transformer = DbrxModel(config, quant_config)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
......
...@@ -29,7 +29,8 @@ import torch ...@@ -29,7 +29,8 @@ import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
...@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__( def __init__(
self, self,
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer) config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, super().__init__(config=config,
linear_method=linear_method, quant_config=quant_config,
lora_config=lora_config) lora_config=lora_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment