Unverified Commit 7076fa1c authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

TP/quantization/weight loading refactor part 2 - Refactor quantized linear...

TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
parent 660a7fcf
...@@ -29,14 +29,17 @@ from transformers import GPTNeoXConfig ...@@ -29,14 +29,17 @@ from transformers import GPTNeoXConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.layers.vocab_parallel_embedding import (
load_tensor_parallel_weights) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -44,7 +47,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] ...@@ -44,7 +47,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -56,15 +63,16 @@ class GPTNeoXAttention(nn.Module): ...@@ -56,15 +63,16 @@ class GPTNeoXAttention(nn.Module):
self.num_heads = (self.total_num_heads // self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.query_key_value = ColumnParallelLinear( self.query_key_value = QKVParallelLinear(
config.hidden_size, config.hidden_size,
3 * config.hidden_size, self.head_size,
gather_output=False, self.total_num_heads,
linear_method=linear_method,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
input_is_parallel=True, linear_method=linear_method,
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
...@@ -100,17 +108,21 @@ class GPTNeoXAttention(nn.Module): ...@@ -100,17 +108,21 @@ class GPTNeoXAttention(nn.Module):
class GPTNeoXMLP(nn.Module): class GPTNeoXMLP(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
gather_output=False, linear_method=linear_method,
) )
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
input_is_parallel=True, linear_method=linear_method,
) )
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
...@@ -123,15 +135,19 @@ class GPTNeoXMLP(nn.Module): ...@@ -123,15 +135,19 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config) self.attention = GPTNeoXAttention(config, linear_method)
self.mlp = GPTNeoXMLP(config) self.mlp = GPTNeoXMLP(config, linear_method)
def forward( def forward(
self, self,
...@@ -169,7 +185,11 @@ class GPTNeoXLayer(nn.Module): ...@@ -169,7 +185,11 @@ class GPTNeoXLayer(nn.Module):
class GPTNeoXModel(nn.Module): class GPTNeoXModel(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -177,8 +197,10 @@ class GPTNeoXModel(nn.Module): ...@@ -177,8 +197,10 @@ class GPTNeoXModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) GPTNeoXLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -210,15 +232,18 @@ class GPTNeoXModel(nn.Module): ...@@ -210,15 +232,18 @@ class GPTNeoXModel(nn.Module):
class GPTNeoXForCausalLM(nn.Module): class GPTNeoXForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.gpt_neox = GPTNeoXModel(config) self.linear_method = linear_method
self.embed_out = ColumnParallelLinear( self.gpt_neox = GPTNeoXModel(config, linear_method)
config.hidden_size, self.embed_out = ParallelLMHead(
config.vocab_size, config.vocab_size,
bias=False, config.hidden_size,
gather_output=False,
) )
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -236,50 +261,35 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -236,50 +261,35 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() params_dict = dict(self.named_parameters())
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
continue continue
param = state_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
# [num_heads * 3 * head_size, hidden_size], while the # (num_heads * 3 * head_size), while the
# required shape is [3 * num_heads * head_size, hidden_size]. # required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion. # Thus, we need weight conversion.
shard_size = param.shape[0] output_dim = getattr(param, "output_dim", None)
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size if output_dim is not None:
head_size = hidden_size // num_heads loaded_weight_shape = loaded_weight.shape
if "query_key_value.weight" in name: loaded_weight = loaded_weight.view(
loaded_weight = loaded_weight.view(-1, 3, head_size, loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
hidden_size) loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(
loaded_weight = loaded_weight.reshape(-1, hidden_size) output_dim, output_dim + 1)
elif "query_key_value.bias" in name: loaded_weight = loaded_weight.reshape(loaded_weight_shape)
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1) weight_loader = getattr(param, "weight_loader",
loaded_weight = loaded_weight.reshape(-1) default_weight_loader)
else: weight_loader(param, loaded_weight)
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
...@@ -9,15 +9,17 @@ from vllm.model_executor.input_metadata import InputMetadata ...@@ -9,15 +9,17 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, from vllm.model_executor.weight_utils import (default_weight_loader,
RowParallelLinear, hf_model_weights_iterator)
VocabParallelEmbedding)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -30,20 +32,17 @@ class InternLMMLP(nn.Module): ...@@ -30,20 +32,17 @@ class InternLMMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, hidden_size, [intermediate_size] * 2,
2 * intermediate_size,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method)
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -65,6 +64,7 @@ class InternLMAttention(nn.Module): ...@@ -65,6 +64,7 @@ class InternLMAttention(nn.Module):
bias: bool, bias: bool,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -79,17 +79,18 @@ class InternLMAttention(nn.Module): ...@@ -79,17 +79,18 @@ class InternLMAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
3 * self.total_num_heads * self.head_dim, self.head_dim,
self.total_num_heads,
bias=bias, bias=bias,
gather_output=False, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=bias, bias=bias,
input_is_parallel=True, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
self.num_heads, self.num_heads,
...@@ -118,7 +119,11 @@ class InternLMAttention(nn.Module): ...@@ -118,7 +119,11 @@ class InternLMAttention(nn.Module):
class InternLMDecoderLayer(nn.Module): class InternLMDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -130,11 +135,13 @@ class InternLMDecoderLayer(nn.Module): ...@@ -130,11 +135,13 @@ class InternLMDecoderLayer(nn.Module):
bias=config.bias, bias=config.bias,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
) )
self.mlp = InternLMMLP( self.mlp = InternLMMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -171,7 +178,11 @@ class InternLMDecoderLayer(nn.Module): ...@@ -171,7 +178,11 @@ class InternLMDecoderLayer(nn.Module):
class InternLMModel(nn.Module): class InternLMModel(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -183,7 +194,7 @@ class InternLMModel(nn.Module): ...@@ -183,7 +194,7 @@ class InternLMModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternLMDecoderLayer(config) InternLMDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -216,17 +227,16 @@ class InternLMModel(nn.Module): ...@@ -216,17 +227,16 @@ class InternLMModel(nn.Module):
class InternLMForCausalLM(nn.Module): class InternLMForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = InternLMModel(config) self.linear_method = linear_method
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.model = InternLMModel(config, linear_method)
self.lm_head = ColumnParallelLinear( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
...@@ -243,69 +253,33 @@ class InternLMForCausalLM(nn.Module): ...@@ -243,69 +253,33 @@ class InternLMForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() stacked_params_mapping = [
state_dict = self.state_dict() # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = params_dict[name.replace(weight_name, param_name)]
shard_size = param.shape[0] // 2 weight_loader = param.weight_loader
loaded_weight = loaded_weight[ weight_loader(param, loaded_weight, shard_id)
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break break
if is_gate_up_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
load_tensor_parallel_weights(param, loaded_weight, name, weight_loader(param, loaded_weight)
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
...@@ -33,17 +33,19 @@ from transformers import LlamaConfig ...@@ -33,17 +33,19 @@ from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.model_executor.quantization_utils import QuantizationConfig hf_model_weights_iterator)
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -56,19 +58,17 @@ class LlamaMLP(nn.Module): ...@@ -56,19 +58,17 @@ class LlamaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = ParallelLinear.column(hidden_size, self.gate_up_proj = MergedColumnParallelLinear(
2 * intermediate_size, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -91,7 +91,7 @@ class LlamaAttention(nn.Module): ...@@ -91,7 +91,7 @@ class LlamaAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -109,7 +109,6 @@ class LlamaAttention(nn.Module): ...@@ -109,7 +109,6 @@ class LlamaAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0 assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
...@@ -117,21 +116,19 @@ class LlamaAttention(nn.Module): ...@@ -117,21 +116,19 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ParallelLinear.column( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
(self.total_num_heads +
2 * self.total_num_kv_heads * num_kv_heads_replicas) *
self.head_dim, self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
quant_config=quant_config,
) )
self.o_proj = ParallelLinear.row( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
self.num_heads, self.num_heads,
...@@ -165,11 +162,10 @@ class LlamaDecoderLayer(nn.Module): ...@@ -165,11 +162,10 @@ class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
...@@ -181,13 +177,13 @@ class LlamaDecoderLayer(nn.Module): ...@@ -181,13 +177,13 @@ class LlamaDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, linear_method=linear_method,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -227,20 +223,18 @@ class LlamaModel(nn.Module): ...@@ -227,20 +223,18 @@ class LlamaModel(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config) LlamaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -276,19 +270,13 @@ class LlamaForCausalLM(nn.Module): ...@@ -276,19 +270,13 @@ class LlamaForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.linear_method = linear_method
self.model = LlamaModel(config, quant_config) self.model = LlamaModel(config, linear_method)
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
# NOTE: The LM head is not quantized.
self.lm_head = ParallelLinear.column(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
...@@ -305,124 +293,33 @@ class LlamaForCausalLM(nn.Module): ...@@ -305,124 +293,33 @@ class LlamaForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_layers = []
_row_parallel_layers = ["o_proj", "down_proj"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
if self.quant_config is None: stacked_params_mapping = [
col_weight_suffixes = ["weight"] # (param_name, shard_name, shard_id)
row_weight_suffixes = ["weight"] ("qkv_proj", "q_proj", "q"),
else: ("qkv_proj", "k_proj", "k"),
col_weight_suffixes = ( ("qkv_proj", "v_proj", "v"),
self.quant_config.get_col_parallel_tensor_names()) ("gate_up_proj", "gate_proj", 0),
row_weight_suffixes = ( ("gate_up_proj", "up_proj", 1),
self.quant_config.get_row_parallel_tensor_names())
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in col_weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in row_weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
num_kv_heads_replicas = max(1,
tp_size // self.config.num_key_value_heads)
num_kv_heads_per_gpu = max(1,
self.config.num_key_value_heads // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
num_kv_heads_per_gpu)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
] ]
state_dict = self.state_dict() params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
packed_dim = None
is_transposed = False
if self.quant_config is not None:
packed_dim = self.quant_config.get_packed_dim(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight = loaded_weight.T
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
if is_transposed:
param = param.T
if packed_dim is not None:
shard_dim = 0 if not is_transposed else 1
if packed_dim == shard_dim:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
if weight_name in ["k_proj", "v_proj"]:
shard_id = tp_rank // num_kv_heads_replicas
else:
shard_id = tp_rank
loaded_weight = loaded_weight[shard_size *
shard_id:shard_size *
(shard_id + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = params_dict[name.replace(weight_name, param_name)]
if is_transposed: weight_loader = param.weight_loader
param = param.T weight_loader(param, loaded_weight, shard_id)
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break break
if is_gate_up_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
if is_transposed: weight_loader(param, loaded_weight)
param = param.T
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
column_parallel_weights,
row_parallel_weights, tp_rank)
...@@ -33,17 +33,19 @@ from transformers import MistralConfig ...@@ -33,17 +33,19 @@ from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.model_executor.quantization_utils import QuantizationConfig hf_model_weights_iterator)
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -56,19 +58,17 @@ class MistralMLP(nn.Module): ...@@ -56,19 +58,17 @@ class MistralMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = ParallelLinear.column(hidden_size, self.gate_up_proj = MergedColumnParallelLinear(
2 * intermediate_size, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
gather_output=False, linear_method=linear_method)
quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size,
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -89,7 +89,7 @@ class MistralAttention(nn.Module): ...@@ -89,7 +89,7 @@ class MistralAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -98,8 +98,15 @@ class MistralAttention(nn.Module): ...@@ -98,8 +98,15 @@ class MistralAttention(nn.Module):
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0 assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
...@@ -107,20 +114,19 @@ class MistralAttention(nn.Module): ...@@ -107,20 +114,19 @@ class MistralAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.qkv_proj = ParallelLinear.column( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False, bias=False,
gather_output=False, linear_method=linear_method,
quant_config=quant_config,
) )
self.o_proj = ParallelLinear.row( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, linear_method=linear_method,
quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim, self.head_dim,
...@@ -153,7 +159,7 @@ class MistralDecoderLayer(nn.Module): ...@@ -153,7 +159,7 @@ class MistralDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -165,13 +171,13 @@ class MistralDecoderLayer(nn.Module): ...@@ -165,13 +171,13 @@ class MistralDecoderLayer(nn.Module):
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
quant_config=quant_config, linear_method=linear_method,
sliding_window=config.sliding_window) sliding_window=config.sliding_window)
self.mlp = MistralMLP( self.mlp = MistralMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -211,20 +217,19 @@ class MistralModel(nn.Module): ...@@ -211,20 +217,19 @@ class MistralModel(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MistralDecoderLayer(config, quant_config) MistralDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -260,19 +265,13 @@ class MistralForCausalLM(nn.Module): ...@@ -260,19 +265,13 @@ class MistralForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.linear_method = linear_method
self.model = MistralModel(config, quant_config) self.model = MistralModel(config, linear_method)
vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
# NOTE: The LM head is not quantized.
self.lm_head = ParallelLinear.column(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
...@@ -289,118 +288,33 @@ class MistralForCausalLM(nn.Module): ...@@ -289,118 +288,33 @@ class MistralForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_layers = []
_row_parallel_layers = ["o_proj", "down_proj"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
if self.quant_config is None: stacked_params_mapping = [
col_weight_suffixes = ["weight"] # (param_name, shard_name, shard_id)
row_weight_suffixes = ["weight"] ("qkv_proj", "q_proj", "q"),
else: ("qkv_proj", "k_proj", "k"),
col_weight_suffixes = ( ("qkv_proj", "v_proj", "v"),
self.quant_config.get_col_parallel_tensor_names()) ("gate_up_proj", "gate_proj", 0),
row_weight_suffixes = ( ("gate_up_proj", "up_proj", 1),
self.quant_config.get_row_parallel_tensor_names())
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in col_weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in row_weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
] ]
state_dict = self.state_dict() params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
packed_dim = None
is_transposed = False
if self.quant_config is not None:
packed_dim = self.quant_config.get_packed_dim(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight = loaded_weight.T
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "qkv_proj")] param = params_dict[name.replace(weight_name, param_name)]
if is_transposed: weight_loader = param.weight_loader
param = param.T weight_loader(param, loaded_weight, shard_id)
if packed_dim is not None:
shard_dim = 0 if not is_transposed else 1
if packed_dim == shard_dim:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break break
if is_gate_up_weight: else:
continue param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
param = state_dict[name] default_weight_loader)
if is_transposed: weight_loader(param, loaded_weight)
param = param.T
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
column_parallel_weights,
row_parallel_weights,
tensor_model_parallel_rank)
...@@ -10,15 +10,17 @@ from transformers import MptConfig ...@@ -10,15 +10,17 @@ from transformers import MptConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, from vllm.model_executor.layers.vocab_parallel_embedding import (
hf_model_weights_iterator, VocabParallelEmbedding)
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, from vllm.model_executor.weight_utils import (default_weight_loader,
ColumnParallelLinear, hf_model_weights_iterator)
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -39,7 +41,11 @@ def _get_alibi_slopes( ...@@ -39,7 +41,11 @@ def _get_alibi_slopes(
class MptAttention(nn.Module): class MptAttention(nn.Module):
def __init__(self, config: MptConfig): def __init__(
self,
config: MptConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
self.total_num_heads = config.n_heads self.total_num_heads = config.n_heads
...@@ -49,11 +55,13 @@ class MptAttention(nn.Module): ...@@ -49,11 +55,13 @@ class MptAttention(nn.Module):
assert not config.attn_config.prefix_lm assert not config.attn_config.prefix_lm
assert config.attn_config.alibi assert config.attn_config.alibi
self.qkv_proj = ColumnParallelLinear( # pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
self.d_model, self.d_model,
3 * self.d_model, self.d_model // self.total_num_heads,
self.total_num_heads,
bias=not config.no_bias, bias=not config.no_bias,
gather_output=False, linear_method=linear_method,
) )
if self.qk_ln: if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model) self.q_ln = nn.LayerNorm(self.d_model)
...@@ -62,7 +70,7 @@ class MptAttention(nn.Module): ...@@ -62,7 +70,7 @@ class MptAttention(nn.Module):
self.d_model, self.d_model,
self.d_model, self.d_model,
bias=not config.no_bias, bias=not config.no_bias,
input_is_parallel=True, linear_method=linear_method,
) )
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
...@@ -91,7 +99,7 @@ class MptAttention(nn.Module): ...@@ -91,7 +99,7 @@ class MptAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # unused. del position_ids # unused.
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None: if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
...@@ -107,7 +115,11 @@ class MptAttention(nn.Module): ...@@ -107,7 +115,11 @@ class MptAttention(nn.Module):
class MptMLP(nn.Module): class MptMLP(nn.Module):
def __init__(self, config: MptConfig): def __init__(
self,
config: MptConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
expansion_ratio = config.expansion_ratio expansion_ratio = config.expansion_ratio
...@@ -116,14 +128,14 @@ class MptMLP(nn.Module): ...@@ -116,14 +128,14 @@ class MptMLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=not config.no_bias, bias=not config.no_bias,
gather_output=False, linear_method=linear_method,
) )
self.act = get_act_fn("gelu") self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=not config.no_bias, bias=not config.no_bias,
input_is_parallel=True, linear_method=linear_method,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -135,13 +147,17 @@ class MptMLP(nn.Module): ...@@ -135,13 +147,17 @@ class MptMLP(nn.Module):
class MptBlock(nn.Module): class MptBlock(nn.Module):
def __init__(self, config: MptConfig): def __init__(
self,
config: MptConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size) self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MptAttention(config) self.attn = MptAttention(config, linear_method)
self.norm_2 = nn.LayerNorm(hidden_size) self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MptMLP(config) self.ffn = MptMLP(config, linear_method)
def forward( def forward(
self, self,
...@@ -168,7 +184,11 @@ class MptBlock(nn.Module): ...@@ -168,7 +184,11 @@ class MptBlock(nn.Module):
class MptModel(nn.Module): class MptModel(nn.Module):
def __init__(self, config: MptConfig): def __init__(
self,
config: MptConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
assert config.embedding_fraction == 1.0 assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm" assert config.norm_type == "low_precision_layernorm"
...@@ -178,7 +198,7 @@ class MptModel(nn.Module): ...@@ -178,7 +198,7 @@ class MptModel(nn.Module):
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[MptBlock(config) for _ in range(config.n_layers)]) [MptBlock(config, linear_method) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model) self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias: if config.no_bias:
for module in self.modules(): for module in self.modules():
...@@ -215,14 +235,17 @@ class MptModel(nn.Module): ...@@ -215,14 +235,17 @@ class MptModel(nn.Module):
class MptForCausalLM(nn.Module): class MptForCausalLM(nn.Module):
def __init__(self, config: MptConfig): def __init__(
self,
config: MptConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.linear_method = linear_method
self.transformer = MptModel(config) self.transformer = MptModel(config, linear_method)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -240,45 +263,15 @@ class MptForCausalLM(nn.Module): ...@@ -240,45 +263,15 @@ class MptForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() params_dict = dict(self.named_parameters(remove_duplicate=False))
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "Wqkv" in name: param = params_dict[name]
# NOTE(woosuk): MPT's fused QKV has the shape of weight_loader = getattr(param, "weight_loader",
# [3 * num_heads * head_size, hidden_size]. default_weight_loader)
# When tensor model parallelism is used, we need to shard weight_loader(param, loaded_weight)
# the weight along the hidden dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
name = name.replace("Wqkv", "qkv_proj")
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import List, Sequence from typing import Sequence
import torch import torch
...@@ -24,7 +24,7 @@ def split_tensor_along_last_dim( ...@@ -24,7 +24,7 @@ def split_tensor_along_last_dim(
tensor: torch.Tensor, tensor: torch.Tensor,
num_partitions: int, num_partitions: int,
contiguous_split_chunks: bool = False, contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]: ) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension. """ Split a tensor along its last dimension.
Arguments: Arguments:
...@@ -46,25 +46,3 @@ def split_tensor_along_last_dim( ...@@ -46,25 +46,3 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list) return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list return tensor_list
class VocabUtility:
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank: int) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank)
from typing import Type
from vllm.model_executor.quantization_utils.awq import AWQConfig
from vllm.model_executor.quantization_utils.base import QuantizationConfig
from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig
_QUANTIZATION_REGISTRY = {
"awq": AWQConfig,
"squeezellm": SqueezeLLMConfig,
}
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quant_class",
]
from typing import Any, Dict, List
import torch
from vllm.model_executor.quantization_utils.base import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
@classmethod
def get_name(cls) -> str:
return "awq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
@classmethod
def get_config_filenames(cls) -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
@classmethod
def get_packed_tensors(cls) -> Dict[str, int]:
return {"qweight": 1, "qzeros": 1}
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_col_parallel_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_row_parallel_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
This diff is collapsed.
"""Utils for model executor.""" """Utils for model executor."""
import random import random
from typing import Any, Dict, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -11,3 +12,24 @@ def set_random_seed(seed: int) -> None: ...@@ -11,3 +12,24 @@ def set_random_seed(seed: int) -> None:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def set_weight_attrs(
weight: torch.Tensor,
weight_attrs: Optional[Dict[str, Any]],
):
"""Set attributes on a weight tensor.
This method is used to set attributes on a weight tensor. This method
will not overwrite existing attributes.
Args:
weight: The weight tensor.
weight_attrs: A dictionary of attributes to set on the weight tensor.
"""
if weight_attrs is None:
return
for key, value in weight_attrs.items():
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment