Unverified Commit a62aaf1d authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

parent 603ad848
......@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
......@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
DeepseekMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
......@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
linear_method=None)
quant_config=None)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
......@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=False,
)
......@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_idx: int,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
......@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekDecoderLayer(config,
layer_idx,
linear_method=linear_method)
DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = DeepseekModel(config, linear_method)
self.quant_config = quant_config
self.model = DeepseekModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -115,7 +116,7 @@ class FalconAttention(nn.Module):
self.total_num_kv_heads,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
......@@ -129,7 +130,7 @@ class FalconAttention(nn.Module):
self.hidden_size,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
......@@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -201,8 +202,8 @@ class FalconMLP(nn.Module):
4 * hidden_size,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method)
quant_config = getattr(linear_method, "quant_config", None)
quant_config=quant_config)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
......@@ -212,7 +213,7 @@ class FalconMLP(nn.Module):
bias=config.bias,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results,
linear_method=linear_method)
quant_config=quant_config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
......@@ -229,13 +230,13 @@ class FalconDecoderLayer(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, linear_method)
self.mlp = FalconMLP(config, linear_method)
self.self_attention = FalconAttention(config, quant_config)
self.mlp = FalconMLP(config, quant_config)
self.config = config
if config.new_decoder_architecture:
......@@ -311,7 +312,7 @@ class FalconModel(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -327,7 +328,7 @@ class FalconModel(nn.Module):
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config, linear_method)
FalconDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
......@@ -359,12 +360,12 @@ class FalconForCausalLM(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = FalconModel(config, linear_method)
self.quant_config = quant_config
self.transformer = FalconModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -77,17 +78,17 @@ class GemmaMLP(nn.Module):
intermediate_size: int,
hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
def forward(self, x):
......@@ -106,7 +107,7 @@ class GemmaAttention(nn.Module):
head_dim: int,
max_position_embeddings: int = 8192,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None) -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
......@@ -135,13 +136,13 @@ class GemmaAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -187,14 +188,14 @@ class GemmaDecoderLayer(nn.Module):
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
linear_method=linear_method,
quant_config=quant_config,
)
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -235,7 +236,7 @@ class GemmaModel(nn.Module):
def __init__(
self,
config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
......@@ -245,7 +246,7 @@ class GemmaModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
GemmaDecoderLayer(config, linear_method)
GemmaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module):
def __init__(
self,
config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = GemmaModel(config, linear_method)
self.quant_config = quant_config
self.model = GemmaModel(config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -27,10 +27,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -44,7 +45,7 @@ class GPT2Attention(nn.Module):
def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -61,13 +62,13 @@ class GPT2Attention(nn.Module):
self.head_dim,
total_num_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
......@@ -90,7 +91,7 @@ class GPT2MLP(nn.Module):
self,
intermediate_size: int,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -98,15 +99,15 @@ class GPT2MLP(nn.Module):
hidden_size,
intermediate_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
......@@ -122,7 +123,7 @@ class GPT2Block(nn.Module):
def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -130,9 +131,9 @@ class GPT2Block(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, linear_method)
self.attn = GPT2Attention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, linear_method)
self.mlp = GPT2MLP(inner_dim, config, quant_config)
def forward(
self,
......@@ -163,7 +164,7 @@ class GPT2Model(nn.Module):
def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -174,7 +175,7 @@ class GPT2Model(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPT2Block(config, linear_method)
GPT2Block(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -203,12 +204,12 @@ class GPT2LMHeadModel(nn.Module):
def __init__(
self,
config: GPT2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = GPT2Model(config, linear_method)
self.quant_config = quant_config
self.transformer = GPT2Model(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -72,14 +73,14 @@ class GPTBigCodeAttention(nn.Module):
total_num_heads,
total_num_kv_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
......@@ -111,7 +112,7 @@ class GPTBigMLP(nn.Module):
self,
intermediate_size: int,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -119,15 +120,15 @@ class GPTBigMLP(nn.Module):
hidden_size,
intermediate_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
......@@ -143,7 +144,7 @@ class GPTBigCodeBlock(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -151,9 +152,9 @@ class GPTBigCodeBlock(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, linear_method)
self.attn = GPTBigCodeAttention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, linear_method)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
def forward(
self,
......@@ -184,7 +185,7 @@ class GPTBigCodeModel(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -195,7 +196,7 @@ class GPTBigCodeModel(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPTBigCodeBlock(config, linear_method)
GPTBigCodeBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -224,12 +225,12 @@ class GPTBigCodeForCausalLM(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = GPTBigCodeModel(config, linear_method)
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -44,7 +45,7 @@ class GPTJAttention(nn.Module):
def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
......@@ -56,13 +57,13 @@ class GPTJAttention(nn.Module):
self.head_size,
self.total_num_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
......@@ -105,21 +106,21 @@ class GPTJMLP(nn.Module):
self,
intermediate_size: int,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(
hidden_size,
intermediate_size,
linear_method=linear_method,
quant_config=quant_config,
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
......@@ -135,14 +136,14 @@ class GPTJBlock(nn.Module):
def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, linear_method)
self.mlp = GPTJMLP(inner_dim, config, linear_method)
self.attn = GPTJAttention(config, quant_config)
self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward(
self,
......@@ -169,7 +170,7 @@ class GPTJModel(nn.Module):
def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -179,7 +180,7 @@ class GPTJModel(nn.Module):
self.embed_dim,
)
self.h = nn.ModuleList(
[GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
[GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
......@@ -207,13 +208,13 @@ class GPTJForCausalLM(nn.Module):
def __init__(
self,
config: GPTJConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.quant_config = quant_config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config, linear_method)
self.transformer = GPTJModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.n_embd,
......
......@@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
......@@ -63,13 +64,13 @@ class GPTNeoXAttention(nn.Module):
self.head_size,
self.total_num_heads,
bias=self.bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=self.bias,
linear_method=linear_method,
quant_config=quant_config,
)
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
......@@ -105,20 +106,20 @@ class GPTNeoXMLP(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
quant_config=quant_config,
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
......@@ -134,7 +135,7 @@ class GPTNeoXLayer(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
......@@ -142,8 +143,8 @@ class GPTNeoXLayer(nn.Module):
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, linear_method)
self.mlp = GPTNeoXMLP(config, linear_method)
self.attention = GPTNeoXAttention(config, quant_config)
self.mlp = GPTNeoXMLP(config, quant_config)
def forward(
self,
......@@ -182,7 +183,7 @@ class GPTNeoXModel(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -192,7 +193,7 @@ class GPTNeoXModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
GPTNeoXLayer(config, linear_method)
GPTNeoXLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
......@@ -223,12 +224,12 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.gpt_neox = GPTNeoXModel(config, linear_method)
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, quant_config)
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
......
......@@ -9,11 +9,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -30,17 +31,17 @@ class InternLM2MLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.w2 = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -63,7 +64,7 @@ class InternLM2Attention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -94,13 +95,13 @@ class InternLM2Attention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -150,13 +151,13 @@ class InternLMDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -195,7 +196,7 @@ class InternLM2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
......@@ -206,7 +207,7 @@ class InternLM2Model(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config, linear_method)
InternLMDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = InternLM2Model(config, linear_method)
self.quant_config = quant_config
self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -29,10 +29,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -68,7 +69,7 @@ class JAISAttention(nn.Module):
def __init__(
self,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -88,13 +89,13 @@ class JAISAttention(nn.Module):
self.head_dim,
total_num_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
tp_rank = get_tensor_model_parallel_rank()
......@@ -128,7 +129,7 @@ class JAISMLP(nn.Module):
self,
intermediate_size: int,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -137,19 +138,19 @@ class JAISMLP(nn.Module):
hidden_size,
intermediate_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_fc2 = (ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
) if self.swiglu else None)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.act = SwiGLUActivation()
......@@ -169,7 +170,7 @@ class JAISBlock(nn.Module):
def __init__(
self,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -177,9 +178,9 @@ class JAISBlock(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, linear_method)
self.attn = JAISAttention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, linear_method)
self.mlp = JAISMLP(inner_dim, config, quant_config)
def forward(
self,
......@@ -210,7 +211,7 @@ class JAISModel(nn.Module):
def __init__(
self,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -227,7 +228,7 @@ class JAISModel(nn.Module):
else:
self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, linear_method)
JAISBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module):
def __init__(
self,
config: JAISConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = JAISModel(config, linear_method)
self.quant_config = quant_config
self.transformer = JAISModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
......
......@@ -33,11 +33,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -56,17 +57,17 @@ class LlamaMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QKVParallelLinear] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -89,7 +90,7 @@ class LlamaAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
......@@ -131,13 +132,13 @@ class LlamaAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -174,7 +175,7 @@ class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -199,7 +200,7 @@ class LlamaDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
bias=attention_bias,
sliding_window=sliding_window,
)
......@@ -207,7 +208,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -248,7 +249,7 @@ class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
......@@ -264,7 +265,7 @@ class LlamaModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method)
LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -329,13 +330,12 @@ class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
self.model = LlamaModel(config, quant_config, lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
......@@ -9,8 +9,9 @@ from transformers import CLIPVisionModel, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
def __init__(self,
config: "LlavaConfig",
vision_language_config: VisionLanguageConfig,
linear_method: Optional["LinearMethodBase"] = None) -> None:
quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__()
self.config = config
......@@ -83,8 +84,8 @@ class LlavaForConditionalGeneration(nn.Module):
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.linear_method = linear_method
self.language_model = LlamaModel(config.text_config, linear_method)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
......
......@@ -35,12 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -84,7 +85,7 @@ class MiniCPMMoE(nn.Module):
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
linear_method=None)
quant_config=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
......@@ -147,17 +148,17 @@ class MiniCPMMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -180,7 +181,7 @@ class MiniCPMAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -211,13 +212,13 @@ class MiniCPMAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
......@@ -274,7 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0:
......@@ -282,7 +283,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
else:
self.mlp = MiniCPMMoE(num_experts=config.num_experts,
......@@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
......@@ -345,7 +346,7 @@ class MiniCPMModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MiniCPMDecoderLayer(config, linear_method)
MiniCPMDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.num_experts = getattr(self.config, "num_experts", 0)
self.linear_method = linear_method
self.quant_config = quant_config
self.model = MiniCPMModel(config,
linear_method,
quant_config,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
......
......@@ -27,6 +27,7 @@ import torch
from torch import nn
from transformers import MixtralConfig
from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
......@@ -34,13 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod,
per_tensor_quantize)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -69,7 +70,7 @@ class MixtralMoE(nn.Module):
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
......@@ -79,7 +80,7 @@ class MixtralMoE(nn.Module):
self.intermediate_size = intermediate_size // self.tp_size
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
self.use_fp8 = isinstance(quant_config, Fp8Config)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
......@@ -89,7 +90,7 @@ class MixtralMoE(nn.Module):
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
linear_method=None)
quant_config=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
......@@ -140,10 +141,10 @@ class MixtralMoE(nn.Module):
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant(
self.ws.data[expert, :, :])
w2s[expert, :, :], self.w2s_scale[
expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :])
self.ws = nn.Parameter(ws, requires_grad=False)
self.w2s = nn.Parameter(w2s, requires_grad=False)
......@@ -178,7 +179,7 @@ class MixtralAttention(nn.Module):
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -203,12 +204,12 @@ class MixtralAttention(nn.Module):
self.rope_theta = rope_theta
self.sliding_window = sliding_window
if isinstance(linear_method, Fp8LinearMethod):
if isinstance(quant_config, Fp8Config):
print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
)
linear_method = None
quant_config = None
self.qkv_proj = QKVParallelLinear(
hidden_size,
......@@ -216,13 +217,13 @@ class MixtralAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -259,7 +260,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -272,13 +273,13 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method)
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
linear_method=linear_method)
quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
......@@ -318,7 +319,7 @@ class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
......@@ -334,7 +335,7 @@ class MixtralModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method)
MixtralDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -384,14 +385,13 @@ class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config,
linear_method,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
......
......@@ -34,11 +34,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -55,7 +56,7 @@ class MixtralMLP(nn.Module):
num_experts: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
......@@ -65,15 +66,15 @@ class MixtralMLP(nn.Module):
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
......@@ -92,7 +93,7 @@ class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -115,14 +116,14 @@ class MixtralMoE(nn.Module):
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
quant_config=quant_config)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts,
bias=False,
linear_method=None)
quant_config=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
......@@ -162,7 +163,7 @@ class MixtralAttention(nn.Module):
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -193,13 +194,13 @@ class MixtralAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -249,9 +250,9 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method)
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(config=config,
linear_method=linear_method)
quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
......@@ -291,7 +292,7 @@ class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
......@@ -302,7 +303,7 @@ class MixtralModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method)
MixtralDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.quant_config = quant_config
self.model = MixtralModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -11,10 +11,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -42,7 +43,7 @@ class MPTAttention(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
......@@ -65,7 +66,7 @@ class MPTAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias,
linear_method=linear_method,
quant_config=quant_config,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
......@@ -74,7 +75,7 @@ class MPTAttention(nn.Module):
self.d_model,
self.d_model,
bias=not config.no_bias,
linear_method=linear_method,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
......@@ -133,7 +134,7 @@ class MPTMLP(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
......@@ -143,15 +144,15 @@ class MPTMLP(nn.Module):
hidden_size,
intermediate_size,
bias=not config.no_bias,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=not config.no_bias,
linear_method=linear_method,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -166,14 +167,14 @@ class MPTBlock(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, linear_method)
self.attn = MPTAttention(config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, linear_method)
self.ffn = MPTMLP(config, quant_config)
def forward(
self,
......@@ -201,7 +202,7 @@ class MPTModel(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
assert config.embedding_fraction == 1.0
......@@ -212,7 +213,7 @@ class MPTModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config, linear_method) for _ in range(config.n_layers)])
[MPTBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
......@@ -246,14 +247,14 @@ class MPTForCausalLM(nn.Module):
def __init__(
self,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
assert config.tie_word_embeddings
self.linear_method = linear_method
self.quant_config = quant_config
self.transformer = MPTModel(config, linear_method)
self.transformer = MPTModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -30,11 +30,12 @@ from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -54,7 +55,7 @@ class OlmoAttention(nn.Module):
def __init__(
self,
config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -79,7 +80,7 @@ class OlmoAttention(nn.Module):
self.head_dim,
self.total_num_heads,
bias=config.attention_bias,
linear_method=linear_method,
quant_config=quant_config,
)
# Rotary embeddings.
......@@ -99,7 +100,7 @@ class OlmoAttention(nn.Module):
self.hidden_size,
self.hidden_size,
bias=config.attention_bias,
linear_method=linear_method,
quant_config=quant_config,
)
def forward(
......@@ -129,7 +130,7 @@ class OlmoMLP(nn.Module):
def __init__(
self,
config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -141,7 +142,7 @@ class OlmoMLP(nn.Module):
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
# Activation function.
......@@ -152,7 +153,7 @@ class OlmoMLP(nn.Module):
self.intermediate_size,
self.hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
def forward(
......@@ -174,13 +175,13 @@ class OlmoDecoderLayer(nn.Module):
def __init__(self,
config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# Attention block.
self.self_attn = OlmoAttention(config, linear_method)
self.self_attn = OlmoAttention(config, quant_config)
# MLP block.
self.mlp = OlmoMLP(config, linear_method)
self.mlp = OlmoMLP(config, quant_config)
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
......@@ -216,14 +217,14 @@ class OlmoModel(nn.Module):
def __init__(self,
config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
OlmoDecoderLayer(config, linear_method)
OlmoDecoderLayer(config, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size,
......@@ -270,11 +271,10 @@ class OlmoForCausalLM(nn.Module):
def __init__(self,
config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = OlmoModel(config, linear_method)
self.model = OlmoModel(config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
......
......@@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -60,7 +61,7 @@ class OPTAttention(nn.Module):
embed_dim: int,
num_heads: int,
bias: bool = True,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_dim = embed_dim
......@@ -77,13 +78,13 @@ class OPTAttention(nn.Module):
self.head_dim,
total_num_heads,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
......@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.do_layer_norm_before = config.do_layer_norm_before
......@@ -127,16 +128,16 @@ class OPTDecoderLayer(nn.Module):
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
......@@ -181,7 +182,7 @@ class OPTDecoder(nn.Module):
def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -202,7 +203,7 @@ class OPTDecoder(nn.Module):
self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
else:
self.project_out = None
......@@ -210,7 +211,7 @@ class OPTDecoder(nn.Module):
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
else:
self.project_in = None
......@@ -226,7 +227,7 @@ class OPTDecoder(nn.Module):
self.final_layer_norm = None
self.layers = nn.ModuleList([
OPTDecoderLayer(config, linear_method)
OPTDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
......@@ -259,10 +260,10 @@ class OPTModel(nn.Module):
def __init__(
self,
config: OPTConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.decoder = OPTDecoder(config, linear_method)
self.decoder = OPTDecoder(config, quant_config)
def forward(
self,
......@@ -279,12 +280,12 @@ class OPTForCausalLM(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = OPTModel(config, linear_method)
self.quant_config = quant_config
self.model = OPTModel(config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -13,11 +13,12 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -34,17 +35,17 @@ class OrionMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -67,7 +68,7 @@ class OrionAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -98,13 +99,13 @@ class OrionAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -154,13 +155,13 @@ class OrionDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
self.mlp = OrionMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
......@@ -201,7 +202,7 @@ class OrionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
......@@ -212,7 +213,7 @@ class OrionModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
OrionDecoderLayer(config, linear_method)
OrionDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = OrionModel(config, linear_method)
self.quant_config = quant_config
self.model = OrionModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -45,10 +45,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -62,7 +63,7 @@ class PhiAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
......@@ -80,12 +81,12 @@ class PhiAttention(nn.Module):
self.head_size,
self.total_num_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
linear_method=linear_method,
quant_config=quant_config,
)
scaling = self.head_size**-0.5
......@@ -125,7 +126,7 @@ class PhiMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
n_inner = getattr(config, "n_inner", None)
......@@ -134,14 +135,14 @@ class PhiMLP(nn.Module):
self.fc1 = ColumnParallelLinear(
config.hidden_size,
n_inner,
linear_method=linear_method,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
def forward(self, hidden_states):
......@@ -155,12 +156,12 @@ class PhiLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, linear_method)
self.mlp = PhiMLP(config, linear_method)
self.self_attn = PhiAttention(config, quant_config)
self.mlp = PhiMLP(config, quant_config)
def forward(
self,
......@@ -186,14 +187,14 @@ class PhiModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PhiLayer(config, linear_method)
PhiLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
......@@ -225,12 +226,12 @@ class PhiForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.quant_config = quant_config
self.model = PhiModel(config, linear_method)
self.model = PhiModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
......
......@@ -14,11 +14,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -35,17 +36,17 @@ class QWenMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -67,7 +68,7 @@ class QWenAttention(nn.Module):
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -83,13 +84,13 @@ class QWenAttention(nn.Module):
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.scaling = self.head_dim**-0.5
......@@ -122,7 +123,7 @@ class QWenBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -134,13 +135,13 @@ class QWenBlock(nn.Module):
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
linear_method=linear_method)
quant_config=quant_config)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size,
config.intermediate_size // 2,
linear_method=linear_method)
quant_config=quant_config)
def forward(
self,
......@@ -174,7 +175,7 @@ class QWenModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -185,7 +186,7 @@ class QWenModel(nn.Module):
config.hidden_size,
)
self.h = nn.ModuleList([
QWenBlock(config, linear_method)
QWenBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = QWenModel(config, linear_method)
self.quant_config = quant_config
self.transformer = QWenModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment