Commit e00b0a19 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.3

parents ead94d93 3f1166ab
# -*- coding: utf-8 -*- # coding=utf-8
from typing import Any, Dict, List, Optional, Tuple # Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import GemmaConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -16,7 +32,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -16,7 +32,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -27,15 +43,14 @@ from vllm.sequence import SamplerOutput ...@@ -27,15 +43,14 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
class InternLMMLP(nn.Module): class GemmaMLP(nn.Module):
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
...@@ -45,10 +60,7 @@ class InternLMMLP(nn.Module): ...@@ -45,10 +60,7 @@ class InternLMMLP(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) linear_method=linear_method)
if hidden_act != "silu": self.act_fn = GeluAndMul()
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
...@@ -57,52 +69,64 @@ class InternLMMLP(nn.Module): ...@@ -57,52 +69,64 @@ class InternLMMLP(nn.Module):
return x return x
class InternLMAttention(nn.Module): class GemmaAttention(nn.Module):
def __init__( def __init__(self,
self, hidden_size: int,
hidden_size: int, num_heads: int,
num_heads: int, num_kv_heads: int,
bias: bool, head_dim: int,
rope_theta: float = 10000, max_position_embeddings: int = 8192,
max_position_embeddings: int = 8192, rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None) -> None:
rope_scaling: Optional[Dict[str, Any]] = None,
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = ( tp_size = get_tensor_model_parallel_world_size()
get_tensor_model_parallel_world_size())
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = (self.total_num_heads // self.num_heads = self.total_num_heads // tp_size
tensor_model_parallel_world_size) self.total_num_kv_heads = num_kv_heads
self.head_dim = hidden_size // self.total_num_heads if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=bias, self.total_num_kv_heads,
bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=bias, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=max_position_embeddings,
base=self.rope_theta, base=self.rope_theta,
rope_scaling=rope_scaling, is_neox_style=True,
) )
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
...@@ -112,7 +136,7 @@ class InternLMAttention(nn.Module): ...@@ -112,7 +136,7 @@ class InternLMAttention(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
...@@ -120,31 +144,27 @@ class InternLMAttention(nn.Module): ...@@ -120,31 +144,27 @@ class InternLMAttention(nn.Module):
return output return output
class InternLMDecoderLayer(nn.Module): class GemmaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = GemmaAttention(
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = InternLMAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
bias=config.bias, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, head_dim=config.head_dim,
max_position_embeddings=max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
linear_method=linear_method, linear_method=linear_method,
rope_scaling=getattr(config, "rope_scaling", None),
) )
self.mlp = InternLMMLP( self.mlp = GemmaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method, linear_method=linear_method,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
...@@ -181,25 +201,22 @@ class InternLMDecoderLayer(nn.Module): ...@@ -181,25 +201,22 @@ class InternLMDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
class InternLMModel(nn.Module): class GemmaModel(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternLMDecoderLayer(config, linear_method) GemmaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -212,6 +229,9 @@ class InternLMModel(nn.Module): ...@@ -212,6 +229,9 @@ class InternLMModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Normalize the embedding by sqrt(hidden_size)
hidden_states *= self.config.hidden_size**0.5
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
...@@ -226,20 +246,44 @@ class InternLMModel(nn.Module): ...@@ -226,20 +246,44 @@ class InternLMModel(nn.Module):
return hidden_states return hidden_states
class InternLMForCausalLM(nn.Module): class GemmaForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
config, config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = InternLMModel(config, linear_method) self.model = GemmaModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -256,8 +300,8 @@ class InternLMForCausalLM(nn.Module): ...@@ -256,8 +300,8 @@ class InternLMForCausalLM(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.model.embed_tokens.weight,
sampling_metadata) hidden_states, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
...@@ -274,26 +318,29 @@ class InternLMForCausalLM(nn.Module): ...@@ -274,26 +318,29 @@ class InternLMForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: for (param_name, shard_name, shard_id) in stacked_params_mapping:
continue if shard_name not in name:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue continue
name = name.replace(shard_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models. # GemmaRMSNorm is different from Llama's in that it multiplies
if name.endswith(".bias") and name not in params_dict: # (1 + weight) to the output, instead of just weight.
continue if "norm.weight" in name:
loaded_weight += 1.0
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")
# coding=utf-8 # -*- coding: utf-8 -*-
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -43,12 +23,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -43,12 +23,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
class AquilaMLP(nn.Module): class InternLM2MLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -56,16 +35,16 @@ class AquilaMLP(nn.Module): ...@@ -56,16 +35,16 @@ class AquilaMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size, self.w2 = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) linear_method=linear_method)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -74,31 +53,11 @@ class AquilaMLP(nn.Module): ...@@ -74,31 +53,11 @@ class AquilaMLP(nn.Module):
def forward(self, x): def forward(self, x):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.w2(x)
return x return x
class AquilaRMSNorm(nn.Module): class InternLM2Attention(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
AquilaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class AquilaAttention(nn.Module):
def __init__( def __init__(
self, self,
...@@ -106,10 +65,10 @@ class AquilaAttention(nn.Module): ...@@ -106,10 +65,10 @@ class AquilaAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -117,8 +76,15 @@ class AquilaAttention(nn.Module): ...@@ -117,8 +76,15 @@ class AquilaAttention(nn.Module):
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0 if self.total_num_kv_heads >= tp_size:
self.num_kv_heads = self.total_num_kv_heads // tp_size # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
...@@ -126,7 +92,7 @@ class AquilaAttention(nn.Module): ...@@ -126,7 +92,7 @@ class AquilaAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear( self.wqkv = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
...@@ -134,17 +100,18 @@ class AquilaAttention(nn.Module): ...@@ -134,17 +100,18 @@ class AquilaAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=max_position_embeddings,
base=self.rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
...@@ -159,47 +126,46 @@ class AquilaAttention(nn.Module): ...@@ -159,47 +126,46 @@ class AquilaAttention(nn.Module):
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.wo(attn_output)
return output return output
class AquilaDecoderLayer(nn.Module): class InternLMDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: AquilaConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.self_attn = AquilaAttention( self.attention = InternLM2Attention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method, linear_method=linear_method,
) )
self.mlp = AquilaMLP( self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, linear_method=linear_method,
) )
self.input_layernorm = AquilaRMSNorm(config.hidden_size, self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
...@@ -207,46 +173,48 @@ class AquilaDecoderLayer(nn.Module): ...@@ -207,46 +173,48 @@ class AquilaDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states if residual is None:
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
hidden_states = self.self_attn( hidden_states = self.attention_norm(hidden_states)
else:
hidden_states, residual = self.attention_norm(
hidden_states, residual)
hidden_states = self.attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
) )
hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states hidden_states, residual = self.ffn_norm(hidden_states, residual)
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.feed_forward(hidden_states)
hidden_states = self.mlp(hidden_states) return hidden_states, residual
hidden_states = residual + hidden_states
return hidden_states
class AquilaModel(nn.Module): class InternLM2Model(nn.Module):
def __init__( def __init__(
self, self,
config: AquilaConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.tok_embeddings = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
AquilaDecoderLayer(config, linear_method) InternLMDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
self, self,
...@@ -255,32 +223,33 @@ class AquilaModel(nn.Module): ...@@ -255,32 +223,33 @@ class AquilaModel(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.tok_embeddings(input_ids)
residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
residual,
) )
hidden_states = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class AquilaForCausalLM(nn.Module): class InternLM2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = AquilaModel(config, linear_method) self.model = InternLM2Model(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
...@@ -299,7 +268,7 @@ class AquilaForCausalLM(nn.Module): ...@@ -299,7 +268,7 @@ class AquilaForCausalLM(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.output.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
...@@ -310,11 +279,8 @@ class AquilaForCausalLM(nn.Module): ...@@ -310,11 +279,8 @@ class AquilaForCausalLM(nn.Module):
revision: Optional[str] = None): revision: Optional[str] = None):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("gate_up_proj", "w1", 0),
("qkv_proj", "k_proj", "k"), ("gate_up_proj", "w3", 1),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
...@@ -337,6 +303,23 @@ class AquilaForCausalLM(nn.Module): ...@@ -337,6 +303,23 @@ class AquilaForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", if "wqkv" in name:
default_weight_loader) config = self.config
weight_loader(param, loaded_weight) kv_groups = config.num_attention_heads // config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim,
loaded_weight.shape[-1])
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, 'q')
weight_loader(param, wk, 'k')
weight_loader(param, wv, 'v')
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -27,6 +27,7 @@ import torch ...@@ -27,6 +27,7 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
...@@ -38,7 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -38,7 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -90,6 +91,8 @@ class LlamaAttention(nn.Module): ...@@ -90,6 +91,8 @@ class LlamaAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -119,13 +122,13 @@ class LlamaAttention(nn.Module): ...@@ -119,13 +122,13 @@ class LlamaAttention(nn.Module):
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=bias,
linear_method=linear_method, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=bias,
linear_method=linear_method, linear_method=linear_method,
) )
...@@ -139,7 +142,8 @@ class LlamaAttention(nn.Module): ...@@ -139,7 +142,8 @@ class LlamaAttention(nn.Module):
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
def forward( def forward(
self, self,
...@@ -170,14 +174,18 @@ class LlamaDecoderLayer(nn.Module): ...@@ -170,14 +174,18 @@ class LlamaDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
sliding_window = getattr(config, "sliding_window", None)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, linear_method=linear_method,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -225,14 +233,19 @@ class LlamaModel(nn.Module): ...@@ -225,14 +233,19 @@ class LlamaModel(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method) LlamaDecoderLayer(config, linear_method)
...@@ -263,18 +276,56 @@ class LlamaModel(nn.Module): ...@@ -263,18 +276,56 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = LlamaModel(config, linear_method) self.model = LlamaModel(config, linear_method, lora_config=lora_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.unpadded_vocab_size = config.vocab_size
self.sampler = Sampler(config.vocab_size) if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
def forward( def forward(
self, self,
......
...@@ -23,30 +23,29 @@ ...@@ -23,30 +23,29 @@
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -54,110 +53,94 @@ from vllm.sequence import SamplerOutput ...@@ -54,110 +53,94 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
class MixtralMLP(nn.Module): class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__( def __init__(
self, self,
num_experts: int, num_experts: int,
top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None, params_dtype: Optional[torch.dtype] = None,
) -> None: tp_size: Optional[int] = None,
):
super().__init__() super().__init__()
self.num_experts = num_experts self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.ffn_dim = intermediate_size self.num_total_experts = num_experts
self.hidden_dim = hidden_size self.top_k = top_k
self.hidden_size = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim, self.intermediate_size = intermediate_size // self.tp_size
self.ffn_dim,
bias=False,
linear_method=linear_method)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
linear_method=linear_method)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module): if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
def __init__( self.gate = ReplicatedLinear(self.hidden_size,
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
params_dtype=self.params_dtype,
linear_method=None) linear_method=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape batch_size, sequence_length, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts) # router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) if self.tp_size > 1:
routing_weights, selected_experts = torch.topk(routing_weights, final_hidden_states = tensor_model_parallel_all_reduce(
self.top_k, final_hidden_states)
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view( return final_hidden_states.view(batch_size, sequence_length,
batch_size, sequence_length, hidden_dim) hidden_size)
class MixtralAttention(nn.Module): class MixtralAttention(nn.Module):
...@@ -257,8 +240,11 @@ class MixtralDecoderLayer(nn.Module): ...@@ -257,8 +240,11 @@ class MixtralDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
linear_method=linear_method) linear_method=linear_method)
self.block_sparse_moe = MixtralMoE(config=config, self.block_sparse_moe = MixtralMoE(
linear_method=linear_method) num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size)
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
...@@ -299,14 +285,19 @@ class MixtralModel(nn.Module): ...@@ -299,14 +285,19 @@ class MixtralModel(nn.Module):
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method) MixtralDecoderLayer(config, linear_method=linear_method)
...@@ -333,18 +324,52 @@ class MixtralModel(nn.Module): ...@@ -333,18 +324,52 @@ class MixtralModel(nn.Module):
class MixtralForCausalLM(nn.Module): class MixtralForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = MixtralModel(config, linear_method) self.model = MixtralModel(config,
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) linear_method,
self.sampler = Sampler(config.vocab_size) lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
def forward( def forward(
self, self,
...@@ -378,6 +403,14 @@ class MixtralForCausalLM(nn.Module): ...@@ -378,6 +403,14 @@ class MixtralForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, model_name_or_path,
...@@ -387,6 +420,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -387,6 +420,7 @@ class MixtralForCausalLM(nn.Module):
fall_back_to_pt=False): fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -399,14 +433,22 @@ class MixtralForCausalLM(nn.Module): ...@@ -399,14 +433,22 @@ class MixtralForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models. for param_name, weight_name, expert_id in expert_params_mapping:
if name.endswith(".bias") and name not in params_dict: if weight_name not in name:
continue continue
# Skip experts that are not assigned to this worker. name = name.replace(weight_name, param_name)
if ("block_sparse_moe.experts." in name param = params_dict[name]
and name not in params_dict): weight_loader = param.weight_loader
continue weight_loader(param,
param = params_dict[name] loaded_weight,
weight_loader = getattr(param, "weight_loader", weight_name,
default_weight_loader) expert_id=expert_id)
weight_loader(param, loaded_weight) break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
linear_method=linear_method)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts,
bias=False,
linear_method=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
batch_size, sequence_length, hidden_dim)
class MixtralAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = PagedAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method)
self.block_sparse_moe = MixtralMoE(config=config,
linear_method=linear_method)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import os
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
with torch.inference_mode():
block_size = self.model.context_buckets[-1]
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables
logits = self.model(input_ids,
cache_ids=positions,
start_ids=seq_ids.flatten())
return logits
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.llama.model import LlamaForSampling
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
from transformers.models.llama import LlamaForCausalLM
from transformers_neuronx.module import save_pretrained_split
hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = LlamaForSampling.from_pretrained(split_model_dir,
**kwargs)
self.model.to_neuron()
# coding=utf-8
# Adapted from
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, )
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.sequence import SamplerOutput
# this model must need this dependency
from hf_olmo import OLMoConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@property
def output_multiplier(self) -> float:
return 0.5
class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.d_model
assert config.d_model % config.n_heads == 0
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = self.config.n_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // self.total_num_heads
# Layer norms.
self.attn_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Attention input projection. Projects x -> (q, k, v)
self.att_proj = QKVParallelLinear(
config.d_model,
self.head_dim,
self.total_num_heads,
bias=config.include_bias,
linear_method=linear_method,
)
# Rotary embeddings.
if self.config.rope:
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
# Attention output projection.
self.attn_out = RowParallelLinear(
config.d_model,
config.d_model,
bias=config.include_bias,
linear_method=linear_method,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.attn_norm(hidden_states)
qkv, _ = self.att_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.config.rope:
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.attn_out(attn_output)
return output
class OlmoMLP(nn.Module):
"""
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
is not None else config.mlp_ratio * config.d_model)
# Layer norms.
self.ff_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Feed-forward input projection.
self.ff_proj = ColumnParallelLinear(
config.d_model,
self.hidden_size,
bias=config.include_bias,
linear_method=linear_method,
)
# Activation function.
# self.act = SiluAndMul()
# self.act.output_multiplier = 0.5
self.act = SwiGLU()
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Feed-forward output projection.
self.ff_out = RowParallelLinear(
int(self.act.output_multiplier * self.hidden_size),
config.d_model,
bias=config.include_bias,
linear_method=linear_method,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self.ff_norm(x)
x, _ = self.ff_proj(x)
x = self.act(x)
x, _ = self.ff_out(x)
x = og_x + x
return x
class OlmoBlock(nn.Module):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
# Attention block.
self.attn = OlmoAttention(config, linear_method)
# MLP block.
self.mlp = OlmoMLP(config, linear_method)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block.
og_x = hidden_states
x = self.attn(positions, hidden_states, kv_cache, input_metadata)
x = x + og_x
# MLP block.
hidden_states = self.mlp(x)
return hidden_states
class OlmoModel(nn.Module):
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(
dict(
wte=VocabParallelEmbedding(
config.embedding_size or config.vocab_size,
config.d_model,
),
ln_f=nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False),
))
blocks = [
OlmoBlock(config, linear_method) for i in range(config.n_layers)
]
if self.config.block_group_size > 1:
raise NotImplementedError("Block group size > 1 not supported yet")
else:
self.transformer.update({"blocks": nn.ModuleList(blocks)})
if not config.weight_tying:
self.transformer.update({
"ff_out":
ColumnParallelLinear(
config.d_model,
config.embedding_size or config.vocab_size,
bias=config.include_bias,
linear_method=linear_method,
)
})
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore
# Apply blocks one-by-one.
for block_idx, block in enumerate(self.transformer.blocks):
# shape: (batch_size, seq_len, d_model)
x = block(
positions,
x,
kv_caches[block_idx],
input_metadata,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
x = self.transformer.ln_f(x) # type: ignore
return x
class OLMoForCausalLM(nn.Module):
"""
Extremely barebones HF model wrapper.
"""
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = OlmoModel(config, linear_method)
self.lm_head_weight = (self.model.transformer.wte.weight
if config.weight_tying else
self.model.transformer.ff_out.weight)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
input_metadata=input_metadata,
)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
# attention
if ".att" in name:
name = name.replace(".att", ".attn.att")
# mlp
if ".ff" in name and "transformer.ff_out" not in name:
name = name.replace(".ff", ".mlp.ff")
# there is no bias in olmo
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# coding=utf-8 # coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright 2023 The vLLM team. # Copyright (c) OrionStar Inc.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
# """Inference-only Orion-14B model compatible with HuggingFace weights."""
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Yi model (https://01.ai) compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.transformers_utils.configs.yi import YiConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -49,7 +31,7 @@ from vllm.sequence import SamplerOutput ...@@ -49,7 +31,7 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
class YiMLP(nn.Module): class OrionMLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -79,7 +61,7 @@ class YiMLP(nn.Module): ...@@ -79,7 +61,7 @@ class YiMLP(nn.Module):
return x return x
class YiAttention(nn.Module): class OrionAttention(nn.Module):
def __init__( def __init__(
self, self,
...@@ -128,11 +110,12 @@ class YiAttention(nn.Module): ...@@ -128,11 +110,12 @@ class YiAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=self.rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
...@@ -156,11 +139,11 @@ class YiAttention(nn.Module): ...@@ -156,11 +139,11 @@ class YiAttention(nn.Module):
return output return output
class YiDecoderLayer(nn.Module): class OrionDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: YiConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -169,7 +152,7 @@ class YiDecoderLayer(nn.Module): ...@@ -169,7 +152,7 @@ class YiDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.self_attn = YiAttention( self.self_attn = OrionAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
...@@ -178,14 +161,17 @@ class YiDecoderLayer(nn.Module): ...@@ -178,14 +161,17 @@ class YiDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, linear_method=linear_method,
) )
self.mlp = YiMLP( self.mlp = OrionMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, linear_method=linear_method,
) )
self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
...@@ -196,11 +182,8 @@ class YiDecoderLayer(nn.Module): ...@@ -196,11 +182,8 @@ class YiDecoderLayer(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: residual = hidden_states
residual = hidden_states hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.ln1(hidden_states)
else:
hidden_states, residual = self.ln1(hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -208,17 +191,21 @@ class YiDecoderLayer(nn.Module): ...@@ -208,17 +191,21 @@ class YiDecoderLayer(nn.Module):
input_metadata=input_metadata, input_metadata=input_metadata,
) )
hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
hidden_states, residual = self.ln2(hidden_states, residual) residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
return hidden_states, residual hidden_states = residual + hidden_states
return hidden_states, None
class YiModel(nn.Module): class OrionModel(nn.Module):
def __init__( def __init__(
self, self,
config: YiConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -230,10 +217,10 @@ class YiModel(nn.Module): ...@@ -230,10 +217,10 @@ class YiModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
YiDecoderLayer(config, linear_method) OrionDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
self, self,
...@@ -253,21 +240,21 @@ class YiModel(nn.Module): ...@@ -253,21 +240,21 @@ class YiModel(nn.Module):
input_metadata, input_metadata,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class YiForCausalLM(nn.Module): class OrionForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: YiConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = YiModel(config, linear_method) self.model = OrionModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -309,6 +296,11 @@ class YiForCausalLM(nn.Module): ...@@ -309,6 +296,11 @@ class YiForCausalLM(nn.Module):
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -27,7 +28,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -27,7 +28,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -127,7 +127,7 @@ class QWenBlock(nn.Module): ...@@ -127,7 +127,7 @@ class QWenBlock(nn.Module):
def __init__( def __init__(
self, self,
config: QWenConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
...@@ -179,7 +179,7 @@ class QWenModel(nn.Module): ...@@ -179,7 +179,7 @@ class QWenModel(nn.Module):
def __init__( def __init__(
self, self,
config: QWenConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
...@@ -222,7 +222,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -222,7 +222,7 @@ class QWenLMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: QWenConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
......
# coding=utf-8 # coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# #
...@@ -20,12 +21,12 @@ ...@@ -20,12 +21,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mistral model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import MistralConfig from transformers import Qwen2Config
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -49,7 +50,7 @@ from vllm.sequence import SamplerOutput ...@@ -49,7 +50,7 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
class MistralMLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -79,7 +80,7 @@ class MistralMLP(nn.Module): ...@@ -79,7 +80,7 @@ class MistralMLP(nn.Module):
return x return x
class MistralAttention(nn.Module): class Qwen2Attention(nn.Module):
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
...@@ -87,6 +88,7 @@ class MistralAttention(nn.Module): ...@@ -87,6 +88,7 @@ class MistralAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
use_sliding_window: bool = False,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
...@@ -110,14 +112,14 @@ class MistralAttention(nn.Module): ...@@ -110,14 +112,14 @@ class MistralAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.sliding_window = sliding_window self.sliding_window = sliding_window if use_sliding_window else None
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=True,
linear_method=linear_method, linear_method=linear_method,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -155,26 +157,29 @@ class MistralAttention(nn.Module): ...@@ -155,26 +157,29 @@ class MistralAttention(nn.Module):
return output return output
class MistralDecoderLayer(nn.Module): class Qwen2DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: Qwen2Config,
layer_idx: int,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 1000000)
self.self_attn = MistralAttention( use_sliding_window = config.use_sliding_window and layer_idx < config.max_window_layers
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
linear_method=linear_method, linear_method=linear_method,
sliding_window=config.sliding_window) sliding_window=config.sliding_window)
self.mlp = MistralMLP( self.mlp = Qwen2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -214,11 +219,11 @@ class MistralDecoderLayer(nn.Module): ...@@ -214,11 +219,11 @@ class MistralDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
class MistralModel(nn.Module): class Qwen2Model(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -231,8 +236,8 @@ class MistralModel(nn.Module): ...@@ -231,8 +236,8 @@ class MistralModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MistralDecoderLayer(config, linear_method) Qwen2DecoderLayer(config, layer_idx, linear_method)
for _ in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -258,17 +263,17 @@ class MistralModel(nn.Module): ...@@ -258,17 +263,17 @@ class MistralModel(nn.Module):
return hidden_states return hidden_states
class MistralForCausalLM(nn.Module): class Qwen2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MistralConfig, config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = MistralModel(config, linear_method) self.model = Qwen2Model(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
......
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class StablelmMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=False)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
self.num_heads = self.total_num_heads // tp_size
self.total_num_key_value_heads = config.num_key_value_heads
if self.total_num_key_value_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_key_value_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_key_value_heads == 0
self.num_key_value_heads = max(
1, self.total_num_key_value_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
rope_pct = getattr(config, "rope_pct",
getattr(config, "partial_rotary_factor", 1))
self.rotary_ndims = int(self.head_dim * rope_pct)
self.scaling = self.head_dim**-0.5
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_bias = getattr(config, "use_qkv_bias", False)
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
self.qkv_proj = QKVParallelLinear(self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_key_value_heads,
self.qkv_bias,
linear_method=linear_method)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
linear_method=linear_method)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
max_position=self.config.max_position_embeddings,
base=self.config.rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, linear_method)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, residual
class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
StablelmDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class StablelmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# coding=utf-8
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
try:
from transformers import Starcoder2Config
except ImportError:
# fallback to PretrainedConfig
# NOTE: Please install transformers from source or use transformers>=4.39.0
from transformers import PretrainedConfig as Starcoder2Config
KVCache = Tuple[torch.Tensor, torch.Tensor]
class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.max_position_embeddings = config.max_position_embeddings
self.use_bias = config.use_bias
self.sliding_window = config.sliding_window
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=self.use_bias,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=self.use_bias,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = PagedAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class Starcoder2MLP(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.c_fc = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=config.use_bias,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=config.use_bias,
linear_method=linear_method,
)
self.act = get_act_fn(config.hidden_act,
intermediate_size=config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config,
linear_method=linear_method)
self.mlp = Starcoder2MLP(config, linear_method=linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Starcoder2Model(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# TODO: consider padding_idx (currently removed)
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i],
input_metadata)
hidden_states = self.norm(hidden_states)
return hidden_states
class Starcoder2ForCausalLM(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.model = Starcoder2Model(config, linear_method=linear_method)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
"""Utilities for selecting and loading models."""
from typing import Type
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig, DeviceConfig
from vllm.model_executor.models import ModelRegistry
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
"half": "f16",
"float16": "f16",
"bfloat16": "bf16",
"float": "f32",
"float32": "f32",
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float32: "f32",
}
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig
parallel_config = kwargs.get("parallel_config")
scheduler_config = kwargs.get("scheduler_config")
model_class = _get_model_architecture(model_config.hf_config)
linear_method = None
# Create a model instance.
model = model_class(model_config.hf_config, linear_method)
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config)
# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model,
model_config.download_dir,
model_config.load_format,
model_config.revision,
tp_degree=parallel_config.neuron_tp_degree,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len],
n_positions=[scheduler_config.max_model_len],
batch_size=scheduler_config.max_num_seqs)
return model.eval()
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union
import torch import torch
from torch.distributed import ProcessGroup
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
is_cupy_nccl_enabled_for_all_reduce,
) )
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce
def tensor_model_parallel_all_reduce(input_): def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group. """All-reduce the input tensor across model parallel group.
NOTE: This operation is applied in-place on the input tensor. NOTE: This operation will be applied in-place on the input tensor if
disable_custom_all_reduce is set to True. Otherwise, this operation may or
may not be applied in place depending on whether custom all reduce is
invoked for a particular tensor, which further depends on the tensor size
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
""" """
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1: if get_tensor_model_parallel_world_size() == 1:
return input_ return input_
# All-reduce. out = custom_all_reduce(input_)
torch.distributed.all_reduce(input_, if out is not None:
group=get_tensor_model_parallel_group()) return out
if is_cupy_nccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_ return input_
def tensor_model_parallel_all_gather(input_, dim=-1): def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group.""" """All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
...@@ -48,7 +69,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1): ...@@ -48,7 +69,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
return output_tensor return output_tensor
def tensor_model_parallel_gather(input_, dst=0, dim=-1): def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""Gather the input tensor across model parallel group. """Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across NOTE: We assume that the input tensor is on the same device across
...@@ -80,27 +103,101 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1): ...@@ -80,27 +103,101 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
return output_tensor return output_tensor
def broadcast(input_, src=0): def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor.""" """Broadcast the input tensor."""
world_size = torch.distributed.get_world_size() group = group or torch.distributed.group.WORLD
assert 0 <= src < world_size, f"Invalid src rank ({src})" ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1: if world_size == 1:
return input_ return input_
# Broadcast. # Broadcast.
torch.distributed.broadcast(input_, src=src) torch.distributed.broadcast(input_, src=src, group=group)
return input_ return input_
def broadcast_object_list(obj_list, src=0): def broadcast_object_list(obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list.""" """Broadcast the input object list."""
world_size = torch.distributed.get_world_size() group = group or torch.distributed.group.WORLD
assert 0 <= src < world_size, f"Invalid src rank ({src})" ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1: if world_size == 1:
return obj_list return obj_list
# Broadcast. # Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src) torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
"""Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict
rank = torch.distributed.get_rank()
if rank == src:
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
assert value.is_cuda, (
f"Tensor {key}: {value} is not on cuda. Currently we only "
f"support broadcasting tensors on cuda.")
metadata_list.append(
(key, TensorMetadata(value.dtype, value.size())))
else:
metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src)
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0]
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device="cuda")
async_handle = torch.distributed.broadcast(tensor,
src=src,
async_op=True,
group=group)
async_handles.append(async_handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
"""CuPy utilities for all-reduce.
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.
NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib
import torch
from torch.distributed import ReduceOp
try:
import cupy
from cupy.cuda import nccl
from cupyx.distributed import NCCLBackend
except ImportError as e:
cupy = e
nccl = None
class NCCLBackend:
...
_OP_MAPPING = {
ReduceOp.SUM: "sum",
ReduceOp.PRODUCT: "prod",
ReduceOp.MIN: "min",
ReduceOp.MAX: "max",
}
class NCCLBackendWithBFloat16(NCCLBackend):
# This is enough to add bfloat16 support for most operations,
# but broadcast will fail (will require changes in compiled
# cupy code).
def _get_nccl_dtype_and_count(self, array, count=None):
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
torch_dtype = getattr(array, "_torch_dtype", None)
if torch_dtype is torch.bfloat16:
nccl_dtype = nccl.NCCL_BFLOAT16
return nccl_dtype, count
def barrier(self) -> None:
raise RuntimeError(
"Currently, CuPy NCCL barrier is not supported since the TCP "
"store is immediately stopped after the initialization.")
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return _NCCL_BACKEND is not None
@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
stream.device_index)
with cupy_stream:
yield
def init_process_group(world_size: int, rank: int, host: str,
port: int) -> None:
"""Initializes the CuPy NCCL backend.
# TODO: handle NCCL timeouts.
"""
assert not is_initialized()
if isinstance(cupy, Exception):
raise ImportError(
"NCCLBackend is not available. Please install cupy.") from cupy
# TODO(woosuk): Create TP and PP process groups for CuPy.
global _NCCL_BACKEND
global _WORLD_SIZE
assert world_size > 0, f"{world_size=} should be a positive integer"
assert 0 <= rank < world_size, (
f"{rank=} should be a integer between [0, {world_size})")
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
_WORLD_SIZE = world_size
# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
# Hack to support bfloat16
torch_dtype = input_.dtype
if torch_dtype is torch.bfloat16:
# We need to view as float16, otherwise
# cupy will fail. This will not change
# the underlying data.
input_ = input_.view(torch.float16)
cupy_input = cupy.asarray(input_)
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
out_array=cupy_input,
op=_OP_MAPPING[op])
def destroy_process_group() -> None:
"""Destroys the NCCL backend."""
global _NCCL_BACKEND
global _WORLD_SIZE
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def get_world_size() -> int:
"""Returns the world size."""
return _WORLD_SIZE
def get_nccl_backend():
return _NCCL_BACKEND
from contextlib import contextmanager
from typing import Optional
import torch
import torch.distributed as dist
from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank)
try:
from vllm._C import custom_ar
import pynvml
except ImportError:
# For AMD GPUs
custom_ar = None
pynvml = None
logger = init_logger(__name__)
_CA_HANDLE = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None:
global _CA_HANDLE
if _CA_HANDLE is not None:
return
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warn(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
"disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
if not _can_p2p(rank, world_size):
logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability. To silence this warning, specify"
"disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size)
def begin_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = True
def end_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = False
def is_capturing() -> bool:
return _IS_CAPTURING and _CA_HANDLE is not None
def get_handle() -> Optional["CustomAllreduce"]:
return _CA_HANDLE
def is_initialized() -> bool:
return _CA_HANDLE is not None
@contextmanager
def capture():
try:
begin_capture()
yield
finally:
end_capture()
handle = get_handle()
if handle is not None:
handle.register_graph_buffers()
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_reg(input)
else:
if ca_handle.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
# query if the set of gpus are fully connected by nvlink (1 hop)
@_nvml()
def _is_full_nvlink(rank, world_size):
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
for i in range(world_size):
if i != rank:
try:
link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i)
if not link_state:
return False
except pynvml.NVMLError as error:
logger.info(
f"NVLink detection failed with message \"{str(error)}\". "
"This is normal if your machine has no NVLink equipped")
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
continue
if not torch.cuda.can_device_access_peer(rank, i):
return False
return True
class CustomAllreduce:
# max_size: max supported allreduce size
def __init__(self, rank, world_size, max_size=8192 * 1024) -> None:
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
dtype=torch.uint8,
device="cuda")
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device="cuda")
self.max_size = max_size
self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = _is_full_nvlink(rank, world_size)
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
handles, offsets, rank,
self.full_nvlink)
self.fast_cond = self.full_nvlink or world_size <= 2
self.register_buffer(self.buffer)
def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
shard_data = (
data[1], # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0])
offsets.append(all_data[i][1])
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
custom_ar.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
custom_ar.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink)
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def close(self):
if self._ptr:
custom_ar.dispose(self._ptr)
self._ptr = 0
def __del__(self):
self.close()
...@@ -3,9 +3,12 @@ ...@@ -3,9 +3,12 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups.""" """Tensor and pipeline parallel groups."""
import contextlib
import torch import torch
from vllm.model_executor.parallel_utils import cupy_utils
# Tensor model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to. # Pipeline model parallel group that the current rank belongs to.
...@@ -83,6 +86,31 @@ def initialize_model_parallel( ...@@ -83,6 +86,31 @@ def initialize_model_parallel(
_PIPELINE_GLOBAL_RANKS = ranks _PIPELINE_GLOBAL_RANKS = ranks
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size)
return
assert (
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
), ("tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
assert (get_pipeline_model_parallel_world_size(
) == pipeline_model_parallel_size), (
"pipeline parallel group already initialized, but of unexpected size: "
f"{get_pipeline_model_parallel_world_size()=} vs. "
f"{pipeline_model_parallel_size=}")
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized.""" """Check if tensor and pipeline parallel groups are initialized."""
return (_TENSOR_MODEL_PARALLEL_GROUP is not None return (_TENSOR_MODEL_PARALLEL_GROUP is not None
...@@ -92,7 +120,7 @@ def model_parallel_is_initialized(): ...@@ -92,7 +120,7 @@ def model_parallel_is_initialized():
def get_tensor_model_parallel_group(): def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to.""" """Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
"tenosr model parallel group is not initialized") "tensor model parallel group is not initialized")
return _TENSOR_MODEL_PARALLEL_GROUP return _TENSOR_MODEL_PARALLEL_GROUP
...@@ -161,7 +189,7 @@ def get_pipeline_model_parallel_next_rank(): ...@@ -161,7 +189,7 @@ def get_pipeline_model_parallel_next_rank():
def get_pipeline_model_parallel_prev_rank(): def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline""" """Return the global rank that precedes the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, ( assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized") "Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank() rank_in_pipeline = get_pipeline_model_parallel_rank()
...@@ -170,10 +198,48 @@ def get_pipeline_model_parallel_prev_rank(): ...@@ -170,10 +198,48 @@ def get_pipeline_model_parallel_prev_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none and destroy them."""
global _TENSOR_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
if _TENSOR_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
if _PIPELINE_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
# Destroy the cupy states if any.
cupy_utils.destroy_process_group()
# Whether to use cupy for nccl all reduce.
# We use cupy for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_CUPY_FOR_ALL_REDUCE = False
@contextlib.contextmanager
def with_cupy_nccl_for_all_reduce():
"""use CuPy nccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No-op.
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
yield
else:
global _ENABLE_CUPY_FOR_ALL_REDUCE
old = _ENABLE_CUPY_FOR_ALL_REDUCE
_ENABLE_CUPY_FOR_ALL_REDUCE = True
stream = torch.cuda.current_stream()
with cupy_utils.set_cupy_stream(stream):
yield
_ENABLE_CUPY_FOR_ALL_REDUCE = old
def is_cupy_nccl_enabled_for_all_reduce():
"""check if CuPy nccl is enabled for all reduce"""
global _ENABLE_CUPY_FOR_ALL_REDUCE
return _ENABLE_CUPY_FOR_ALL_REDUCE
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from vllm.utils import in_wsl from vllm.utils import in_wsl, is_neuron
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -19,6 +19,7 @@ class SamplingMetadata: ...@@ -19,6 +19,7 @@ class SamplingMetadata:
prompt_lens: Lengths of prompts. prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling. selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample. categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable make the sampling only happens in the driver worker, and disable
sampling in other worker processes. sampling in other worker processes.
...@@ -31,6 +32,7 @@ class SamplingMetadata: ...@@ -31,6 +32,7 @@ class SamplingMetadata:
prompt_lens: Optional[List[int]], prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor, selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
generators: Optional[List[torch.Generator]] = None,
perform_sampling: bool = True, perform_sampling: bool = True,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
...@@ -38,6 +40,7 @@ class SamplingMetadata: ...@@ -38,6 +40,7 @@ class SamplingMetadata:
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices self.categorized_sample_indices = categorized_sample_indices
self.generators = generators
self.perform_sampling = perform_sampling self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
...@@ -152,7 +155,7 @@ class SamplingTensors: ...@@ -152,7 +155,7 @@ class SamplingTensors:
dtype: torch.dtype) -> "SamplingTensors": dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without # Note that the performance will be very bad without
# pinned memory. # pinned memory.
pin_memory = not in_wsl() pin_memory = not in_wsl() and not is_neuron()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens) prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_padded_tokens = [ prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens)) tokens + [vocab_size] * (prompt_max_len - len(tokens))
......
"""Utils for model executor.""" """Utils for model executor."""
import random import random
import importlib
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np import numpy as np
import torch import torch
from vllm.config import DeviceConfig, ModelConfig
DEVICE_TO_MODEL_LOADER_MAP = {
"cuda": "model_loader",
"neuron": "neuron_model_loader",
}
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
random.seed(seed) random.seed(seed)
...@@ -33,3 +41,12 @@ def set_weight_attrs( ...@@ -33,3 +41,12 @@ def set_weight_attrs(
assert not hasattr( assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}") weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value) setattr(weight, key, value)
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> torch.nn.Module:
model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
imported_model_loader = importlib.import_module(
f"vllm.model_executor.{model_loader_module}")
get_model_fn = imported_model_loader.get_model
return get_model_fn(model_config, device_config, **kwargs)
"""Utilities for downloading and initializing model weights.""" """Utilities for downloading and initializing model weights."""
import filelock import filelock
import glob import glob
import fnmatch
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download, HfFileSystem
import numpy as np import numpy as np
from safetensors.torch import load_file, save_file, safe_open from safetensors.torch import load_file, save_file, safe_open
import torch import torch
from transformers import PretrainedConfig
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (get_quantization_config, from vllm.model_executor.layers.quantization import (get_quantization_config,
QuantizationConfig) QuantizationConfig)
...@@ -82,25 +83,22 @@ def convert_bin_to_safetensor_file( ...@@ -82,25 +83,22 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place. # TODO(woosuk): Move this to other place.
def get_quant_config( def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
quantization: str, quant_cls = get_quantization_config(model_config.quantization)
model_name_or_path: str,
hf_config: PretrainedConfig,
cache_dir: Optional[str] = None,
) -> QuantizationConfig:
quant_cls = get_quantization_config(quantization)
# Read the quantization config from the HF model config, if available. # Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(hf_config, "quantization_config", None) hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if hf_quant_config is not None: if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config) return quant_cls.from_config(hf_quant_config)
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
if not is_local: if not is_local:
# Download the config files. # Download the config files.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, model_config.download_dir):
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json", allow_patterns="*.json",
cache_dir=cache_dir, cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm) tqdm_class=Disabledtqdm)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
...@@ -111,10 +109,12 @@ def get_quant_config( ...@@ -111,10 +109,12 @@ def get_quant_config(
f.endswith(x) for x in quant_cls.get_config_filenames()) f.endswith(x) for x in quant_cls.get_config_filenames())
] ]
if len(quant_config_files) == 0: if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {quantization}") raise ValueError(
f"Cannot find the config file for {model_config.quantization}")
if len(quant_config_files) > 1: if len(quant_config_files) > 1:
raise ValueError(f"Found multiple config files for {quantization}: " raise ValueError(
f"{quant_config_files}") f"Found multiple config files for {model_config.quantization}: "
f"{quant_config_files}")
quant_config_file = quant_config_files[0] quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f: with open(quant_config_file, "r") as f:
...@@ -149,6 +149,18 @@ def prepare_hf_model_weights( ...@@ -149,6 +149,18 @@ def prepare_hf_model_weights(
allow_patterns += ["*.pt"] allow_patterns += ["*.pt"]
if not is_local: if not is_local:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment