Unverified Commit 0fca3cdc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Enhance attention selector (#4751)

parent e7c46b95
......@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -77,6 +78,7 @@ class FalconAttention(nn.Module):
def __init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -168,7 +170,8 @@ class FalconAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -229,12 +232,14 @@ class FalconDecoderLayer(nn.Module):
def __init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, quant_config)
self.self_attention = FalconAttention(config, cache_config,
quant_config)
self.mlp = FalconMLP(config, quant_config)
self.config = config
......@@ -311,6 +316,7 @@ class FalconModel(nn.Module):
def __init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -327,7 +333,7 @@ class FalconModel(nn.Module):
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config, quant_config)
FalconDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
......@@ -359,12 +365,13 @@ class FalconForCausalLM(nn.Module):
def __init__(
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(config, quant_config)
self.transformer = FalconModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -22,7 +22,7 @@ from torch import nn
from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
......@@ -107,6 +107,7 @@ class GemmaAttention(nn.Module):
head_dim: int,
max_position_embeddings: int = 8192,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -155,7 +156,8 @@ class GemmaAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -177,6 +179,7 @@ class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -188,6 +191,7 @@ class GemmaDecoderLayer(nn.Module):
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
)
self.mlp = GemmaMLP(
......@@ -236,6 +240,7 @@ class GemmaModel(nn.Module):
def __init__(
self,
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -246,7 +251,7 @@ class GemmaModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
GemmaDecoderLayer(config, quant_config)
GemmaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -309,6 +314,7 @@ class GemmaForCausalLM(nn.Module):
def __init__(
self,
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
......@@ -316,7 +322,7 @@ class GemmaForCausalLM(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = GemmaModel(config, quant_config)
self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -24,6 +24,7 @@ from torch import nn
from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -45,6 +46,7 @@ class GPT2Attention(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -70,7 +72,10 @@ class GPT2Attention(nn.Module):
bias=True,
quant_config=quant_config,
)
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
cache_config=cache_config)
def forward(
self,
......@@ -122,6 +127,7 @@ class GPT2Block(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -130,7 +136,7 @@ class GPT2Block(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, quant_config)
self.attn = GPT2Attention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config)
......@@ -163,6 +169,7 @@ class GPT2Model(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -174,7 +181,7 @@ class GPT2Model(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPT2Block(config, quant_config)
GPT2Block(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config, quant_config)
self.transformer = GPT2Model(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -25,6 +25,7 @@ from torch import nn
from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -46,6 +47,7 @@ class GPTBigCodeAttention(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -85,7 +87,8 @@ class GPTBigCodeAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -143,6 +146,7 @@ class GPTBigCodeBlock(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -151,7 +155,7 @@ class GPTBigCodeBlock(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, quant_config)
self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
......@@ -184,6 +188,7 @@ class GPTBigCodeModel(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -195,7 +200,7 @@ class GPTBigCodeModel(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPTBigCodeBlock(config, quant_config)
GPTBigCodeBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -224,12 +229,13 @@ class GPTBigCodeForCausalLM(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, quant_config)
self.transformer = GPTBigCodeModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -23,6 +23,7 @@ from torch import nn
from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -45,6 +46,7 @@ class GPTJAttention(nn.Module):
def __init__(
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -83,7 +85,10 @@ class GPTJAttention(nn.Module):
base=rope_theta,
is_neox_style=False,
)
self.attn = Attention(self.num_heads, self.head_size, scaling)
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
def forward(
self,
......@@ -135,13 +140,14 @@ class GPTJBlock(nn.Module):
def __init__(
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, quant_config)
self.attn = GPTJAttention(config, cache_config, quant_config)
self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward(
......@@ -169,6 +175,7 @@ class GPTJModel(nn.Module):
def __init__(
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -178,8 +185,10 @@ class GPTJModel(nn.Module):
config.vocab_size,
self.embed_dim,
)
self.h = nn.ModuleList(
[GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
self.h = nn.ModuleList([
GPTJBlock(config, cache_config, quant_config)
for _ in range(config.n_layer)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
......@@ -207,13 +216,14 @@ class GPTJForCausalLM(nn.Module):
def __init__(
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config, quant_config)
self.transformer = GPTJModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.n_embd,
......
......@@ -23,6 +23,7 @@ from torch import nn
from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -45,6 +46,7 @@ class GPTNeoXAttention(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -84,7 +86,10 @@ class GPTNeoXAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads, self.head_size, scaling)
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
def forward(
self,
......@@ -134,6 +139,7 @@ class GPTNeoXLayer(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -142,7 +148,7 @@ class GPTNeoXLayer(nn.Module):
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, quant_config)
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
self.mlp = GPTNeoXMLP(config, quant_config)
def forward(
......@@ -182,6 +188,7 @@ class GPTNeoXModel(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -192,7 +199,7 @@ class GPTNeoXModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
GPTNeoXLayer(config, quant_config)
GPTNeoXLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
......@@ -223,12 +230,13 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, quant_config)
self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
......
......@@ -6,6 +6,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -64,6 +65,7 @@ class InternLM2Attention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -114,7 +116,8 @@ class InternLM2Attention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -136,6 +139,7 @@ class InternLMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -151,6 +155,7 @@ class InternLMDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.feed_forward = InternLM2MLP(
......@@ -196,6 +201,7 @@ class InternLM2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -207,7 +213,7 @@ class InternLM2Model(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config, quant_config)
InternLMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -239,12 +245,13 @@ class InternLM2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(config, quant_config)
self.model = InternLM2Model(config, cache_config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -26,6 +26,7 @@ import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -69,6 +70,7 @@ class JAISAttention(nn.Module):
def __init__(
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -108,6 +110,7 @@ class JAISAttention(nn.Module):
self.head_dim,
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
)
def forward(
......@@ -170,6 +173,7 @@ class JAISBlock(nn.Module):
def __init__(
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -178,7 +182,7 @@ class JAISBlock(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, quant_config)
self.attn = JAISAttention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config)
......@@ -211,6 +215,7 @@ class JAISModel(nn.Module):
def __init__(
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -228,7 +233,7 @@ class JAISModel(nn.Module):
else:
self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, quant_config)
JAISBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -262,12 +267,13 @@ class JAISLMHeadModel(nn.Module):
def __init__(
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, quant_config)
self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
......
......@@ -28,7 +28,7 @@ from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -94,6 +94,7 @@ class LlamaAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -153,7 +154,8 @@ class LlamaAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
sliding_window=sliding_window,
cache_config=cache_config)
def forward(
self,
......@@ -176,6 +178,7 @@ class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -204,6 +207,7 @@ class LlamaDecoderLayer(nn.Module):
quant_config=quant_config,
bias=attention_bias,
sliding_window=sliding_window,
cache_config=cache_config,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
......@@ -251,6 +255,7 @@ class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
......@@ -267,7 +272,7 @@ class LlamaModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config)
LlamaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -332,12 +337,16 @@ class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.model = LlamaModel(config, quant_config, lora_config=lora_config)
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
......@@ -7,7 +7,7 @@ from torch import nn
from transformers import CLIPVisionModel, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
......@@ -62,6 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
def __init__(self,
config: "LlavaConfig",
vision_language_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__()
self.config = config
......@@ -85,7 +86,8 @@ class LlavaForConditionalGeneration(nn.Module):
projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, quant_config)
self.language_model = LlamaModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
......
......@@ -28,7 +28,7 @@ import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -181,6 +181,7 @@ class MiniCPMAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -234,7 +235,8 @@ class MiniCPMAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -259,6 +261,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -275,6 +278,7 @@ class MiniCPMDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.num_experts = getattr(self.config, "num_experts", 0)
......@@ -330,6 +334,7 @@ class MiniCPMModel(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
......@@ -346,7 +351,7 @@ class MiniCPMModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MiniCPMDecoderLayer(config, quant_config)
MiniCPMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -413,6 +418,7 @@ class MiniCPMForCausalLM(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
......@@ -421,6 +427,7 @@ class MiniCPMForCausalLM(nn.Module):
self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config
self.model = MiniCPMModel(config,
cache_config,
quant_config,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
......
......@@ -29,7 +29,7 @@ from transformers import MixtralConfig
from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -252,6 +252,7 @@ class MixtralAttention(nn.Module):
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
......@@ -313,6 +314,7 @@ class MixtralAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
def forward(
......@@ -335,6 +337,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -348,6 +351,7 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
......@@ -394,6 +398,7 @@ class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
......@@ -410,7 +415,9 @@ class MixtralModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, quant_config=quant_config)
MixtralDecoderLayer(config,
cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -460,12 +467,14 @@ class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.model = MixtralModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
......
......@@ -30,6 +30,7 @@ from torch import nn
from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -157,14 +158,17 @@ class MixtralMoE(nn.Module):
class MixtralAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
......@@ -215,6 +219,7 @@ class MixtralAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
def forward(
......@@ -237,6 +242,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -250,6 +256,7 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(config=config,
quant_config=quant_config)
......@@ -292,6 +299,7 @@ class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -303,7 +311,9 @@ class MixtralModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, quant_config=quant_config)
MixtralDecoderLayer(config,
cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -332,12 +342,13 @@ class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(config, quant_config)
self.model = MixtralModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -7,6 +7,7 @@ import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
......@@ -43,6 +44,7 @@ class MPTAttention(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -107,7 +109,8 @@ class MPTAttention(nn.Module):
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -166,12 +169,13 @@ class MPTBlock(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, quant_config)
self.attn = MPTAttention(config, cache_config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
......@@ -201,6 +205,7 @@ class MPTModel(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -211,8 +216,10 @@ class MPTModel(nn.Module):
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config, quant_config) for _ in range(config.n_layers)])
self.blocks = nn.ModuleList([
MPTBlock(config, cache_config, quant_config)
for _ in range(config.n_layers)
])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
......@@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -253,7 +261,7 @@ class MPTForCausalLM(nn.Module):
assert config.tie_word_embeddings
self.quant_config = quant_config
self.transformer = MPTModel(config, quant_config)
self.transformer = MPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -28,6 +28,7 @@ from torch import nn
from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -55,6 +56,7 @@ class OlmoAttention(nn.Module):
def __init__(
self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -93,7 +95,8 @@ class OlmoAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling)
scale=self.scaling,
cache_config=cache_config)
# Attention output projection.
self.o_proj = RowParallelLinear(
......@@ -175,10 +178,11 @@ class OlmoDecoderLayer(nn.Module):
def __init__(self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# Attention block.
self.self_attn = OlmoAttention(config, quant_config)
self.self_attn = OlmoAttention(config, cache_config, quant_config)
# MLP block.
self.mlp = OlmoMLP(config, quant_config)
......@@ -217,6 +221,7 @@ class OlmoModel(nn.Module):
def __init__(self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
......@@ -224,7 +229,7 @@ class OlmoModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
OlmoDecoderLayer(config, quant_config)
OlmoDecoderLayer(config, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size,
......@@ -271,10 +276,11 @@ class OlmoForCausalLM(nn.Module):
def __init__(self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.model = OlmoModel(config, quant_config)
self.model = OlmoModel(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
......
......@@ -24,6 +24,7 @@ from torch import nn
from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -61,6 +62,7 @@ class OPTAttention(nn.Module):
embed_dim: int,
num_heads: int,
bias: bool = True,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -88,7 +90,8 @@ class OPTAttention(nn.Module):
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling)
scale=self.scaling,
cache_config=cache_config)
def forward(
self,
......@@ -108,6 +111,7 @@ class OPTDecoderLayer(nn.Module):
def __init__(
self,
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -117,6 +121,7 @@ class OPTDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
cache_config=cache_config,
quant_config=quant_config,
)
self.do_layer_norm_before = config.do_layer_norm_before
......@@ -181,6 +186,7 @@ class OPTDecoder(nn.Module):
def __init__(
self,
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -226,7 +232,7 @@ class OPTDecoder(nn.Module):
self.final_layer_norm = None
self.layers = nn.ModuleList([
OPTDecoderLayer(config, quant_config)
OPTDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
......@@ -259,10 +265,11 @@ class OPTModel(nn.Module):
def __init__(
self,
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.decoder = OPTDecoder(config, quant_config)
self.decoder = OPTDecoder(config, cache_config, quant_config)
def forward(
self,
......@@ -279,12 +286,13 @@ class OPTForCausalLM(nn.Module):
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config, quant_config)
self.model = OPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -11,6 +11,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -68,6 +69,7 @@ class OrionAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -118,7 +120,8 @@ class OrionAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward(
self,
......@@ -140,6 +143,7 @@ class OrionDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -155,6 +159,7 @@ class OrionDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
self.mlp = OrionMLP(
......@@ -202,6 +207,7 @@ class OrionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -213,7 +219,7 @@ class OrionModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
OrionDecoderLayer(config, quant_config)
OrionDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -245,12 +251,13 @@ class OrionForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OrionModel(config, quant_config)
self.model = OrionModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -42,6 +42,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -63,6 +64,7 @@ class PhiAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.total_num_heads = config.num_attention_heads
......@@ -105,7 +107,10 @@ class PhiAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(self.num_heads, self.head_size, scaling)
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
def forward(
self,
......@@ -155,11 +160,12 @@ class PhiLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, quant_config)
self.self_attn = PhiAttention(config, cache_config, quant_config)
self.mlp = PhiMLP(config, quant_config)
def forward(
......@@ -186,6 +192,7 @@ class PhiModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
......@@ -193,7 +200,7 @@ class PhiModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PhiLayer(config, quant_config)
PhiLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
......@@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = PhiModel(config, quant_config)
self.model = PhiModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
......
......@@ -11,6 +11,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -68,6 +69,7 @@ class QWenAttention(nn.Module):
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -101,7 +103,10 @@ class QWenAttention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
def forward(
self,
......@@ -123,6 +128,7 @@ class QWenBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -135,6 +141,7 @@ class QWenBlock(nn.Module):
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
cache_config=cache_config,
quant_config=quant_config)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -175,6 +182,7 @@ class QWenModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -186,7 +194,7 @@ class QWenModel(nn.Module):
config.hidden_size,
)
self.h = nn.ModuleList([
QWenBlock(config, quant_config)
QWenBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -218,12 +226,13 @@ class QWenLMHeadModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = QWenModel(config, quant_config)
self.transformer = QWenModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -29,7 +29,7 @@ from torch import nn
from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -87,6 +87,7 @@ class Qwen2Attention(nn.Module):
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
......@@ -137,7 +138,8 @@ class Qwen2Attention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
sliding_window=self.sliding_window,
cache_config=cache_config)
def forward(
self,
......@@ -160,6 +162,7 @@ class Qwen2DecoderLayer(nn.Module):
self,
config: Qwen2Config,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -175,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
cache_config=cache_config,
quant_config=quant_config,
sliding_window=config.sliding_window)
self.mlp = Qwen2MLP(
......@@ -222,6 +226,7 @@ class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -234,7 +239,7 @@ class Qwen2Model(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, quant_config)
Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -287,6 +292,7 @@ class Qwen2ForCausalLM(nn.Module):
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
......@@ -294,7 +300,7 @@ class Qwen2ForCausalLM(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config)
self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment