Unverified Commit 7076fa1c authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

TP/quantization/weight loading refactor part 2 - Refactor quantized linear...

TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
parent 660a7fcf
...@@ -140,8 +140,8 @@ class ModelConfig: ...@@ -140,8 +140,8 @@ class ModelConfig:
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the number of KV heads per GPU worker.""" """Returns the total number of KV heads."""
# For GPTBigCode & Falcon: # For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the # NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of # multi_query flag is ignored and we use n_head_kv for the number of
...@@ -155,23 +155,34 @@ class ModelConfig: ...@@ -155,23 +155,34 @@ class ModelConfig:
# Multi-query attention, only one KV head. # Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case. # Currently, tensor parallelism is not supported in this case.
return 1 return 1
attributes = [
# For Falcon: # For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None: "n_head_kv",
return (self.hf_config.n_head_kv // "num_kv_heads",
parallel_config.tensor_parallel_size)
if getattr(self.hf_config, "num_kv_heads", None) is not None:
return (self.hf_config.num_kv_heads //
parallel_config.tensor_parallel_size)
# For LLaMA-2: # For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None: "num_key_value_heads",
return (self.hf_config.num_key_value_heads // # For ChatGLM:
parallel_config.tensor_parallel_size) "multi_query_group_num",
# For ChatGLM-2: ]
if getattr(self.hf_config, "multi_query_group_num", None) is not None: for attr in attributes:
return (self.hf_config.multi_query_group_num // num_kv_heads = getattr(self.hf_config, attr, None)
parallel_config.tensor_parallel_size) if num_kv_heads is not None:
total_num_attention_heads = self.hf_config.num_attention_heads return num_kv_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers total_num_hidden_layers = self.hf_config.num_hidden_layers
......
...@@ -142,10 +142,10 @@ class RequestTracker: ...@@ -142,10 +142,10 @@ class RequestTracker:
self._request_streams[request_id].finish() self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]: def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be """Get the new requests and finished requests to be
sent to the engine.""" sent to the engine."""
new_requests: List[dict] = [] new_requests: List[Dict] = []
finished_requests: Set[str] = set() finished_requests: Set[str] = set()
while not self._finished_requests.empty(): while not self._finished_requests.empty():
......
This diff is collapsed.
from typing import Type
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
_QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig,
"squeezellm": SqueezeLLMConfig,
}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quantization_config",
]
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
def get_name(self) -> str:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
@staticmethod
def get_config_filenames() -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
def get_linear_method(self) -> "AWQLinearMethod":
return AWQLinearMethod(self)
class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.
Args:
quant_config: The AWQ quantization config.
"""
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
if input_size % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
qweight = Parameter(
torch.empty(
input_size,
output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
qzeros = Parameter(
torch.empty(
input_size // self.quant_config.group_size,
output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
input_size // self.quant_config.group_size,
output_size,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
return {
"qweight": qweight,
"qzeros": qzeros,
"scales": scales,
}
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
qzeros = weights["qzeros"]
scales = weights["scales"]
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
from typing import Any, Dict, List, Optional from abc import ABC, abstractmethod
from typing import Any, Dict, List
import torch import torch
from vllm.model_executor.layers.linear import LinearMethodBase
class QuantizationConfig:
@classmethod class QuantizationConfig(ABC):
def get_name(cls) -> str: """Base class for quantization configs."""
@abstractmethod
def get_name(self) -> str:
"""Name of the quantization method.""" """Name of the quantization method."""
raise NotImplementedError raise NotImplementedError
@classmethod @abstractmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
"""List of supported activation dtypes.""" """List of supported activation dtypes."""
raise NotImplementedError raise NotImplementedError
@classmethod @abstractmethod
def get_min_capability(cls) -> int: def get_min_capability(self) -> int:
"""Minimum GPU capability to support the quantization method. """Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
...@@ -25,12 +29,14 @@ class QuantizationConfig: ...@@ -25,12 +29,14 @@ class QuantizationConfig:
""" """
raise NotImplementedError raise NotImplementedError
@classmethod @staticmethod
def get_config_filenames(cls) -> List[str]: @abstractmethod
def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory.""" """List of filenames to search for in the model directory."""
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config.""" """Create a config class from the model's quantization config."""
raise NotImplementedError raise NotImplementedError
...@@ -44,42 +50,7 @@ class QuantizationConfig: ...@@ -44,42 +50,7 @@ class QuantizationConfig:
raise ValueError(f"Cannot find any of {keys} in the model's " raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.") "quantization config.")
@classmethod @abstractmethod
def get_packed_tensors(cls) -> Dict[str, int]: def get_linear_method(self) -> LinearMethodBase:
"""Returns a dictionary of packed tensor names and their pack dims.""" """Get the linear method to use for the quantized linear layer."""
raise NotImplementedError
@classmethod
def get_packed_dim(cls, tensor_name: str) -> Optional[int]:
"""Returns the pack dim of a tensor if it is packed.
A tensor is considered packed if each element in the tensor is a
packed representation of multiple elements in the original tensor.
For example, an INT32 element in the tensor may represent 8 INT4
elements in the original tensor.
If the tensor is not packed, returns None.
"""
packed_tensors = cls.get_packed_tensors()
for packed_tensor_name, pack_dim in packed_tensors.items():
if packed_tensor_name in tensor_name:
return pack_dim
return None
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_transposed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
"""
return any(tag in tensor_name
for tag in cls.get_transposed_tensor_names())
@classmethod
def get_col_parallel_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def get_row_parallel_tensor_names(cls) -> List[str]:
raise NotImplementedError raise NotImplementedError
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class SqueezeLLMConfig(QuantizationConfig):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def __init__(
self,
weight_bits: int,
) -> None:
self.weight_bits = weight_bits
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"SqueezeLLM, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
def get_name(self) -> str:
return "squeezellm"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
return 70
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
return SqueezeLLMLinearMethod(self)
class SqueezeLLMLinearMethod(LinearMethodBase):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
if input_size % self.quant_config.pack_factor != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
qweight = Parameter(
torch.empty(
input_size // self.quant_config.pack_factor,
output_size,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
lookup_table = Parameter(
torch.empty(
output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(lookup_table, {
"output_dim": 0,
})
return {
"qweight": qweight,
"lookup_table": lookup_table,
}
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
lookup_table = weights["lookup_table"]
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
lookup_table)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.layers.quantized_linear.squeezellm import (
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
"squeezellm":
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
}
class ParallelLinear:
@classmethod
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return ColumnParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
return quant_linear_cls(*args, **kwargs)
@classmethod
def row(cls, *args, **kwargs) -> RowParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return RowParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
return quant_linear_cls(*args, **kwargs)
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.group_size == 0
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class AWQRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.output_size % self.quant_config.pack_factor == 0
if self.input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
return out.reshape(out_shape)
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
class SqueezeLLMColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size // self.quant_config.pack_factor,
self.output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size_per_partition,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class SqueezeLLMRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
if self.input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.pack_factor,
self.output_size,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)
return out.reshape(out_shape)
from typing import Optional, Sequence
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
rank: int) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank)
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
vocab_range_from_global_vocab_size(
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.weight, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None):
super().__init__(num_embeddings, embedding_dim, params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
else:
self.register_parameter("bias", None)
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")
...@@ -37,13 +37,6 @@ _MODEL_REGISTRY = { ...@@ -37,13 +37,6 @@ _MODEL_REGISTRY = {
"YiForCausalLM": YiForCausalLM, "YiForCausalLM": YiForCausalLM,
} }
# FIXME(woosuk): Remove this once all models support quantization.
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
LlamaForCausalLM,
MistralForCausalLM,
YiForCausalLM,
]
@contextlib.contextmanager @contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype): def _set_default_torch_dtype(dtype: torch.dtype):
...@@ -67,12 +60,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -67,12 +60,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config. # Get the (maybe quantized) linear method.
quant_config = None linear_method = None
if model_config.quantization is not None: if model_config.quantization is not None:
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
raise ValueError(
f"Quantization is not supported for {model_class}.")
quant_config = get_quant_config(model_config.quantization, quant_config = get_quant_config(model_config.quantization,
model_config.model, model_config.model,
model_config.download_dir) model_config.download_dir)
...@@ -90,14 +80,12 @@ def get_model(model_config: ModelConfig) -> nn.Module: ...@@ -90,14 +80,12 @@ def get_model(model_config: ModelConfig) -> nn.Module:
f"{model_config.dtype} is not supported for quantization " f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: " f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}") f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: model = model_class(model_config.hf_config, linear_method)
model = model_class(model_config.hf_config, quant_config)
else:
model = model_class(model_config.hf_config)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
model = model.cuda() model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
......
...@@ -33,15 +33,17 @@ from torch import nn ...@@ -33,15 +33,17 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab, VocabParallelEmbedding, ParallelLMHead)
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig from vllm.transformers_utils.configs.aquila import AquilaConfig
...@@ -55,20 +57,17 @@ class AquilaMLP(nn.Module): ...@@ -55,20 +57,17 @@ class AquilaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, hidden_size, [intermediate_size] * 2,
2 * intermediate_size,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method)
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -111,6 +110,7 @@ class AquilaAttention(nn.Module): ...@@ -111,6 +110,7 @@ class AquilaAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -128,29 +128,29 @@ class AquilaAttention(nn.Module): ...@@ -128,29 +128,29 @@ class AquilaAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
rotary_dim=self.head_dim,
base=self.rope_theta, base=self.rope_theta,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling, rope_scaling=rope_scaling)
)
def forward( def forward(
self, self,
...@@ -171,7 +171,11 @@ class AquilaAttention(nn.Module): ...@@ -171,7 +171,11 @@ class AquilaAttention(nn.Module):
class AquilaDecoderLayer(nn.Module): class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig): def __init__(
self,
config: AquilaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -185,11 +189,13 @@ class AquilaDecoderLayer(nn.Module): ...@@ -185,11 +189,13 @@ class AquilaDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
linear_method=linear_method,
) )
self.mlp = AquilaMLP( self.mlp = AquilaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method,
) )
self.input_layernorm = AquilaRMSNorm(config.hidden_size, self.input_layernorm = AquilaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -226,19 +232,22 @@ class AquilaDecoderLayer(nn.Module): ...@@ -226,19 +232,22 @@ class AquilaDecoderLayer(nn.Module):
class AquilaModel(nn.Module): class AquilaModel(nn.Module):
def __init__(self, config: AquilaConfig): def __init__(
self,
config: AquilaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
#vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers) AquilaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
]) ])
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -271,17 +280,16 @@ class AquilaModel(nn.Module): ...@@ -271,17 +280,16 @@ class AquilaModel(nn.Module):
class AquilaForCausalLM(nn.Module): class AquilaForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = AquilaModel(config) self.linear_method = linear_method
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.model = AquilaModel(config, linear_method)
self.lm_head = ColumnParallelLinear( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
...@@ -298,79 +306,33 @@ class AquilaForCausalLM(nn.Module): ...@@ -298,79 +306,33 @@ class AquilaForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size() stacked_params_mapping = [
tensor_model_parallel_rank = get_tensor_model_parallel_rank() # (param_name, shard_name, shard_id)
q_proj_shard_size = (self.config.hidden_size // tp_size) ("qkv_proj", "q_proj", "q"),
kv_proj_shard_size = (self.config.hidden_size // ("qkv_proj", "k_proj", "k"),
self.config.num_attention_heads * ("qkv_proj", "v_proj", "v"),
self.config.num_key_value_heads // tp_size) ("gate_up_proj", "gate_proj", 0),
attention_weight_specs = [ ("gate_up_proj", "up_proj", 1),
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
] ]
state_dict = self.state_dict() params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "qkv_proj")] param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
loaded_weight = loaded_weight[ weight_loader(param, loaded_weight, shard_id)
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break break
if is_attention_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
is_gate_up_weight = False default_weight_loader)
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): weight_loader(param, loaded_weight)
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
...@@ -30,18 +30,20 @@ from torch import nn ...@@ -30,18 +30,20 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi) PagedAttentionWithALiBi)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
convert_pyslice_to_tensor, hf_model_weights_iterator, VocabParallelEmbedding, ParallelLMHead)
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
...@@ -80,20 +82,17 @@ class BaiChuanMLP(nn.Module): ...@@ -80,20 +82,17 @@ class BaiChuanMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, hidden_size, [intermediate_size] * 2,
2 * intermediate_size,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method)
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -116,6 +115,7 @@ class BaiChuanAttention(nn.Module): ...@@ -116,6 +115,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str, position_embedding: str,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -131,17 +131,19 @@ class BaiChuanAttention(nn.Module): ...@@ -131,17 +131,19 @@ class BaiChuanAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear( self.W_pack = QKVParallelLinear(
hidden_size, hidden_size,
3 * hidden_size, self.head_dim,
self.total_num_heads,
self.total_num_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI": if self.postion_embedding == "ALIBI":
...@@ -188,7 +190,10 @@ class BaiChuanAttention(nn.Module): ...@@ -188,7 +190,10 @@ class BaiChuanAttention(nn.Module):
class BaiChuanDecoderLayer(nn.Module): class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str): def __init__(self,
config: BaiChuanConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -200,11 +205,13 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -200,11 +205,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding, position_embedding=position_embedding,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -241,7 +248,10 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -241,7 +248,10 @@ class BaiChuanDecoderLayer(nn.Module):
class BaiChuanModel(nn.Module): class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str): def __init__(self,
config: BaiChuanConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -252,7 +262,7 @@ class BaiChuanModel(nn.Module): ...@@ -252,7 +262,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding) BaiChuanDecoderLayer(config, position_embedding, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -285,16 +295,15 @@ class BaiChuanModel(nn.Module): ...@@ -285,16 +295,15 @@ class BaiChuanModel(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module): class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config, position_embedding: str): def __init__(self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = BaiChuanModel(config, position_embedding) self.linear_method = linear_method
self.lm_head = ColumnParallelLinear( self.model = BaiChuanModel(config, position_embedding, linear_method)
config.hidden_size, self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
config.vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
...@@ -311,79 +320,46 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -311,79 +320,46 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = []
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() stacked_params_mapping = [
tp_rank = get_tensor_model_parallel_rank() # (param_name, shard_name, shard_id)
state_dict = self.state_dict() ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = params_dict[name.replace(weight_name, param_name)]
shard_size = param.shape[0] // 2 weight_loader = param.weight_loader
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * weight_loader(param, loaded_weight, shard_id)
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break break
if is_gate_up_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
weight_loader(param, loaded_weight)
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config): def __init__(self,
super().__init__(config, "ALIBI") config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ALIBI", linear_method)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config): def __init__(self,
super().__init__(config, "ROPE") config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ROPE", linear_method)
...@@ -30,14 +30,17 @@ from transformers import BloomConfig ...@@ -30,14 +30,17 @@ from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.layers.vocab_parallel_embedding import (
load_tensor_parallel_weights) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -70,7 +73,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: ...@@ -70,7 +73,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
class BloomAttention(nn.Module): class BloomAttention(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.total_num_heads = config.n_head self.total_num_heads = config.n_head
...@@ -81,17 +88,18 @@ class BloomAttention(nn.Module): ...@@ -81,17 +88,18 @@ class BloomAttention(nn.Module):
assert self.total_num_heads % tp_world_size == 0 assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size self.num_heads = self.total_num_heads // tp_world_size
self.query_key_value = ColumnParallelLinear( self.query_key_value = QKVParallelLinear(
self.hidden_size, self.hidden_size,
3 * self.hidden_size, self.head_dim,
self.total_num_heads,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
...@@ -125,19 +133,23 @@ class BloomAttention(nn.Module): ...@@ -125,19 +133,23 @@ class BloomAttention(nn.Module):
class BloomMLP(nn.Module): class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
hidden_size, hidden_size,
4 * hidden_size, 4 * hidden_size,
gather_output=False, linear_method=linear_method,
) )
self.act = get_act_fn("gelu") self.act = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,
input_is_parallel=True, linear_method=linear_method,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -149,16 +161,20 @@ class BloomMLP(nn.Module): ...@@ -149,16 +161,20 @@ class BloomMLP(nn.Module):
class BloomBlock(nn.Module): class BloomBlock(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size, self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config) self.self_attention = BloomAttention(config, linear_method)
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon) hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config) self.mlp = BloomMLP(config, linear_method)
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm) config.apply_residual_connection_post_layernorm)
...@@ -203,7 +219,11 @@ class BloomBlock(nn.Module): ...@@ -203,7 +219,11 @@ class BloomBlock(nn.Module):
class BloomModel(nn.Module): class BloomModel(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -216,8 +236,10 @@ class BloomModel(nn.Module): ...@@ -216,8 +236,10 @@ class BloomModel(nn.Module):
self.embed_dim, eps=config.layer_norm_epsilon) self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList( self.h = nn.ModuleList([
[BloomBlock(config) for _ in range(config.num_hidden_layers)]) BloomBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
# Final Layer Norm # Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -251,12 +273,15 @@ class BloomModel(nn.Module): ...@@ -251,12 +273,15 @@ class BloomModel(nn.Module):
class BloomForCausalLM(nn.Module): class BloomForCausalLM(nn.Module):
def __init__(self, config: BloomConfig): def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = BloomModel(config) self.linear_method = linear_method
# TODO(zhuohan): create a new weight after implementing pipeline self.transformer = BloomModel(config, linear_method)
# parallelism
self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head_weight = self.transformer.word_embeddings.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -274,55 +299,36 @@ class BloomForCausalLM(nn.Module): ...@@ -274,55 +299,36 @@ class BloomForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() params_dict = dict(self.named_parameters(remove_duplicate=False))
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if name == "lm_head.weight": if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to continue
# load lm_head.weight in parallel.
self._column_parallel_weights.append(name)
# If lm_head is provided, use it instead.
param = self.lm_head_weight
else:
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
param = state_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of # NOTE: BLOOM's fused QKV's output_dim has the shape of
# [num_heads * 3 * head_size, hidden_size], while the # (num_heads * 3 * head_size), while the
# required shape is [3 * num_heads * head_size, hidden_size]. # required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion. # Thus, we need weight conversion.
shard_size = param.shape[0] output_dim = getattr(param, "output_dim", None)
start = shard_size * tp_rank
end = shard_size * (tp_rank + 1)
loaded_weight = loaded_weight[start:end]
num_heads = self.config.num_attention_heads num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size if output_dim is not None:
head_size = hidden_size // num_heads loaded_weight_shape = loaded_weight.shape
if "query_key_value.weight" in name: loaded_weight = loaded_weight.view(
loaded_weight = loaded_weight.view(-1, 3, head_size, loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
hidden_size) loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(
loaded_weight = loaded_weight.reshape(-1, hidden_size) output_dim, output_dim + 1)
elif "query_key_value.bias" in name: loaded_weight = loaded_weight.reshape(loaded_weight_shape)
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1) weight_loader = getattr(param, "weight_loader",
loaded_weight = loaded_weight.reshape(-1) default_weight_loader)
else: weight_loader(param, loaded_weight)
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
...@@ -6,32 +6,28 @@ ...@@ -6,32 +6,28 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
hf_model_weights_iterator, VocabParallelEmbedding, ParallelLMHead)
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size, from vllm.model_executor.weight_utils import (default_weight_loader,
) hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput
from vllm.model_executor.parallel_utils.layers import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -39,7 +35,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] ...@@ -39,7 +35,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -50,25 +50,33 @@ class GLMAttention(nn.Module): ...@@ -50,25 +50,33 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads = (config.multi_query_group_num self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else if config.multi_query_attention else
config.num_attention_heads) config.num_attention_heads)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0 assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.query_key_value = ColumnParallelLinear( self.query_key_value = QKVParallelLinear(
config.hidden_size, self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
bias=config.add_qkv_bias, self.total_num_heads,
gather_output=False, self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
linear_method=linear_method,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
...@@ -78,7 +86,6 @@ class GLMAttention(nn.Module): ...@@ -78,7 +86,6 @@ class GLMAttention(nn.Module):
rotary_dim=self.head_dim // 2, rotary_dim=self.head_dim // 2,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
is_neox_style=False, is_neox_style=False,
# is_glm_style=True
) )
def forward( def forward(
...@@ -117,17 +124,21 @@ class GLMMLP(nn.Module): ...@@ -117,17 +124,21 @@ class GLMMLP(nn.Module):
state back into h hidden dimension. state back into h hidden dimension.
""" """
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.add_bias = config.add_bias_linear self.add_bias = config.add_bias_linear
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = MergedColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.ffn_hidden_size * 2, [config.ffn_hidden_size] * 2,
bias=config.add_bias_linear, bias=config.add_bias_linear,
gather_output=False, linear_method=linear_method,
) )
self.activation_func = SiluAndMul() self.activation_func = SiluAndMul()
...@@ -137,7 +148,7 @@ class GLMMLP(nn.Module): ...@@ -137,7 +148,7 @@ class GLMMLP(nn.Module):
config.ffn_hidden_size, config.ffn_hidden_size,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
input_is_parallel=True, linear_method=linear_method,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -159,6 +170,7 @@ class GLMBlock(nn.Module): ...@@ -159,6 +170,7 @@ class GLMBlock(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
...@@ -172,7 +184,7 @@ class GLMBlock(nn.Module): ...@@ -172,7 +184,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon) eps=config.layernorm_epsilon)
# Self attention. # Self attention.
self.self_attention = GLMAttention(config) self.self_attention = GLMAttention(config, linear_method)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
...@@ -180,7 +192,7 @@ class GLMBlock(nn.Module): ...@@ -180,7 +192,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon)
# MLP # MLP
self.mlp = GLMMLP(config) self.mlp = GLMMLP(config, linear_method)
def forward( def forward(
self, self,
...@@ -227,7 +239,11 @@ class GLMBlock(nn.Module): ...@@ -227,7 +239,11 @@ class GLMBlock(nn.Module):
class GLMTransformer(nn.Module): class GLMTransformer(nn.Module):
"""Transformer class.""" """Transformer class."""
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.post_layer_norm = config.post_layer_norm self.post_layer_norm = config.post_layer_norm
...@@ -236,7 +252,7 @@ class GLMTransformer(nn.Module): ...@@ -236,7 +252,7 @@ class GLMTransformer(nn.Module):
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[GLMBlock(config) for i in range(self.num_layers)]) [GLMBlock(config, linear_method) for i in range(self.num_layers)])
if self.post_layer_norm: if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
...@@ -274,7 +290,11 @@ class GLMTransformer(nn.Module): ...@@ -274,7 +290,11 @@ class GLMTransformer(nn.Module):
class ChatGLMModel(nn.Module): class ChatGLMModel(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size, self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
...@@ -283,15 +303,10 @@ class ChatGLMModel(nn.Module): ...@@ -283,15 +303,10 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config) self.encoder = GLMTransformer(config, linear_method)
self.output_layer = ColumnParallelLinear( self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size, config.hidden_size)
config.padded_vocab_size,
bias=False,
gather_output=False,
params_dtype=config.torch_dtype,
)
def forward( def forward(
self, self,
...@@ -317,10 +332,15 @@ class ChatGLMModel(nn.Module): ...@@ -317,10 +332,15 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module): class ChatGLMForCausalLM(nn.Module):
def __init__(self, config: ChatGLMConfig): def __init__(
self,
config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.transformer = ChatGLMModel(config) self.linear_method = linear_method
self.transformer = ChatGLMModel(config, linear_method)
self.lm_head_weight = self.transformer.output_layer.weight self.lm_head_weight = self.transformer.output_layer.weight
self.sampler = Sampler(config.padded_vocab_size) self.sampler = Sampler(config.padded_vocab_size)
...@@ -331,78 +351,26 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -331,78 +351,26 @@ class ChatGLMForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [ def load_weights(self,
"output_layer.weight",
"embedding.weight",
]
_row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"]
def load_weights(
self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None, revision: Optional[str] = None):
): params_dict = dict(self.named_parameters(remove_duplicate=False))
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
q_proj_shard_size = self.config.hidden_size // tp_size
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.multi_query_group_num // tp_size)
mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name: if "word_embeddings" in name:
name = name.replace(".word_embeddings", "") name = name.replace(".word_embeddings", "")
param = params_dict[name]
if name in state_dict: weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
if "query_key_value" in name: weight_loader(param, loaded_weight)
q_offset = q_proj_shard_size * tp_rank
k_offset = (q_proj_shard_size * tp_size +
kv_proj_shard_size * tp_rank)
v_offset = (q_proj_shard_size * tp_size +
kv_proj_shard_size * (tp_size + tp_rank))
wq = loaded_weight[q_offset:q_offset + q_proj_shard_size]
wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size]
wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param.data.copy_(loaded_weight)
continue
if "dense_h_to_4h" in name:
w_gate = loaded_weight[mlp_hidden_shard_size *
tp_rank:mlp_hidden_shard_size *
(tp_rank + 1)]
w_proj = loaded_weight[mlp_hidden_shard_size *
(tp_size +
tp_rank):mlp_hidden_shard_size *
(tp_size + tp_rank + 1)]
loaded_weight = torch.cat([w_gate, w_proj], dim=0)
param.data.copy_(loaded_weight)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
elif name == "transformer.rotary_pos_emb.inv_freq":
continue
else:
print("Warning never found tensor's name:", name)
...@@ -30,17 +30,19 @@ from vllm.model_executor.input_metadata import InputMetadata ...@@ -30,17 +30,19 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import (PagedAttention, from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi, PagedAttentionWithALiBi,
PagedAttentionWithRoPE) PagedAttentionWithRoPE)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, LinearMethodBase,
hf_model_weights_iterator, QKVParallelLinear,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
...@@ -48,19 +50,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] ...@@ -48,19 +50,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
# training, this means that there's one additional quantization to bfloat16
# between the operations. In order not to degrade the quality of our HF-port,
# we keep these characteristics in the final model.
class FalconLinear(nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = x @ self.weight.T
if self.bias is None:
return hidden_states
return hidden_states + self.bias
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
...@@ -86,7 +75,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: ...@@ -86,7 +75,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
class FalconAttention(nn.Module): class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -103,41 +96,29 @@ class FalconAttention(nn.Module): ...@@ -103,41 +96,29 @@ class FalconAttention(nn.Module):
if self.new_decoder_architecture: if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads self.total_num_kv_heads = config.num_kv_heads
assert self.total_num_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
skip_bias_add=True,
)
elif self.multi_query: elif self.multi_query:
self.total_num_kv_heads = 1 self.total_num_kv_heads = 1
self.num_kv_heads = 1
self.query = ColumnParallelLinear(
self.hidden_size,
self.total_num_heads * self.head_dim,
bias=config.bias,
gather_output=False,
skip_bias_add=True,
)
self.key_value = FalconLinear(self.hidden_size,
2 * self.head_dim,
bias=config.bias)
else: else:
self.total_num_kv_heads = self.total_num_heads self.total_num_kv_heads = self.total_num_heads
self.num_kv_heads = self.num_heads if self.total_num_kv_heads >= tp_size:
self.query_key_value = ColumnParallelLinear( # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.query_key_value = QKVParallelLinear(
self.hidden_size, self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.bias, bias=config.bias,
gather_output=False,
skip_bias_add=True, skip_bias_add=True,
linear_method=linear_method,
) )
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
...@@ -149,7 +130,6 @@ class FalconAttention(nn.Module): ...@@ -149,7 +130,6 @@ class FalconAttention(nn.Module):
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=config.bias, bias=config.bias,
input_is_parallel=True,
skip_bias_add=True, skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results) reduce_results=self.reduce_row_parallel_results)
...@@ -196,18 +176,10 @@ class FalconAttention(nn.Module): ...@@ -196,18 +176,10 @@ class FalconAttention(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
if not self.new_decoder_architecture and self.multi_query:
q, bias = self.query(hidden_states)
if bias is not None:
q += bias
kv = self.key_value(hidden_states)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
else:
qkv, bias = self.query_key_value(hidden_states) qkv, bias = self.query_key_value(hidden_states)
if bias is not None: if bias is not None:
qkv += bias qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
if self.use_rotary: if self.use_rotary:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
...@@ -221,15 +193,19 @@ class FalconAttention(nn.Module): ...@@ -221,15 +193,19 @@ class FalconAttention(nn.Module):
class FalconMLP(nn.Module): class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size, self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size, 4 * hidden_size,
bias=config.bias, bias=config.bias,
gather_output=False, skip_bias_add=True,
skip_bias_add=True) linear_method=linear_method)
self.act = nn.GELU() self.act = nn.GELU()
self.reduce_row_parallel_results = not (config.new_decoder_architecture self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn) or config.parallel_attn)
...@@ -237,9 +213,9 @@ class FalconMLP(nn.Module): ...@@ -237,9 +213,9 @@ class FalconMLP(nn.Module):
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,
bias=config.bias, bias=config.bias,
input_is_parallel=True,
skip_bias_add=True, skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results) reduce_results=self.reduce_row_parallel_results,
linear_method=linear_method)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here. # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
...@@ -253,12 +229,16 @@ class FalconMLP(nn.Module): ...@@ -253,12 +229,16 @@ class FalconMLP(nn.Module):
class FalconDecoderLayer(nn.Module): class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config) self.self_attention = FalconAttention(config, linear_method)
self.mlp = FalconMLP(config) self.mlp = FalconMLP(config, linear_method)
self.config = config self.config = config
if config.new_decoder_architecture: if config.new_decoder_architecture:
...@@ -334,7 +314,11 @@ class FalconDecoderLayer(nn.Module): ...@@ -334,7 +314,11 @@ class FalconDecoderLayer(nn.Module):
class FalconModel(nn.Module): class FalconModel(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -349,7 +333,8 @@ class FalconModel(nn.Module): ...@@ -349,7 +333,8 @@ class FalconModel(nn.Module):
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.h = nn.ModuleList([
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers) FalconDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
]) ])
# Final Layer Norm # Final Layer Norm
...@@ -383,15 +368,18 @@ class FalconModel(nn.Module): ...@@ -383,15 +368,18 @@ class FalconModel(nn.Module):
class FalconForCausalLM(nn.Module): class FalconForCausalLM(nn.Module):
def __init__(self, config: FalconConfig): def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = FalconModel(config) self.linear_method = linear_method
self.lm_head = ColumnParallelLinear( self.transformer = FalconModel(config, linear_method)
config.hidden_size, self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
bias=False, config.hidden_size,
gather_output=False,
) )
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -415,89 +403,44 @@ class FalconForCausalLM(nn.Module): ...@@ -415,89 +403,44 @@ class FalconForCausalLM(nn.Module):
return next_tokens return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank()
hidden_size = self.config.hidden_size
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
num_heads = total_num_heads // tp_size
head_size = hidden_size // total_num_heads
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if self.config.new_decoder_architecture: if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads total_num_kv_heads = self.config.num_kv_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
elif self.config.multi_query: elif self.config.multi_query:
total_num_kv_heads = 1 total_num_kv_heads = 1
num_kv_heads = 1
separated_q_kv = True
kv_head_start = 0
kv_head_end = 1
else: else:
total_num_kv_heads = total_num_heads total_num_kv_heads = total_num_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
state_dict = self.state_dict() params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight) output_dim = getattr(param, "output_dim", None)
loaded_weight_size = loaded_weight.size() loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view( loaded_weight = loaded_weight.view(
total_num_kv_heads, num_query_heads_per_kv_head + 2, loaded_weight_shape[:output_dim] +
head_size, *loaded_weight_size[1:]) (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) +
loaded_weight_shape[output_dim + 1:])
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:]) wq = loaded_weight.narrow(
wk = loaded_weight[:, [-2]].reshape(-1, output_dim + 1, 0, num_query_heads_per_kv_head).reshape(
*loaded_weight_size[1:]) *loaded_weight_shape[:output_dim], -1,
wv = loaded_weight[:, [-1]].reshape(-1, *loaded_weight_shape[output_dim + 1:])
*loaded_weight_size[1:]) wk = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head,
wq = wq[head_size * head_start:head_size * head_end] 1).reshape(*loaded_weight_shape[:output_dim], -1,
wk = wk[head_size * kv_head_start:head_size * kv_head_end] *loaded_weight_shape[output_dim + 1:])
wv = wv[head_size * kv_head_start:head_size * kv_head_end] wv = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head + 1,
if separated_q_kv: 1).reshape(*loaded_weight_shape[:output_dim], -1,
loaded_weight_q = wq *loaded_weight_shape[output_dim + 1:])
loaded_weight_kv = torch.cat([wk, wv], dim=0) loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
q_weight_name = name.replace("query_key_value", "query")
kv_weight_name = name.replace("query_key_value", weight_loader = getattr(param, "weight_loader",
"key_value") default_weight_loader)
load_tensor_parallel_weights(state_dict[q_weight_name], weight_loader(param, loaded_weight)
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
continue
else:
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
...@@ -30,15 +30,17 @@ from transformers import GPT2Config ...@@ -30,15 +30,17 @@ from transformers import GPT2Config
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
convert_pyslice_to_tensor, hf_model_weights_iterator, VocabParallelEmbedding)
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -46,7 +48,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] ...@@ -46,7 +48,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
def __init__(self, config: GPT2Config): def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
...@@ -57,17 +63,18 @@ class GPT2Attention(nn.Module): ...@@ -57,17 +63,18 @@ class GPT2Attention(nn.Module):
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear( self.c_attn = QKVParallelLinear(
self.hidden_size, self.hidden_size,
3 * self.hidden_size, self.head_dim,
total_num_heads,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -95,6 +102,7 @@ class GPT2MLP(nn.Module): ...@@ -95,6 +102,7 @@ class GPT2MLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPT2Config, config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -102,13 +110,13 @@ class GPT2MLP(nn.Module): ...@@ -102,13 +110,13 @@ class GPT2MLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
...@@ -121,16 +129,20 @@ class GPT2MLP(nn.Module): ...@@ -121,16 +129,20 @@ class GPT2MLP(nn.Module):
class GPT2Block(nn.Module): class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config): def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 * inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config) self.attn = GPT2Attention(config, linear_method)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config) self.mlp = GPT2MLP(inner_dim, config, linear_method)
def forward( def forward(
self, self,
...@@ -160,24 +172,23 @@ class GPT2Block(nn.Module): ...@@ -160,24 +172,23 @@ class GPT2Block(nn.Module):
class GPT2Model(nn.Module): class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config): def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList([
[GPT2Block(config) for _ in range(config.num_hidden_layers)]) GPT2Block(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -207,12 +218,15 @@ class GPT2Model(nn.Module): ...@@ -207,12 +218,15 @@ class GPT2Model(nn.Module):
class GPT2LMHeadModel(nn.Module): class GPT2LMHeadModel(nn.Module):
def __init__(self, config: GPT2Config): def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = GPT2Model(config) self.linear_method = linear_method
# TODO(zhuohan): create a new weight after implementing pipeline self.transformer = GPT2Model(config, linear_method)
# parallelism
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -230,19 +244,12 @@ class GPT2LMHeadModel(nn.Module): ...@@ -230,19 +244,12 @@ class GPT2LMHeadModel(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_world_size = ( params_dict = dict(self.named_parameters(remove_duplicate=False))
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
...@@ -253,53 +260,19 @@ class GPT2LMHeadModel(nn.Module): ...@@ -253,53 +260,19 @@ class GPT2LMHeadModel(nn.Module):
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
param = params_dict[name]
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights. # Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name: if conv1d_weight_name not in name:
continue continue
if not name.endswith(".weight"): if not name.endswith(".weight"):
continue continue
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
param = state_dict[name]
if name == "transformer.wte.weight": weight_loader = getattr(param, "weight_loader",
load_padded_tensor_parallel_vocab(param, loaded_weight, default_weight_loader)
tensor_model_parallel_rank) weight_loader(param, loaded_weight)
continue
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
...@@ -31,15 +31,17 @@ from transformers import GPTBigCodeConfig ...@@ -31,15 +31,17 @@ from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
convert_pyslice_to_tensor, hf_model_weights_iterator, VocabParallelEmbedding)
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -47,7 +49,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] ...@@ -47,7 +49,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
...@@ -61,32 +67,26 @@ class GPTBigCodeAttention(nn.Module): ...@@ -61,32 +67,26 @@ class GPTBigCodeAttention(nn.Module):
self.multi_query = config.multi_query self.multi_query = config.multi_query
if self.multi_query: if self.multi_query:
total_num_kv_heads = 1
self.num_kv_heads = 1 self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else: else:
total_num_kv_heads = total_num_heads
self.num_kv_heads = self.num_heads self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim self.kv_dim = self.head_dim * self.num_kv_heads
self.c_attn = ColumnParallelLinear( self.c_attn = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size + 2 * self.kv_dim, self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -100,17 +100,14 @@ class GPTBigCodeAttention(nn.Module): ...@@ -100,17 +100,14 @@ class GPTBigCodeAttention(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
if self.multi_query:
q, _ = self.c_attn_q(hidden_states)
kv = self.c_attn_kv(hidden_states)
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([ q, k, v = qkv.split(
[
self.hidden_size // self.tensor_model_parallel_world_size, self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim self.kv_dim, self.kv_dim
], ],
dim=-1) dim=-1,
)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache, attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event) input_metadata, cache_event)
...@@ -124,6 +121,7 @@ class GPTBigMLP(nn.Module): ...@@ -124,6 +121,7 @@ class GPTBigMLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -131,13 +129,13 @@ class GPTBigMLP(nn.Module): ...@@ -131,13 +129,13 @@ class GPTBigMLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=True, bias=True,
gather_output=False, linear_method=linear_method,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
input_is_parallel=True, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
...@@ -150,16 +148,20 @@ class GPTBigMLP(nn.Module): ...@@ -150,16 +148,20 @@ class GPTBigMLP(nn.Module):
class GPTBigCodeBlock(nn.Module): class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 * inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config) self.attn = GPTBigCodeAttention(config, linear_method)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config) self.mlp = GPTBigMLP(inner_dim, config, linear_method)
def forward( def forward(
self, self,
...@@ -189,23 +191,23 @@ class GPTBigCodeBlock(nn.Module): ...@@ -189,23 +191,23 @@ class GPTBigCodeBlock(nn.Module):
class GPTBigCodeModel(nn.Module): class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList([
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)]) GPTBigCodeBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -235,12 +237,15 @@ class GPTBigCodeModel(nn.Module): ...@@ -235,12 +237,15 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module): class GPTBigCodeForCausalLM(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = GPTBigCodeModel(config) self.linear_method = linear_method
# TODO(zhuohan): create a new weight after implementing pipeline self.transformer = GPTBigCodeModel(config, linear_method)
# parallelism
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -258,89 +263,21 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -258,89 +263,21 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_world_size = ( params_dict = dict(self.named_parameters(remove_duplicate=False))
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue continue
if ".attn.bias" in name: if ".attn.bias" in name:
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
param = params_dict[name]
if not name.startswith("transformer."): weight_loader = getattr(param, "weight_loader",
name = "transformer." + name default_weight_loader)
weight_loader(param, loaded_weight)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)
wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
else:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("c_attn", "c_attn_q")
kv_weight_name = name.replace("c_attn", "c_attn_kv")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
param = state_dict[name]
if name == "transformer.wte.weight":
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
...@@ -29,14 +29,17 @@ from transformers import GPTJConfig ...@@ -29,14 +29,17 @@ from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.layers.vocab_parallel_embedding import (
load_tensor_parallel_weights) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -44,23 +47,28 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] ...@@ -44,23 +47,28 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTJAttention(nn.Module): class GPTJAttention(nn.Module):
def __init__(self, config: GPTJConfig): def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = QKVParallelLinear(
config.hidden_size, config.hidden_size,
3 * config.hidden_size, self.head_size,
self.total_num_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
) )
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
...@@ -102,18 +110,23 @@ class GPTJAttention(nn.Module): ...@@ -102,18 +110,23 @@ class GPTJAttention(nn.Module):
class GPTJMLP(nn.Module): class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig): def __init__(
self,
intermediate_size: int,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.n_embd hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear( self.fc_in = ColumnParallelLinear(
hidden_size, hidden_size,
intermediate_size, intermediate_size,
gather_output=False, linear_method=linear_method,
) )
self.fc_out = RowParallelLinear( self.fc_out = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
input_is_parallel=True, linear_method=linear_method,
) )
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
...@@ -126,15 +139,19 @@ class GPTJMLP(nn.Module): ...@@ -126,15 +139,19 @@ class GPTJMLP(nn.Module):
class GPTJBlock(nn.Module): class GPTJBlock(nn.Module):
def __init__(self, config: GPTJConfig): def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
if config.n_inner is None: if config.n_inner is None:
inner_dim = 4 * config.n_embd inner_dim = 4 * config.n_embd
else: else:
inner_dim = config.n_inner inner_dim = config.n_inner
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config) self.attn = GPTJAttention(config, linear_method)
self.mlp = GPTJMLP(inner_dim, config) self.mlp = GPTJMLP(inner_dim, config, linear_method)
def forward( def forward(
self, self,
...@@ -160,7 +177,11 @@ class GPTJBlock(nn.Module): ...@@ -160,7 +177,11 @@ class GPTJBlock(nn.Module):
class GPTJModel(nn.Module): class GPTJModel(nn.Module):
def __init__(self, config: GPTJConfig): def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.n_embd self.embed_dim = config.n_embd
...@@ -169,7 +190,7 @@ class GPTJModel(nn.Module): ...@@ -169,7 +190,7 @@ class GPTJModel(nn.Module):
self.embed_dim, self.embed_dim,
) )
self.h = nn.ModuleList( self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)]) [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -200,15 +221,20 @@ class GPTJModel(nn.Module): ...@@ -200,15 +221,20 @@ class GPTJModel(nn.Module):
class GPTJForCausalLM(nn.Module): class GPTJForCausalLM(nn.Module):
def __init__(self, config: GPTJConfig): def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method
assert not config.tie_word_embeddings assert not config.tie_word_embeddings
self.transformer = GPTJModel(config) self.transformer = GPTJModel(config, linear_method)
self.lm_head = ColumnParallelLinear( self.lm_head = ParallelLMHead(
config.n_embd,
config.vocab_size, config.vocab_size,
gather_output=False, config.n_embd,
bias=True,
) )
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -226,43 +252,33 @@ class GPTJForCausalLM(nn.Module): ...@@ -226,43 +252,33 @@ class GPTJForCausalLM(nn.Module):
input_metadata, self.lm_head.bias) input_metadata, self.lm_head.bias)
return next_tokens return next_tokens
_column_parallel_weights = [
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
"lm_head.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() stacked_params_mapping = [
state_dict = self.state_dict() # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "attn.bias" in name or "attn.masked_bias" in name: if "attn.bias" in name or "attn.masked_bias" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
is_attention_weight = False if weight_name not in name:
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = params_dict[name.replace(weight_name, param_name)]
shard_size = param.shape[0] // 3 weight_loader = param.weight_loader
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * weight_loader(param, loaded_weight, shard_id)
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break break
if is_attention_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
load_tensor_parallel_weights(param, loaded_weight, name, weight_loader(param, loaded_weight)
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment