Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
...@@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size ...@@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -55,10 +56,10 @@ def _get_gemma_act_fn( ...@@ -55,10 +56,10 @@ def _get_gemma_act_fn(
"in the config JSON file when it was initially released. " "in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU " "Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy " "(`gelu_pytorch_tanh`). If you want to use the legacy "
f"`{hidden_act}`, edit the config JSON to set " "`%s`, edit the config JSON to set "
f"`hidden_activation={hidden_act}` instead of `hidden_act`. " "`hidden_activation=%s` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 " "See https://github.com/huggingface/transformers/pull/29402 "
"for more details.") "for more details.", hidden_act, hidden_act)
return GeluAndMul(approximate="tanh") return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu_pytorch_tanh": elif hidden_activation == "gelu_pytorch_tanh":
return GeluAndMul(approximate="tanh") return GeluAndMul(approximate="tanh")
...@@ -77,17 +78,17 @@ class GemmaMLP(nn.Module): ...@@ -77,17 +78,17 @@ class GemmaMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: Optional[str] = None, hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None, hidden_activation: Optional[str] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
def forward(self, x): def forward(self, x):
...@@ -106,7 +107,7 @@ class GemmaAttention(nn.Module): ...@@ -106,7 +107,7 @@ class GemmaAttention(nn.Module):
head_dim: int, head_dim: int,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
rope_theta: float = 10000, rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -135,13 +136,13 @@ class GemmaAttention(nn.Module): ...@@ -135,13 +136,13 @@ class GemmaAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -187,14 +188,14 @@ class GemmaDecoderLayer(nn.Module): ...@@ -187,14 +188,14 @@ class GemmaDecoderLayer(nn.Module):
head_dim=config.head_dim, head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta, rope_theta=config.rope_theta,
linear_method=linear_method, quant_config=quant_config,
) )
self.mlp = GemmaMLP( self.mlp = GemmaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None), hidden_activation=getattr(config, "hidden_activation", None),
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -235,7 +236,7 @@ class GemmaModel(nn.Module): ...@@ -235,7 +236,7 @@ class GemmaModel(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -245,7 +246,7 @@ class GemmaModel(nn.Module): ...@@ -245,7 +246,7 @@ class GemmaModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
GemmaDecoderLayer(config, linear_method) GemmaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module): ...@@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
del lora_config # Unused. del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = GemmaModel(config, linear_method) self.model = GemmaModel(config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -27,10 +27,11 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -27,10 +27,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -44,7 +45,7 @@ class GPT2Attention(nn.Module): ...@@ -44,7 +45,7 @@ class GPT2Attention(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -61,13 +62,13 @@ class GPT2Attention(nn.Module): ...@@ -61,13 +62,13 @@ class GPT2Attention(nn.Module):
self.head_dim, self.head_dim,
total_num_heads, total_num_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
...@@ -90,7 +91,7 @@ class GPT2MLP(nn.Module): ...@@ -90,7 +91,7 @@ class GPT2MLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPT2Config, config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -98,15 +99,14 @@ class GPT2MLP(nn.Module): ...@@ -98,15 +99,14 @@ class GPT2MLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size) intermediate_size)
...@@ -122,7 +122,7 @@ class GPT2Block(nn.Module): ...@@ -122,7 +122,7 @@ class GPT2Block(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -130,9 +130,9 @@ class GPT2Block(nn.Module): ...@@ -130,9 +130,9 @@ class GPT2Block(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, linear_method) self.attn = GPT2Attention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, linear_method) self.mlp = GPT2MLP(inner_dim, config, quant_config)
def forward( def forward(
self, self,
...@@ -163,7 +163,7 @@ class GPT2Model(nn.Module): ...@@ -163,7 +163,7 @@ class GPT2Model(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -174,7 +174,7 @@ class GPT2Model(nn.Module): ...@@ -174,7 +174,7 @@ class GPT2Model(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.h = nn.ModuleList([
GPT2Block(config, linear_method) GPT2Block(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -203,12 +203,12 @@ class GPT2LMHeadModel(nn.Module): ...@@ -203,12 +203,12 @@ class GPT2LMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = GPT2Model(config, linear_method) self.transformer = GPT2Model(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -72,14 +73,14 @@ class GPTBigCodeAttention(nn.Module): ...@@ -72,14 +73,14 @@ class GPTBigCodeAttention(nn.Module):
total_num_heads, total_num_heads,
total_num_kv_heads, total_num_kv_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -111,7 +112,7 @@ class GPTBigMLP(nn.Module): ...@@ -111,7 +112,7 @@ class GPTBigMLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -119,15 +120,14 @@ class GPTBigMLP(nn.Module): ...@@ -119,15 +120,14 @@ class GPTBigMLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size) intermediate_size)
...@@ -143,7 +143,7 @@ class GPTBigCodeBlock(nn.Module): ...@@ -143,7 +143,7 @@ class GPTBigCodeBlock(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -151,9 +151,9 @@ class GPTBigCodeBlock(nn.Module): ...@@ -151,9 +151,9 @@ class GPTBigCodeBlock(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, linear_method) self.attn = GPTBigCodeAttention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, linear_method) self.mlp = GPTBigMLP(inner_dim, config, quant_config)
def forward( def forward(
self, self,
...@@ -184,7 +184,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -184,7 +184,7 @@ class GPTBigCodeModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -195,7 +195,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -195,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.h = nn.ModuleList([
GPTBigCodeBlock(config, linear_method) GPTBigCodeBlock(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -224,12 +224,12 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -224,12 +224,12 @@ class GPTBigCodeForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, linear_method) self.transformer = GPTBigCodeModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -44,7 +45,7 @@ class GPTJAttention(nn.Module): ...@@ -44,7 +45,7 @@ class GPTJAttention(nn.Module):
def __init__( def __init__(
self, self,
config: GPTJConfig, config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
...@@ -56,13 +57,13 @@ class GPTJAttention(nn.Module): ...@@ -56,13 +57,13 @@ class GPTJAttention(nn.Module):
self.head_size, self.head_size,
self.total_num_heads, self.total_num_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
...@@ -105,21 +106,20 @@ class GPTJMLP(nn.Module): ...@@ -105,21 +106,20 @@ class GPTJMLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPTJConfig, config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.n_embd hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear( self.fc_in = ColumnParallelLinear(
hidden_size, hidden_size,
intermediate_size, intermediate_size,
linear_method=linear_method, quant_config=quant_config,
) )
self.fc_out = RowParallelLinear( self.fc_out = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size) intermediate_size)
...@@ -135,14 +135,14 @@ class GPTJBlock(nn.Module): ...@@ -135,14 +135,14 @@ class GPTJBlock(nn.Module):
def __init__( def __init__(
self, self,
config: GPTJConfig, config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
inner_dim = (4 * config.n_embd inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner) if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, linear_method) self.attn = GPTJAttention(config, quant_config)
self.mlp = GPTJMLP(inner_dim, config, linear_method) self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward( def forward(
self, self,
...@@ -169,7 +169,7 @@ class GPTJModel(nn.Module): ...@@ -169,7 +169,7 @@ class GPTJModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPTJConfig, config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -179,7 +179,7 @@ class GPTJModel(nn.Module): ...@@ -179,7 +179,7 @@ class GPTJModel(nn.Module):
self.embed_dim, self.embed_dim,
) )
self.h = nn.ModuleList( self.h = nn.ModuleList(
[GPTJBlock(config, linear_method) for _ in range(config.n_layer)]) [GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -207,13 +207,13 @@ class GPTJForCausalLM(nn.Module): ...@@ -207,13 +207,13 @@ class GPTJForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: GPTJConfig, config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
assert not config.tie_word_embeddings assert not config.tie_word_embeddings
self.transformer = GPTJModel(config, linear_method) self.transformer = GPTJModel(config, quant_config)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.n_embd, config.n_embd,
......
...@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
...@@ -63,13 +64,13 @@ class GPTNeoXAttention(nn.Module): ...@@ -63,13 +64,13 @@ class GPTNeoXAttention(nn.Module):
self.head_size, self.head_size,
self.total_num_heads, self.total_num_heads,
bias=self.bias, bias=self.bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
bias=self.bias, bias=self.bias,
linear_method=linear_method, quant_config=quant_config,
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
...@@ -105,20 +106,19 @@ class GPTNeoXMLP(nn.Module): ...@@ -105,20 +106,19 @@ class GPTNeoXMLP(nn.Module):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
linear_method=linear_method, quant_config=quant_config,
) )
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config, self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size) config.intermediate_size)
...@@ -134,7 +134,7 @@ class GPTNeoXLayer(nn.Module): ...@@ -134,7 +134,7 @@ class GPTNeoXLayer(nn.Module):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
...@@ -142,8 +142,8 @@ class GPTNeoXLayer(nn.Module): ...@@ -142,8 +142,8 @@ class GPTNeoXLayer(nn.Module):
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, linear_method) self.attention = GPTNeoXAttention(config, quant_config)
self.mlp = GPTNeoXMLP(config, linear_method) self.mlp = GPTNeoXMLP(config, quant_config)
def forward( def forward(
self, self,
...@@ -182,7 +182,7 @@ class GPTNeoXModel(nn.Module): ...@@ -182,7 +182,7 @@ class GPTNeoXModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -192,7 +192,7 @@ class GPTNeoXModel(nn.Module): ...@@ -192,7 +192,7 @@ class GPTNeoXModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
GPTNeoXLayer(config, linear_method) GPTNeoXLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,
...@@ -223,12 +223,12 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -223,12 +223,12 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, linear_method) self.gpt_neox = GPTNeoXModel(config, quant_config)
self.embed_out = ParallelLMHead( self.embed_out = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -9,11 +9,12 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -9,11 +9,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -30,17 +31,17 @@ class InternLM2MLP(nn.Module): ...@@ -30,17 +31,17 @@ class InternLM2MLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.w2 = RowParallelLinear(intermediate_size, self.w2 = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -63,7 +64,7 @@ class InternLM2Attention(nn.Module): ...@@ -63,7 +64,7 @@ class InternLM2Attention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -94,13 +95,13 @@ class InternLM2Attention(nn.Module): ...@@ -94,13 +95,13 @@ class InternLM2Attention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.wo = RowParallelLinear( self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -150,13 +151,13 @@ class InternLMDecoderLayer(nn.Module): ...@@ -150,13 +151,13 @@ class InternLMDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
) )
self.feed_forward = InternLM2MLP( self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.attention_norm = RMSNorm(config.hidden_size, self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -195,7 +196,7 @@ class InternLM2Model(nn.Module): ...@@ -195,7 +196,7 @@ class InternLM2Model(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -206,7 +207,7 @@ class InternLM2Model(nn.Module): ...@@ -206,7 +207,7 @@ class InternLM2Model(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternLMDecoderLayer(config, linear_method) InternLMDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = InternLM2Model(config, linear_method) self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -29,10 +29,11 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -29,10 +29,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -68,7 +69,7 @@ class JAISAttention(nn.Module): ...@@ -68,7 +69,7 @@ class JAISAttention(nn.Module):
def __init__( def __init__(
self, self,
config: JAISConfig, config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -88,13 +89,13 @@ class JAISAttention(nn.Module): ...@@ -88,13 +89,13 @@ class JAISAttention(nn.Module):
self.head_dim, self.head_dim,
total_num_heads, total_num_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -128,7 +129,7 @@ class JAISMLP(nn.Module): ...@@ -128,7 +129,7 @@ class JAISMLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: JAISConfig, config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -137,19 +138,19 @@ class JAISMLP(nn.Module): ...@@ -137,19 +138,19 @@ class JAISMLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_fc2 = (ColumnParallelLinear( self.c_fc2 = (ColumnParallelLinear(
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) if self.swiglu else None) ) if self.swiglu else None)
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.act = SwiGLUActivation() self.act = SwiGLUActivation()
...@@ -169,7 +170,7 @@ class JAISBlock(nn.Module): ...@@ -169,7 +170,7 @@ class JAISBlock(nn.Module):
def __init__( def __init__(
self, self,
config: JAISConfig, config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -177,9 +178,9 @@ class JAISBlock(nn.Module): ...@@ -177,9 +178,9 @@ class JAISBlock(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, linear_method) self.attn = JAISAttention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, linear_method) self.mlp = JAISMLP(inner_dim, config, quant_config)
def forward( def forward(
self, self,
...@@ -210,7 +211,7 @@ class JAISModel(nn.Module): ...@@ -210,7 +211,7 @@ class JAISModel(nn.Module):
def __init__( def __init__(
self, self,
config: JAISConfig, config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -227,7 +228,7 @@ class JAISModel(nn.Module): ...@@ -227,7 +228,7 @@ class JAISModel(nn.Module):
else: else:
self.embeddings_scale = config.mup_embeddings_scale self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([ self.h = nn.ModuleList([
JAISBlock(config, linear_method) JAISBlock(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module): ...@@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: JAISConfig, config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = JAISModel(config, linear_method) self.transformer = JAISModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
if hasattr(config, "width_scale"): if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale self.output_logits_scale = config.width_scale
......
...@@ -33,11 +33,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -33,11 +33,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -56,17 +57,17 @@ class LlamaMLP(nn.Module): ...@@ -56,17 +57,17 @@ class LlamaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QKVParallelLinear] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -89,7 +90,7 @@ class LlamaAttention(nn.Module): ...@@ -89,7 +90,7 @@ class LlamaAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
...@@ -131,13 +132,13 @@ class LlamaAttention(nn.Module): ...@@ -131,13 +132,13 @@ class LlamaAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=bias, bias=bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=bias, bias=bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -174,12 +175,16 @@ class LlamaDecoderLayer(nn.Module): ...@@ -174,12 +175,16 @@ class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
sliding_window = getattr(config, "sliding_window", None) sliding_window = getattr(config, "sliding_window", None)
...@@ -195,7 +200,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -195,7 +200,7 @@ class LlamaDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
bias=attention_bias, bias=attention_bias,
sliding_window=sliding_window, sliding_window=sliding_window,
) )
...@@ -203,7 +208,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -203,7 +208,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -244,7 +249,7 @@ class LlamaModel(nn.Module): ...@@ -244,7 +249,7 @@ class LlamaModel(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -260,7 +265,7 @@ class LlamaModel(nn.Module): ...@@ -260,7 +265,7 @@ class LlamaModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method) LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -325,13 +330,12 @@ class LlamaForCausalLM(nn.Module): ...@@ -325,13 +330,12 @@ class LlamaForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.model = LlamaModel(config, quant_config, lora_config=lora_config)
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
...@@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module): ...@@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
("qkv_proj", "k_proj", "k"), (".qkv_proj", ".k_proj", "k"),
("qkv_proj", "v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "gate_proj", 0), (".gate_up_proj", ".gate_proj", 0),
("gate_up_proj", "up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -9,8 +9,9 @@ from transformers import CLIPVisionModel, LlavaConfig ...@@ -9,8 +9,9 @@ from transformers import CLIPVisionModel, LlavaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig from vllm.config import VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module): ...@@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
def __init__(self, def __init__(self,
config: "LlavaConfig", config: "LlavaConfig",
vision_language_config: VisionLanguageConfig, vision_language_config: VisionLanguageConfig,
linear_method: Optional["LinearMethodBase"] = None) -> None: quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -83,8 +84,8 @@ class LlavaForConditionalGeneration(nn.Module): ...@@ -83,8 +84,8 @@ class LlavaForConditionalGeneration(nn.Module):
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.linear_method = linear_method self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, linear_method) self.language_model = LlamaModel(config.text_config, quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, self.unpadded_vocab_size,
......
...@@ -35,12 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -35,12 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -84,7 +85,7 @@ class MiniCPMMoE(nn.Module): ...@@ -84,7 +85,7 @@ class MiniCPMMoE(nn.Module):
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
linear_method=None) quant_config=None)
self.ws = nn.Parameter( self.ws = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
...@@ -147,17 +148,17 @@ class MiniCPMMLP(nn.Module): ...@@ -147,17 +148,17 @@ class MiniCPMMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -180,7 +181,7 @@ class MiniCPMAttention(nn.Module): ...@@ -180,7 +181,7 @@ class MiniCPMAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -211,13 +212,13 @@ class MiniCPMAttention(nn.Module): ...@@ -211,13 +212,13 @@ class MiniCPMAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -274,7 +275,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -274,7 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
) )
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0: if self.num_experts == 0:
...@@ -282,7 +283,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -282,7 +283,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
else: else:
self.mlp = MiniCPMMoE(num_experts=config.num_experts, self.mlp = MiniCPMMoE(num_experts=config.num_experts,
...@@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module): ...@@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -345,7 +346,7 @@ class MiniCPMModel(nn.Module): ...@@ -345,7 +346,7 @@ class MiniCPMModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MiniCPMDecoderLayer(config, linear_method) MiniCPMDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self.linear_method = linear_method self.quant_config = quant_config
self.model = MiniCPMModel(config, self.model = MiniCPMModel(config,
linear_method, quant_config,
lora_config=lora_config) lora_config=lora_config)
unpadded_vocab_size = config.vocab_size unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
......
...@@ -27,6 +27,7 @@ import torch ...@@ -27,6 +27,7 @@ import torch
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
...@@ -34,13 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -34,13 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (QKVParallelLinear,
QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod, from vllm.model_executor.layers.quantization.base_config import (
per_tensor_quantize) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -69,7 +70,7 @@ class MixtralMoE(nn.Module): ...@@ -69,7 +70,7 @@ class MixtralMoE(nn.Module):
intermediate_size: int, intermediate_size: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.tp_size = tp_size or get_tensor_model_parallel_world_size()
...@@ -77,50 +78,90 @@ class MixtralMoE(nn.Module): ...@@ -77,50 +78,90 @@ class MixtralMoE(nn.Module):
self.top_k = top_k self.top_k = top_k
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
# FIXME(pcmoritz): Make this more general to support different # FIXME(pcmoritz): Make this more general to support different
# quantization schemes # quantization schemes
self.use_fp8 = isinstance(linear_method, Fp8LinearMethod) self.use_fp8 = isinstance(quant_config, Fp8Config)
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size, self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
linear_method=None) quant_config=None)
if self.use_fp8:
params_dtype = torch.float8_e4m3fn
self.ws = nn.Parameter( self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
2 * self.intermediate_size, 2 * self.intermediate_size,
self.hidden_size, self.hidden_size,
device="cuda", dtype=params_dtype))
dtype=self.params_dtype)) self.w2_weight = nn.Parameter(
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size, self.intermediate_size,
device="cuda", dtype=params_dtype))
dtype=self.params_dtype))
set_weight_attrs(self.w13_weight, {
# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
self.w2s_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) })
set_weight_attrs(self.w2s, { set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) })
# Used for fp8.
self.w13_scale = None
self.w2_scale = None
self.a13_scale = None
self.a2_scale = None
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int): weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -134,18 +175,49 @@ class MixtralMoE(nn.Module): ...@@ -134,18 +175,49 @@ class MixtralMoE(nn.Module):
shard_size:2 * shard_size, :] = loaded_weight[shard, :] shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"): if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard] param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[expert_id] = loaded_weight
def process_weights_after_loading(self): def process_weights_after_loading(self):
if self.use_fp8: # Fp8 is the only case where we need to process after loading.
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) if not self.use_fp8:
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) return
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts): for expert in range(self.num_total_experts):
ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize( w13_weight[expert, :, :], self.w13_scale[
self.ws.data[expert, :, :]) expert] = ops.scaled_fp8_quant(
w2s[expert, :, :], self.w2s_scale[ self.w13_weight.data[expert, :, :])
expert] = per_tensor_quantize(self.w2s.data[expert, :, :]) w2_weight[expert, :, :], self.w2_scale[
self.ws = nn.Parameter(ws, requires_grad=False) expert] = ops.scaled_fp8_quant(
self.w2s = nn.Parameter(w2s, requires_grad=False) self.w2_weight.data[expert, :, :])
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
...@@ -153,15 +225,17 @@ class MixtralMoE(nn.Module): ...@@ -153,15 +225,17 @@ class MixtralMoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states, final_hidden_states = fused_moe(hidden_states,
self.ws, self.w13_weight,
self.w2s, self.w2_weight,
router_logits, router_logits,
self.top_k, self.top_k,
renormalize=True, renormalize=True,
inplace=True, inplace=True,
use_fp8=self.use_fp8, use_fp8=self.use_fp8,
w1_scale=self.ws_scale, w1_scale=self.w13_scale,
w2_scale=self.w2s_scale) w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
...@@ -178,7 +252,7 @@ class MixtralAttention(nn.Module): ...@@ -178,7 +252,7 @@ class MixtralAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -203,12 +277,14 @@ class MixtralAttention(nn.Module): ...@@ -203,12 +277,14 @@ class MixtralAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.sliding_window = sliding_window self.sliding_window = sliding_window
if isinstance(linear_method, Fp8LinearMethod): if isinstance(
quant_config,
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
print_warning_once( print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize " "For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved." "the attention layers until their FP8 performance is improved."
) )
linear_method = None quant_config = None
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
...@@ -216,13 +292,13 @@ class MixtralAttention(nn.Module): ...@@ -216,13 +292,13 @@ class MixtralAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -259,7 +335,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -259,7 +335,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -272,13 +348,13 @@ class MixtralDecoderLayer(nn.Module): ...@@ -272,13 +348,13 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
linear_method=linear_method) quant_config=quant_config)
self.block_sparse_moe = MixtralMoE( self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
linear_method=linear_method) quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
...@@ -318,7 +394,7 @@ class MixtralModel(nn.Module): ...@@ -318,7 +394,7 @@ class MixtralModel(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -334,7 +410,7 @@ class MixtralModel(nn.Module): ...@@ -334,7 +410,7 @@ class MixtralModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method) MixtralDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -384,14 +460,13 @@ class MixtralForCausalLM(nn.Module): ...@@ -384,14 +460,13 @@ class MixtralForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, self.model = MixtralModel(config,
linear_method, quant_config,
lora_config=lora_config) lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
...@@ -443,11 +518,26 @@ class MixtralForCausalLM(nn.Module): ...@@ -443,11 +518,26 @@ class MixtralForCausalLM(nn.Module):
] ]
expert_params_mapping = [ expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id) # (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s", ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id) f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts) for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"] for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -486,3 +576,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -486,3 +576,8 @@ class MixtralForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
...@@ -34,11 +34,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -34,11 +34,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (QKVParallelLinear,
QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -55,7 +56,7 @@ class MixtralMLP(nn.Module): ...@@ -55,7 +56,7 @@ class MixtralMLP(nn.Module):
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_experts = num_experts self.num_experts = num_experts
...@@ -65,15 +66,15 @@ class MixtralMLP(nn.Module): ...@@ -65,15 +66,15 @@ class MixtralMLP(nn.Module):
self.w1 = ReplicatedLinear(self.hidden_dim, self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim, self.ffn_dim,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.w2 = ReplicatedLinear(self.ffn_dim, self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim, self.hidden_dim,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.w3 = ReplicatedLinear(self.hidden_dim, self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim, self.ffn_dim,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
# TODO: Use vllm's SiluAndMul # TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU() self.act_fn = nn.SiLU()
...@@ -92,7 +93,7 @@ class MixtralMoE(nn.Module): ...@@ -92,7 +93,7 @@ class MixtralMoE(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -115,14 +116,14 @@ class MixtralMoE(nn.Module): ...@@ -115,14 +116,14 @@ class MixtralMoE(nn.Module):
MixtralMLP(self.num_total_experts, MixtralMLP(self.num_total_experts,
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
linear_method=linear_method) quant_config=quant_config)
if idx in self.expert_indicies else None if idx in self.expert_indicies else None
for idx in range(self.num_total_experts) for idx in range(self.num_total_experts)
]) ])
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
linear_method=None) quant_config=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
...@@ -162,7 +163,7 @@ class MixtralAttention(nn.Module): ...@@ -162,7 +163,7 @@ class MixtralAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -193,13 +194,13 @@ class MixtralAttention(nn.Module): ...@@ -193,13 +194,13 @@ class MixtralAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -249,9 +250,9 @@ class MixtralDecoderLayer(nn.Module): ...@@ -249,9 +250,9 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
linear_method=linear_method) quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(config=config, self.block_sparse_moe = MixtralMoE(config=config,
linear_method=linear_method) quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
...@@ -291,7 +292,7 @@ class MixtralModel(nn.Module): ...@@ -291,7 +292,7 @@ class MixtralModel(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -302,7 +303,7 @@ class MixtralModel(nn.Module): ...@@ -302,7 +303,7 @@ class MixtralModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method) MixtralDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module): ...@@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = MixtralModel(config, linear_method) self.model = MixtralModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -11,10 +11,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -11,10 +11,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -42,7 +43,7 @@ class MPTAttention(nn.Module): ...@@ -42,7 +43,7 @@ class MPTAttention(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -65,7 +66,7 @@ class MPTAttention(nn.Module): ...@@ -65,7 +66,7 @@ class MPTAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=not config.no_bias, bias=not config.no_bias,
linear_method=linear_method, quant_config=quant_config,
) )
if self.qk_ln: if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model) self.q_ln = nn.LayerNorm(self.d_model)
...@@ -74,7 +75,7 @@ class MPTAttention(nn.Module): ...@@ -74,7 +75,7 @@ class MPTAttention(nn.Module):
self.d_model, self.d_model,
self.d_model, self.d_model,
bias=not config.no_bias, bias=not config.no_bias,
linear_method=linear_method, quant_config=quant_config,
) )
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
...@@ -133,7 +134,7 @@ class MPTMLP(nn.Module): ...@@ -133,7 +134,7 @@ class MPTMLP(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
...@@ -143,15 +144,14 @@ class MPTMLP(nn.Module): ...@@ -143,15 +144,14 @@ class MPTMLP(nn.Module):
hidden_size, hidden_size,
intermediate_size, intermediate_size,
bias=not config.no_bias, bias=not config.no_bias,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, intermediate_size) self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=not config.no_bias, bias=not config.no_bias,
linear_method=linear_method, quant_config=quant_config,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -166,14 +166,14 @@ class MPTBlock(nn.Module): ...@@ -166,14 +166,14 @@ class MPTBlock(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size) self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, linear_method) self.attn = MPTAttention(config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size) self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, linear_method) self.ffn = MPTMLP(config, quant_config)
def forward( def forward(
self, self,
...@@ -201,7 +201,7 @@ class MPTModel(nn.Module): ...@@ -201,7 +201,7 @@ class MPTModel(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
assert config.embedding_fraction == 1.0 assert config.embedding_fraction == 1.0
...@@ -212,7 +212,7 @@ class MPTModel(nn.Module): ...@@ -212,7 +212,7 @@ class MPTModel(nn.Module):
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[MPTBlock(config, linear_method) for _ in range(config.n_layers)]) [MPTBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model) self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias: if config.no_bias:
for module in self.modules(): for module in self.modules():
...@@ -246,14 +246,14 @@ class MPTForCausalLM(nn.Module): ...@@ -246,14 +246,14 @@ class MPTForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = MPTModel(config, linear_method) self.transformer = MPTModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
# coding=utf-8 # coding=utf-8
# Adapted from # Adapted from
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py # Copyright 2024 The vLLM team.
# Copyright 2023 The vLLM team. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# #
# BSD 3-Clause License # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
# #
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. # Licensed under the Apache License, Version 2.0 (the "License");
# All rights reserved. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# Redistribution and use in source and binary forms, with or without # http://www.apache.org/licenses/LICENSE-2.0
# modification, are permitted provided that the following conditions are met:
# #
# * Redistributions of source code must retain the above copyright notice, this # Unless required by applicable law or agreed to in writing, software
# list of conditions and the following disclaimer. # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * Redistributions in binary form must reproduce the above copyright notice, # See the License for the specific language governing permissions and
# this list of conditions and the following disclaimer in the documentation # limitations under the License.
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights.""" """Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
# this model must need this dependency
from hf_olmo import OLMoConfig
from torch import nn from torch import nn
from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -70,56 +54,53 @@ class OlmoAttention(nn.Module): ...@@ -70,56 +54,53 @@ class OlmoAttention(nn.Module):
def __init__( def __init__(
self, self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.d_model self.hidden_size = config.hidden_size
assert config.d_model % config.n_heads == 0
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
self.total_num_heads = self.config.n_heads self.total_num_heads = config.num_attention_heads
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads // self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // self.total_num_heads self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.clip_qkv = config.clip_qkv
# Layer norms.
self.attn_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Attention input projection. Projects x -> (q, k, v) # Attention input projection. Projects x -> (q, k, v)
self.att_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
config.d_model, self.hidden_size,
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=config.include_bias, bias=config.attention_bias,
linear_method=linear_method, quant_config=quant_config,
) )
# Rotary embeddings. # Rotary embeddings.
if self.config.rope: self.rotary_emb = get_rope(
rope_theta = getattr(config, "rope_theta", 10000) self.head_dim,
max_position_embeddings = getattr(config, rotary_dim=self.head_dim,
"max_position_embeddings", 8192) max_position=self.max_position_embeddings,
self.rotary_emb = get_rope( base=self.rope_theta,
self.head_dim, )
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling) scale=self.scaling)
# Attention output projection. # Attention output projection.
self.attn_out = RowParallelLinear( self.o_proj = RowParallelLinear(
config.d_model, self.hidden_size,
config.d_model, self.hidden_size,
bias=config.include_bias, bias=config.attention_bias,
linear_method=linear_method, quant_config=quant_config,
) )
def forward( def forward(
...@@ -129,13 +110,13 @@ class OlmoAttention(nn.Module): ...@@ -129,13 +110,13 @@ class OlmoAttention(nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.attn_norm(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
qkv, _ = self.att_proj(hidden_states) if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.config.rope: q, k = self.rotary_emb(positions, q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.attn_out(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -148,57 +129,44 @@ class OlmoMLP(nn.Module): ...@@ -148,57 +129,44 @@ class OlmoMLP(nn.Module):
def __init__( def __init__(
self, self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size self.hidden_size = config.hidden_size
is not None else config.mlp_ratio * config.d_model) self.intermediate_size = config.intermediate_size
# Layer norms.
self.ff_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Feed-forward input projection. # Feed-forward input projection.
self.ff_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
config.d_model, self.hidden_size,
[self.hidden_size // 2] * 2, [self.intermediate_size] * 2,
bias=config.include_bias, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
# Activation function. # Activation function.
self.act = SiluAndMul() self.act_fn = SiluAndMul()
self.act.output_multiplier = 0.5
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Feed-forward output projection. # Feed-forward output projection.
self.ff_out = RowParallelLinear( self.down_proj = RowParallelLinear(
int(self.act.output_multiplier * self.hidden_size), self.intermediate_size,
config.d_model, self.hidden_size,
bias=config.include_bias, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Add feed-forward projection. gate_up, _ = self.gate_up_proj(x)
# shape: (batch_size, seq_len, d_model) x = self.act_fn(gate_up)
og_x = x x, _ = self.down_proj(x)
x = self.ff_norm(x)
x, _ = self.ff_proj(x)
x = self.act(x)
x, _ = self.ff_out(x)
x = og_x + x
return x return x
class OlmoBlock(nn.Module): class OlmoDecoderLayer(nn.Module):
""" """
This is a typical transformer block where the output is This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))`` computed as ``MLP(LN(x + Attention(LN(x))))``
...@@ -206,14 +174,22 @@ class OlmoBlock(nn.Module): ...@@ -206,14 +174,22 @@ class OlmoBlock(nn.Module):
""" """
def __init__(self, def __init__(self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
# Attention block. # Attention block.
self.attn = OlmoAttention(config, linear_method) self.self_attn = OlmoAttention(config, quant_config)
# MLP block. # MLP block.
self.mlp = OlmoMLP(config, linear_method) self.mlp = OlmoMLP(config, quant_config)
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
def forward( def forward(
self, self,
...@@ -223,52 +199,37 @@ class OlmoBlock(nn.Module): ...@@ -223,52 +199,37 @@ class OlmoBlock(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block. # Attention block.
og_x = hidden_states residual = hidden_states
x = self.attn(positions, hidden_states, kv_cache, attn_metadata) hidden_states = self.input_layernorm(hidden_states)
x = x + og_x hidden_states = self.self_attn(positions, hidden_states, kv_cache,
attn_metadata)
hidden_states = hidden_states + residual
# MLP block. # MLP block.
hidden_states = self.mlp(x) residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states return hidden_states
class OlmoModel(nn.Module): class OlmoModel(nn.Module):
def __init__(self, def __init__(self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = nn.ModuleDict( self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
dict( config.hidden_size)
wte=VocabParallelEmbedding( self.layers = nn.ModuleList([
config.embedding_size or config.vocab_size, OlmoDecoderLayer(config, quant_config)
config.d_model, for layer_idx in range(config.num_hidden_layers)
), ])
ln_f=nn.LayerNorm(config.d_model, self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False, elementwise_affine=False,
bias=False), bias=False)
))
blocks = [
OlmoBlock(config, linear_method) for i in range(config.n_layers)
]
if self.config.block_group_size > 1:
raise NotImplementedError("Block group size > 1 not supported yet")
else:
self.transformer.update({"blocks": nn.ModuleList(blocks)})
if not config.weight_tying:
self.transformer.update({
"ff_out":
ColumnParallelLinear(
config.d_model,
config.embedding_size or config.vocab_size,
bias=config.include_bias,
linear_method=linear_method,
)
})
def forward( def forward(
self, self,
...@@ -282,39 +243,48 @@ class OlmoModel(nn.Module): ...@@ -282,39 +243,48 @@ class OlmoModel(nn.Module):
""" """
# Get embeddings of input. # Get embeddings of input.
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
# Apply blocks one-by-one. # Apply blocks one-by-one.
for block_idx, block in enumerate(self.transformer.blocks): for layer_idx, decoder_layer in enumerate(self.layers):
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
x = block( hidden_states = decoder_layer(
positions, positions,
x, hidden_states,
kv_caches[block_idx], kv_caches[layer_idx],
attn_metadata, attn_metadata,
) )
# Apply final layer norm. # Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model) # shape: (batch_size, seq_len or 1, d_model)
x = self.transformer.ln_f(x) # type: ignore hidden_states = self.norm(hidden_states)
return x return hidden_states
class OLMoForCausalLM(nn.Module): class OlmoForCausalLM(nn.Module):
""" """
Extremely barebones HF model wrapper. Extremely barebones HF model wrapper.
""" """
def __init__(self, def __init__(self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.model = OlmoModel(config, quant_config)
self.model = OlmoModel(config, linear_method) if config.tie_word_embeddings:
self.lm_head_weight = (self.model.transformer.wte.weight self.lm_head_weight = self.model.embed_tokens.weight
if config.weight_tying else else:
self.model.transformer.ff_out.weight) self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -348,20 +318,39 @@ class OLMoForCausalLM(nn.Module): ...@@ -348,20 +318,39 @@ class OLMoForCausalLM(nn.Module):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
# attention if "rotary_emb.inv_freq" in name:
if ".att" in name: continue
name = name.replace(".att", ".attn.att") if ("rotary_emb.cos_cached" in name
# mlp or "rotary_emb.sin_cached" in name):
if ".ff_proj" in name: # Models trained using ColossalAI may include these tensors in
name = name.replace(".ff_proj", ".mlp.ff_proj") # the checkpoint. Skip them.
# Reverse the weight for the MergeColumnParallelLinear continue
loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1]) for (param_name, weight_name, shard_id) in stacked_params_mapping:
if ".ff_out" in name and "transformer.ff_out" not in name: if weight_name not in name:
name = name.replace(".ff_out", ".mlp.ff_out") continue
# there is no bias in olmo name = name.replace(weight_name, param_name)
param = params_dict[name] # Skip loading extra bias for GPTQ models.
weight_loader = getattr(param, "weight_loader", if name.endswith(".bias") and name not in params_dict:
default_weight_loader) continue
weight_loader(param, loaded_weight) param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -60,7 +61,7 @@ class OPTAttention(nn.Module): ...@@ -60,7 +61,7 @@ class OPTAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
bias: bool = True, bias: bool = True,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -77,13 +78,13 @@ class OPTAttention(nn.Module): ...@@ -77,13 +78,13 @@ class OPTAttention(nn.Module):
self.head_dim, self.head_dim,
total_num_heads, total_num_heads,
bias=bias, bias=bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
embed_dim, embed_dim,
embed_dim, embed_dim,
bias=bias, bias=bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module): ...@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: OPTConfig, config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module): ...@@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
bias=config.enable_bias, bias=config.enable_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.do_layer_norm_before = config.do_layer_norm_before self.do_layer_norm_before = config.do_layer_norm_before
...@@ -127,16 +128,15 @@ class OPTDecoderLayer(nn.Module): ...@@ -127,16 +128,15 @@ class OPTDecoderLayer(nn.Module):
self.embed_dim, self.embed_dim,
config.ffn_dim, config.ffn_dim,
bias=config.enable_bias, bias=config.enable_bias,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function, self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim) quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.ffn_dim, config.ffn_dim,
self.embed_dim, self.embed_dim,
bias=config.enable_bias, bias=config.enable_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
self.embed_dim, self.embed_dim,
...@@ -181,7 +181,7 @@ class OPTDecoder(nn.Module): ...@@ -181,7 +181,7 @@ class OPTDecoder(nn.Module):
def __init__( def __init__(
self, self,
config: OPTConfig, config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -202,7 +202,7 @@ class OPTDecoder(nn.Module): ...@@ -202,7 +202,7 @@ class OPTDecoder(nn.Module):
self.project_out = ReplicatedLinear(config.hidden_size, self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim, config.word_embed_proj_dim,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
else: else:
self.project_out = None self.project_out = None
...@@ -210,7 +210,7 @@ class OPTDecoder(nn.Module): ...@@ -210,7 +210,7 @@ class OPTDecoder(nn.Module):
self.project_in = ReplicatedLinear(config.word_embed_proj_dim, self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
else: else:
self.project_in = None self.project_in = None
...@@ -226,7 +226,7 @@ class OPTDecoder(nn.Module): ...@@ -226,7 +226,7 @@ class OPTDecoder(nn.Module):
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
OPTDecoderLayer(config, linear_method) OPTDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
...@@ -259,10 +259,10 @@ class OPTModel(nn.Module): ...@@ -259,10 +259,10 @@ class OPTModel(nn.Module):
def __init__( def __init__(
self, self,
config: OPTConfig, config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.decoder = OPTDecoder(config, linear_method) self.decoder = OPTDecoder(config, quant_config)
def forward( def forward(
self, self,
...@@ -279,12 +279,12 @@ class OPTForCausalLM(nn.Module): ...@@ -279,12 +279,12 @@ class OPTForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = OPTModel(config, linear_method) self.model = OPTModel(config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -13,11 +13,12 @@ from transformers import PretrainedConfig ...@@ -13,11 +13,12 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -34,17 +35,17 @@ class OrionMLP(nn.Module): ...@@ -34,17 +35,17 @@ class OrionMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -67,7 +68,7 @@ class OrionAttention(nn.Module): ...@@ -67,7 +68,7 @@ class OrionAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -98,13 +99,13 @@ class OrionAttention(nn.Module): ...@@ -98,13 +99,13 @@ class OrionAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module): ...@@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -154,13 +155,13 @@ class OrionDecoderLayer(nn.Module): ...@@ -154,13 +155,13 @@ class OrionDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
) )
self.mlp = OrionMLP( self.mlp = OrionMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
...@@ -201,7 +202,7 @@ class OrionModel(nn.Module): ...@@ -201,7 +202,7 @@ class OrionModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -212,7 +213,7 @@ class OrionModel(nn.Module): ...@@ -212,7 +213,7 @@ class OrionModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
OrionDecoderLayer(config, linear_method) OrionDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module): ...@@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = OrionModel(config, linear_method) self.model = OrionModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -45,10 +45,11 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -45,10 +45,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -62,7 +63,7 @@ class PhiAttention(nn.Module): ...@@ -62,7 +63,7 @@ class PhiAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -80,12 +81,12 @@ class PhiAttention(nn.Module): ...@@ -80,12 +81,12 @@ class PhiAttention(nn.Module):
self.head_size, self.head_size,
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
linear_method=linear_method, quant_config=quant_config,
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
...@@ -125,7 +126,7 @@ class PhiMLP(nn.Module): ...@@ -125,7 +126,7 @@ class PhiMLP(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
n_inner = getattr(config, "n_inner", None) n_inner = getattr(config, "n_inner", None)
...@@ -134,14 +135,13 @@ class PhiMLP(nn.Module): ...@@ -134,14 +135,13 @@ class PhiMLP(nn.Module):
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
n_inner, n_inner,
linear_method=linear_method, quant_config=quant_config,
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
n_inner, n_inner,
config.hidden_size, config.hidden_size,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config, n_inner) self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -155,12 +155,12 @@ class PhiLayer(nn.Module): ...@@ -155,12 +155,12 @@ class PhiLayer(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, linear_method) self.self_attn = PhiAttention(config, quant_config)
self.mlp = PhiMLP(config, linear_method) self.mlp = PhiMLP(config, quant_config)
def forward( def forward(
self, self,
...@@ -186,14 +186,14 @@ class PhiModel(nn.Module): ...@@ -186,14 +186,14 @@ class PhiModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
PhiLayer(config, linear_method) PhiLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
...@@ -225,12 +225,12 @@ class PhiForCausalLM(nn.Module): ...@@ -225,12 +225,12 @@ class PhiForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = PhiModel(config, linear_method) self.model = PhiModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -14,11 +14,12 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -14,11 +14,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -35,17 +36,17 @@ class QWenMLP(nn.Module): ...@@ -35,17 +36,17 @@ class QWenMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str = "silu", hidden_act: str = "silu",
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.c_proj = RowParallelLinear(intermediate_size, self.c_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -67,7 +68,7 @@ class QWenAttention(nn.Module): ...@@ -67,7 +68,7 @@ class QWenAttention(nn.Module):
max_position_embeddings: int, max_position_embeddings: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -83,13 +84,13 @@ class QWenAttention(nn.Module): ...@@ -83,13 +84,13 @@ class QWenAttention(nn.Module):
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
...@@ -122,7 +123,7 @@ class QWenBlock(nn.Module): ...@@ -122,7 +123,7 @@ class QWenBlock(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -134,13 +135,13 @@ class QWenBlock(nn.Module): ...@@ -134,13 +135,13 @@ class QWenBlock(nn.Module):
config.max_position_embeddings, config.max_position_embeddings,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
linear_method=linear_method) quant_config=quant_config)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size, self.mlp = QWenMLP(config.hidden_size,
config.intermediate_size // 2, config.intermediate_size // 2,
linear_method=linear_method) quant_config=quant_config)
def forward( def forward(
self, self,
...@@ -174,7 +175,7 @@ class QWenModel(nn.Module): ...@@ -174,7 +175,7 @@ class QWenModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -185,7 +186,7 @@ class QWenModel(nn.Module): ...@@ -185,7 +186,7 @@ class QWenModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.h = nn.ModuleList([ self.h = nn.ModuleList([
QWenBlock(config, linear_method) QWenBlock(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module): ...@@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = QWenModel(config, linear_method) self.transformer = QWenModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -33,11 +33,12 @@ from vllm.config import LoRAConfig ...@@ -33,11 +33,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -54,17 +55,17 @@ class Qwen2MLP(nn.Module): ...@@ -54,17 +55,17 @@ class Qwen2MLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -86,7 +87,7 @@ class Qwen2Attention(nn.Module): ...@@ -86,7 +87,7 @@ class Qwen2Attention(nn.Module):
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
use_sliding_window: bool = False, use_sliding_window: bool = False,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -117,13 +118,13 @@ class Qwen2Attention(nn.Module): ...@@ -117,13 +118,13 @@ class Qwen2Attention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -159,7 +160,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -159,7 +160,7 @@ class Qwen2DecoderLayer(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
layer_idx: int, layer_idx: int,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -174,13 +175,13 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -174,13 +175,13 @@ class Qwen2DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
use_sliding_window=use_sliding_window, use_sliding_window=use_sliding_window,
linear_method=linear_method, quant_config=quant_config,
sliding_window=config.sliding_window) sliding_window=config.sliding_window)
self.mlp = Qwen2MLP( self.mlp = Qwen2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -221,7 +222,7 @@ class Qwen2Model(nn.Module): ...@@ -221,7 +222,7 @@ class Qwen2Model(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2Config, config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -233,7 +234,7 @@ class Qwen2Model(nn.Module): ...@@ -233,7 +234,7 @@ class Qwen2Model(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, linear_method) Qwen2DecoderLayer(config, layer_idx, quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2Config, config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
del lora_config del lora_config
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = Qwen2Model(config, linear_method) self.model = Qwen2Model(config, quant_config)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight self.lm_head_weight = self.model.embed_tokens.weight
......
...@@ -36,12 +36,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -36,12 +36,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -58,18 +59,18 @@ class Qwen2MoeMLP(nn.Module): ...@@ -58,18 +59,18 @@ class Qwen2MoeMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
reduce_results=reduce_results) reduce_results=reduce_results)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
...@@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -105,7 +106,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -105,7 +106,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
Qwen2MoeMLP(hidden_size=config.hidden_size, Qwen2MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
reduce_results=False) reduce_results=False)
for idx in range(self.n_routed_experts) for idx in range(self.n_routed_experts)
]) ])
...@@ -114,13 +115,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -114,13 +115,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts, self.n_routed_experts,
bias=False, bias=False,
linear_method=None) quant_config=None)
if config.shared_expert_intermediate_size > 0: if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP( self.shared_expert = Qwen2MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size, intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
reduce_results=False, reduce_results=False,
) )
else: else:
...@@ -186,7 +187,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -186,7 +187,7 @@ class Qwen2MoeAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -217,14 +218,14 @@ class Qwen2MoeAttention(nn.Module): ...@@ -217,14 +218,14 @@ class Qwen2MoeAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -260,7 +261,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -260,7 +261,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_idx: int, layer_idx: int,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -275,18 +276,18 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -275,18 +276,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
) )
if (config.num_experts is not None if (config.num_experts is not None
and (layer_idx + 1) % config.decoder_sparse_step == 0): and (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config, self.mlp = Qwen2MoeSparseMoeBlock(config=config,
linear_method=linear_method) quant_config=quant_config)
else: else:
self.mlp = Qwen2MoeMLP( self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -338,9 +339,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -338,9 +339,7 @@ class Qwen2MoeModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen2MoeDecoderLayer(config, Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
layer_idx,
linear_method=linear_method)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = Qwen2MoeModel(config, linear_method) self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment