Unverified Commit 7076fa1c authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

TP/quantization/weight loading refactor part 2 - Refactor quantized linear...

TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
parent 660a7fcf
......@@ -140,8 +140,8 @@ class ModelConfig:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker."""
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
......@@ -155,23 +155,34 @@ class ModelConfig:
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
if getattr(self.hf_config, "num_kv_heads", None) is not None:
return (self.hf_config.num_kv_heads //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
# For ChatGLM-2:
if getattr(self.hf_config, "multi_query_group_num", None) is not None:
return (self.hf_config.multi_query_group_num //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size
attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
......
......@@ -142,10 +142,10 @@ class RequestTracker:
self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[dict] = []
new_requests: List[Dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
......
This diff is collapsed.
from typing import Type
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
_QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig,
"squeezellm": SqueezeLLMConfig,
}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quantization_config",
]
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
def get_name(self) -> str:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
@staticmethod
def get_config_filenames() -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
def get_linear_method(self) -> "AWQLinearMethod":
return AWQLinearMethod(self)
class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.
Args:
quant_config: The AWQ quantization config.
"""
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
if input_size % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
qweight = Parameter(
torch.empty(
input_size,
output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
qzeros = Parameter(
torch.empty(
input_size // self.quant_config.group_size,
output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
input_size // self.quant_config.group_size,
output_size,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
return {
"qweight": qweight,
"qzeros": qzeros,
"scales": scales,
}
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
qzeros = weights["qzeros"]
scales = weights["scales"]
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
from typing import Any, Dict, List, Optional
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import torch
from vllm.model_executor.layers.linear import LinearMethodBase
class QuantizationConfig:
@classmethod
def get_name(cls) -> str:
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
@abstractmethod
def get_name(self) -> str:
"""Name of the quantization method."""
raise NotImplementedError
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@abstractmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
def get_min_capability(cls) -> int:
@abstractmethod
def get_min_capability(self) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
......@@ -25,12 +29,14 @@ class QuantizationConfig:
"""
raise NotImplementedError
@classmethod
def get_config_filenames(cls) -> List[str]:
@staticmethod
@abstractmethod
def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
......@@ -44,42 +50,7 @@ class QuantizationConfig:
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@classmethod
def get_packed_tensors(cls) -> Dict[str, int]:
"""Returns a dictionary of packed tensor names and their pack dims."""
raise NotImplementedError
@classmethod
def get_packed_dim(cls, tensor_name: str) -> Optional[int]:
"""Returns the pack dim of a tensor if it is packed.
A tensor is considered packed if each element in the tensor is a
packed representation of multiple elements in the original tensor.
For example, an INT32 element in the tensor may represent 8 INT4
elements in the original tensor.
If the tensor is not packed, returns None.
"""
packed_tensors = cls.get_packed_tensors()
for packed_tensor_name, pack_dim in packed_tensors.items():
if packed_tensor_name in tensor_name:
return pack_dim
return None
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_transposed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
"""
return any(tag in tensor_name
for tag in cls.get_transposed_tensor_names())
@classmethod
def get_col_parallel_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def get_row_parallel_tensor_names(cls) -> List[str]:
@abstractmethod
def get_linear_method(self) -> LinearMethodBase:
"""Get the linear method to use for the quantized linear layer."""
raise NotImplementedError
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class SqueezeLLMConfig(QuantizationConfig):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def __init__(
self,
weight_bits: int,
) -> None:
self.weight_bits = weight_bits
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"SqueezeLLM, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
def get_name(self) -> str:
return "squeezellm"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
return 70
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
return SqueezeLLMLinearMethod(self)
class SqueezeLLMLinearMethod(LinearMethodBase):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
if input_size % self.quant_config.pack_factor != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
qweight = Parameter(
torch.empty(
input_size // self.quant_config.pack_factor,
output_size,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
lookup_table = Parameter(
torch.empty(
output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(lookup_table, {
"output_dim": 0,
})
return {
"qweight": qweight,
"lookup_table": lookup_table,
}
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
lookup_table = weights["lookup_table"]
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
lookup_table)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.layers.quantized_linear.squeezellm import (
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
"squeezellm":
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
}
class ParallelLinear:
@classmethod
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return ColumnParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
return quant_linear_cls(*args, **kwargs)
@classmethod
def row(cls, *args, **kwargs) -> RowParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return RowParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
return quant_linear_cls(*args, **kwargs)
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.group_size == 0
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class AWQRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.output_size % self.quant_config.pack_factor == 0
if self.input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
return out.reshape(out_shape)
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
class SqueezeLLMColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size // self.quant_config.pack_factor,
self.output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size_per_partition,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class SqueezeLLMRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
if self.input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.pack_factor,
self.output_size,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)
return out.reshape(out_shape)
from typing import Optional, Sequence
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
rank: int) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank)
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
vocab_range_from_global_vocab_size(
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.weight, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None):
super().__init__(num_embeddings, embedding_dim, params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
else:
self.register_parameter("bias", None)
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")
......@@ -37,13 +37,6 @@ _MODEL_REGISTRY = {
"YiForCausalLM": YiForCausalLM,
}
# FIXME(woosuk): Remove this once all models support quantization.
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
LlamaForCausalLM,
MistralForCausalLM,
YiForCausalLM,
]
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
......@@ -67,12 +60,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config.
quant_config = None
# Get the (maybe quantized) linear method.
linear_method = None
if model_config.quantization is not None:
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
raise ValueError(
f"Quantization is not supported for {model_class}.")
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.download_dir)
......@@ -90,14 +80,12 @@ def get_model(model_config: ModelConfig) -> nn.Module:
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
model = model_class(model_config.hf_config, quant_config)
else:
model = model_class(model_config.hf_config)
model = model_class(model_config.hf_config, linear_method)
if model_config.load_format == "dummy":
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
......
......@@ -33,15 +33,17 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig
......@@ -55,20 +57,17 @@ class AquilaMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
gather_output=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
)
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -111,6 +110,7 @@ class AquilaAttention(nn.Module):
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -128,29 +128,29 @@ class AquilaAttention(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear(
self.qkv_proj = QKVParallelLinear(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
gather_output=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling,
)
rope_scaling=rope_scaling)
def forward(
self,
......@@ -171,7 +171,11 @@ class AquilaAttention(nn.Module):
class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig):
def __init__(
self,
config: AquilaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
......@@ -185,11 +189,13 @@ class AquilaDecoderLayer(nn.Module):
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
rope_scaling=rope_scaling,
linear_method=linear_method,
)
self.mlp = AquilaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = AquilaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -226,19 +232,22 @@ class AquilaDecoderLayer(nn.Module):
class AquilaModel(nn.Module):
def __init__(self, config: AquilaConfig):
def __init__(
self,
config: AquilaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
#vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
AquilaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -271,17 +280,16 @@ class AquilaModel(nn.Module):
class AquilaForCausalLM(nn.Module):
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.model = AquilaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.linear_method = linear_method
self.model = AquilaModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
......@@ -298,79 +306,33 @@ class AquilaForCausalLM(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
state_dict = self.state_dict()
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -30,18 +30,20 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
......@@ -80,20 +82,17 @@ class BaiChuanMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
input_is_parallel=True,
)
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -116,6 +115,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -131,17 +131,19 @@ class BaiChuanAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear(
self.W_pack = QKVParallelLinear(
hidden_size,
3 * hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_heads,
bias=False,
gather_output=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
......@@ -188,7 +190,10 @@ class BaiChuanAttention(nn.Module):
class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
def __init__(self,
config: BaiChuanConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
......@@ -200,11 +205,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -241,7 +248,10 @@ class BaiChuanDecoderLayer(nn.Module):
class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
def __init__(self,
config: BaiChuanConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
......@@ -252,7 +262,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding)
BaiChuanDecoderLayer(config, position_embedding, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -285,16 +295,15 @@ class BaiChuanModel(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config, position_embedding: str):
def __init__(self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
)
self.linear_method = linear_method
self.model = BaiChuanModel(config, position_embedding, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
......@@ -311,79 +320,46 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = []
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config):
super().__init__(config, "ALIBI")
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ALIBI", linear_method)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config):
super().__init__(config, "ROPE")
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ROPE", linear_method)
......@@ -30,14 +30,17 @@ from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -70,7 +73,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
class BloomAttention(nn.Module):
def __init__(self, config: BloomConfig):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.total_num_heads = config.n_head
......@@ -81,17 +88,18 @@ class BloomAttention(nn.Module):
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
self.query_key_value = ColumnParallelLinear(
self.query_key_value = QKVParallelLinear(
self.hidden_size,
3 * self.hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
gather_output=False,
linear_method=linear_method,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
linear_method=linear_method,
)
# Create the alibi slopes and slice them.
......@@ -125,19 +133,23 @@ class BloomAttention(nn.Module):
class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
gather_output=False,
linear_method=linear_method,
)
self.act = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
input_is_parallel=True,
linear_method=linear_method,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -149,16 +161,20 @@ class BloomMLP(nn.Module):
class BloomBlock(nn.Module):
def __init__(self, config: BloomConfig):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config)
self.self_attention = BloomAttention(config, linear_method)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config)
self.mlp = BloomMLP(config, linear_method)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
......@@ -203,7 +219,11 @@ class BloomBlock(nn.Module):
class BloomModel(nn.Module):
def __init__(self, config: BloomConfig):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.embed_dim = config.hidden_size
......@@ -216,8 +236,10 @@ class BloomModel(nn.Module):
self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList(
[BloomBlock(config) for _ in range(config.num_hidden_layers)])
self.h = nn.ModuleList([
BloomBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
# Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -251,12 +273,15 @@ class BloomModel(nn.Module):
class BloomForCausalLM(nn.Module):
def __init__(self, config: BloomConfig):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.transformer = BloomModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.linear_method = linear_method
self.transformer = BloomModel(config, linear_method)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.sampler = Sampler(config.vocab_size)
......@@ -274,55 +299,36 @@ class BloomForCausalLM(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.
self._column_parallel_weights.append(name)
# If lm_head is provided, use it instead.
param = self.lm_head_weight
else:
if not name.startswith("transformer."):
name = "transformer." + name
param = state_dict[name]
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the
# required shape is [3 * num_heads * head_size, hidden_size].
# NOTE: BLOOM's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
shard_size = param.shape[0]
start = shard_size * tp_rank
end = shard_size * (tp_rank + 1)
loaded_weight = loaded_weight[start:end]
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads
if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -6,32 +6,28 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
load_tensor_parallel_weights,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
from vllm.model_executor.parallel_utils.layers import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SequenceOutputs
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -39,7 +35,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GLMAttention(nn.Module):
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
......@@ -50,25 +50,33 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.query_key_value = ColumnParallelLinear(
config.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
bias=config.add_qkv_bias,
gather_output=False,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
linear_method=linear_method,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
input_is_parallel=True,
linear_method=linear_method,
)
self.attn = PagedAttentionWithRoPE(
......@@ -78,7 +86,6 @@ class GLMAttention(nn.Module):
rotary_dim=self.head_dim // 2,
num_kv_heads=self.num_kv_heads,
is_neox_style=False,
# is_glm_style=True
)
def forward(
......@@ -117,17 +124,21 @@ class GLMMLP(nn.Module):
state back into h hidden dimension.
"""
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.add_bias = config.add_bias_linear
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(
self.dense_h_to_4h = MergedColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size * 2,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
gather_output=False,
linear_method=linear_method,
)
self.activation_func = SiluAndMul()
......@@ -137,7 +148,7 @@ class GLMMLP(nn.Module):
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
input_is_parallel=True,
linear_method=linear_method,
)
def forward(self, hidden_states):
......@@ -159,6 +170,7 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
......@@ -172,7 +184,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config)
self.self_attention = GLMAttention(config, linear_method)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
......@@ -180,7 +192,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config)
self.mlp = GLMMLP(config, linear_method)
def forward(
self,
......@@ -227,7 +239,11 @@ class GLMBlock(nn.Module):
class GLMTransformer(nn.Module):
"""Transformer class."""
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
......@@ -236,7 +252,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config) for i in range(self.num_layers)])
[GLMBlock(config, linear_method) for i in range(self.num_layers)])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
......@@ -274,7 +290,11 @@ class GLMTransformer(nn.Module):
class ChatGLMModel(nn.Module):
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
......@@ -283,15 +303,10 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config)
self.encoder = GLMTransformer(config, linear_method)
self.output_layer = ColumnParallelLinear(
config.hidden_size,
config.padded_vocab_size,
bias=False,
gather_output=False,
params_dtype=config.torch_dtype,
)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
def forward(
self,
......@@ -317,10 +332,15 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module):
def __init__(self, config: ChatGLMConfig):
def __init__(
self,
config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.transformer = ChatGLMModel(config)
self.linear_method = linear_method
self.transformer = ChatGLMModel(config, linear_method)
self.lm_head_weight = self.transformer.output_layer.weight
self.sampler = Sampler(config.padded_vocab_size)
......@@ -331,78 +351,26 @@ class ChatGLMForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"output_layer.weight",
"embedding.weight",
]
_row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"]
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
q_proj_shard_size = self.config.hidden_size // tp_size
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.multi_query_group_num // tp_size)
mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size
state_dict = self.state_dict()
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
if name in state_dict:
param = state_dict[name]
if "query_key_value" in name:
q_offset = q_proj_shard_size * tp_rank
k_offset = (q_proj_shard_size * tp_size +
kv_proj_shard_size * tp_rank)
v_offset = (q_proj_shard_size * tp_size +
kv_proj_shard_size * (tp_size + tp_rank))
wq = loaded_weight[q_offset:q_offset + q_proj_shard_size]
wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size]
wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param.data.copy_(loaded_weight)
continue
if "dense_h_to_4h" in name:
w_gate = loaded_weight[mlp_hidden_shard_size *
tp_rank:mlp_hidden_shard_size *
(tp_rank + 1)]
w_proj = loaded_weight[mlp_hidden_shard_size *
(tp_size +
tp_rank):mlp_hidden_shard_size *
(tp_size + tp_rank + 1)]
loaded_weight = torch.cat([w_gate, w_proj], dim=0)
param.data.copy_(loaded_weight)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
elif name == "transformer.rotary_pos_emb.inv_freq":
continue
else:
print("Warning never found tensor's name:", name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -30,17 +30,19 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi,
PagedAttentionWithRoPE)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig
......@@ -48,19 +50,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
FalconConfig = Union[HF_FalconConfig, RWConfig]
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
# training, this means that there's one additional quantization to bfloat16
# between the operations. In order not to degrade the quality of our HF-port,
# we keep these characteristics in the final model.
class FalconLinear(nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = x @ self.weight.T
if self.bias is None:
return hidden_states
return hidden_states + self.bias
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
......@@ -86,7 +75,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -103,41 +96,29 @@ class FalconAttention(nn.Module):
if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads
assert self.total_num_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
skip_bias_add=True,
)
elif self.multi_query:
self.total_num_kv_heads = 1
self.num_kv_heads = 1
self.query = ColumnParallelLinear(
self.hidden_size,
self.total_num_heads * self.head_dim,
bias=config.bias,
gather_output=False,
skip_bias_add=True,
)
self.key_value = FalconLinear(self.hidden_size,
2 * self.head_dim,
bias=config.bias)
else:
self.total_num_kv_heads = self.total_num_heads
self.num_kv_heads = self.num_heads
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
skip_bias_add=True,
)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
......@@ -149,7 +130,6 @@ class FalconAttention(nn.Module):
self.hidden_size,
self.hidden_size,
bias=config.bias,
input_is_parallel=True,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
......@@ -196,18 +176,10 @@ class FalconAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if not self.new_decoder_architecture and self.multi_query:
q, bias = self.query(hidden_states)
if bias is not None:
q += bias
kv = self.key_value(hidden_states)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
else:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
if self.use_rotary:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
......@@ -221,15 +193,19 @@ class FalconAttention(nn.Module):
class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
bias=config.bias,
gather_output=False,
skip_bias_add=True)
skip_bias_add=True,
linear_method=linear_method)
self.act = nn.GELU()
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
......@@ -237,9 +213,9 @@ class FalconMLP(nn.Module):
4 * hidden_size,
hidden_size,
bias=config.bias,
input_is_parallel=True,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
reduce_results=self.reduce_row_parallel_results,
linear_method=linear_method)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
......@@ -253,12 +229,16 @@ class FalconMLP(nn.Module):
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config)
self.mlp = FalconMLP(config)
self.self_attention = FalconAttention(config, linear_method)
self.mlp = FalconMLP(config, linear_method)
self.config = config
if config.new_decoder_architecture:
......@@ -334,7 +314,11 @@ class FalconDecoderLayer(nn.Module):
class FalconModel(nn.Module):
def __init__(self, config: FalconConfig):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
......@@ -349,7 +333,8 @@ class FalconModel(nn.Module):
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
FalconDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
# Final Layer Norm
......@@ -383,15 +368,18 @@ class FalconModel(nn.Module):
class FalconForCausalLM(nn.Module):
def __init__(self, config: FalconConfig):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.transformer = FalconModel(config)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
self.linear_method = linear_method
self.transformer = FalconModel(config, linear_method)
self.lm_head = ParallelLMHead(
config.vocab_size,
bias=False,
gather_output=False,
config.hidden_size,
)
self.sampler = Sampler(config.vocab_size)
......@@ -415,89 +403,44 @@ class FalconForCausalLM(nn.Module):
return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank()
hidden_size = self.config.hidden_size
total_num_heads = self.config.num_attention_heads
num_heads = total_num_heads // tp_size
head_size = hidden_size // total_num_heads
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
elif self.config.multi_query:
total_num_kv_heads = 1
num_kv_heads = 1
separated_q_kv = True
kv_head_start = 0
kv_head_end = 1
else:
total_num_kv_heads = total_num_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
state_dict = self.state_dict()
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
param = params_dict[name]
if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight_size = loaded_weight.size()
output_dim = getattr(param, "output_dim", None)
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
total_num_kv_heads, num_query_heads_per_kv_head + 2,
head_size, *loaded_weight_size[1:])
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
wk = loaded_weight[:, [-2]].reshape(-1,
*loaded_weight_size[1:])
wv = loaded_weight[:, [-1]].reshape(-1,
*loaded_weight_size[1:])
wq = wq[head_size * head_start:head_size * head_end]
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
if separated_q_kv:
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("query_key_value", "query")
kv_weight_name = name.replace("query_key_value",
"key_value")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
continue
else:
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
loaded_weight_shape[:output_dim] +
(total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) +
loaded_weight_shape[output_dim + 1:])
wq = loaded_weight.narrow(
output_dim + 1, 0, num_query_heads_per_kv_head).reshape(
*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wk = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wv = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head + 1,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -30,15 +30,17 @@ from transformers import GPT2Config
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -46,7 +48,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPT2Attention(nn.Module):
def __init__(self, config: GPT2Config):
def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
......@@ -57,17 +63,18 @@ class GPT2Attention(nn.Module):
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(
self.c_attn = QKVParallelLinear(
self.hidden_size,
3 * self.hidden_size,
self.head_dim,
total_num_heads,
bias=True,
gather_output=False,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
linear_method=linear_method,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
......@@ -95,6 +102,7 @@ class GPT2MLP(nn.Module):
self,
intermediate_size: int,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -102,13 +110,13 @@ class GPT2MLP(nn.Module):
hidden_size,
intermediate_size,
bias=True,
gather_output=False,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
linear_method=linear_method,
)
self.act = get_act_fn(config.activation_function)
......@@ -121,16 +129,20 @@ class GPT2MLP(nn.Module):
class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config):
def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config)
self.attn = GPT2Attention(config, linear_method)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
self.mlp = GPT2MLP(inner_dim, config, linear_method)
def forward(
self,
......@@ -160,24 +172,23 @@ class GPT2Block(nn.Module):
class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config):
def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[GPT2Block(config) for _ in range(config.num_hidden_layers)])
self.h = nn.ModuleList([
GPT2Block(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
......@@ -207,12 +218,15 @@ class GPT2Model(nn.Module):
class GPT2LMHeadModel(nn.Module):
def __init__(self, config: GPT2Config):
def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.transformer = GPT2Model(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.linear_method = linear_method
self.transformer = GPT2Model(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
......@@ -230,19 +244,12 @@ class GPT2LMHeadModel(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name:
......@@ -253,53 +260,19 @@ class GPT2LMHeadModel(nn.Module):
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
param = state_dict[name]
if name == "transformer.wte.weight":
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -31,15 +31,17 @@ from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -47,7 +49,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTBigCodeAttention(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
......@@ -61,32 +67,26 @@ class GPTBigCodeAttention(nn.Module):
self.multi_query = config.multi_query
if self.multi_query:
total_num_kv_heads = 1
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
total_num_kv_heads = total_num_heads
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(
self.hidden_size,
self.hidden_size + 2 * self.kv_dim,
bias=True,
gather_output=False,
)
self.kv_dim = self.head_dim * self.num_kv_heads
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
linear_method=linear_method,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
......@@ -100,17 +100,14 @@ class GPTBigCodeAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if self.multi_query:
q, _ = self.c_attn_q(hidden_states)
kv = self.c_attn_kv(hidden_states)
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
[
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1)
dim=-1,
)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
......@@ -124,6 +121,7 @@ class GPTBigMLP(nn.Module):
self,
intermediate_size: int,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -131,13 +129,13 @@ class GPTBigMLP(nn.Module):
hidden_size,
intermediate_size,
bias=True,
gather_output=False,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
linear_method=linear_method,
)
self.act = get_act_fn(config.activation_function)
......@@ -150,16 +148,20 @@ class GPTBigMLP(nn.Module):
class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config)
self.attn = GPTBigCodeAttention(config, linear_method)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config)
self.mlp = GPTBigMLP(inner_dim, config, linear_method)
def forward(
self,
......@@ -189,23 +191,23 @@ class GPTBigCodeBlock(nn.Module):
class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
self.h = nn.ModuleList([
GPTBigCodeBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
......@@ -235,12 +237,15 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.transformer = GPTBigCodeModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.linear_method = linear_method
self.transformer = GPTBigCodeModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
......@@ -258,89 +263,21 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)
wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
else:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("c_attn", "c_attn_q")
kv_weight_name = name.replace("c_attn", "c_attn_kv")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
param = state_dict[name]
if name == "transformer.wte.weight":
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -29,14 +29,17 @@ from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -44,23 +47,28 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTJAttention(nn.Module):
def __init__(self, config: GPTJConfig):
def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear(
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
3 * config.hidden_size,
self.head_size,
self.total_num_heads,
bias=False,
gather_output=False,
linear_method=linear_method,
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method,
)
tp_world_size = get_tensor_model_parallel_world_size()
......@@ -102,18 +110,23 @@ class GPTJAttention(nn.Module):
class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig):
def __init__(
self,
intermediate_size: int,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(
hidden_size,
intermediate_size,
gather_output=False,
linear_method=linear_method,
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
input_is_parallel=True,
linear_method=linear_method,
)
self.act = get_act_fn(config.activation_function)
......@@ -126,15 +139,19 @@ class GPTJMLP(nn.Module):
class GPTJBlock(nn.Module):
def __init__(self, config: GPTJConfig):
def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
if config.n_inner is None:
inner_dim = 4 * config.n_embd
else:
inner_dim = config.n_inner
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config)
self.mlp = GPTJMLP(inner_dim, config)
self.attn = GPTJAttention(config, linear_method)
self.mlp = GPTJMLP(inner_dim, config, linear_method)
def forward(
self,
......@@ -160,7 +177,11 @@ class GPTJBlock(nn.Module):
class GPTJModel(nn.Module):
def __init__(self, config: GPTJConfig):
def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.n_embd
......@@ -169,7 +190,7 @@ class GPTJModel(nn.Module):
self.embed_dim,
)
self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)])
[GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
......@@ -200,15 +221,20 @@ class GPTJModel(nn.Module):
class GPTJForCausalLM(nn.Module):
def __init__(self, config: GPTJConfig):
def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config)
self.lm_head = ColumnParallelLinear(
config.n_embd,
self.transformer = GPTJModel(config, linear_method)
self.lm_head = ParallelLMHead(
config.vocab_size,
gather_output=False,
config.n_embd,
bias=True,
)
self.sampler = Sampler(config.vocab_size)
......@@ -226,43 +252,33 @@ class GPTJForCausalLM(nn.Module):
input_metadata, self.lm_head.bias)
return next_tokens
_column_parallel_weights = [
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
"lm_head.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "attn.bias" in name or "attn.masked_bias" in name:
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
if is_attention_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment