"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "168cab6bbfb733f97defc8c1aa13df90c5319f19"
Unverified Commit 0fca3cdc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Enhance attention selector (#4751)

parent e7c46b95
...@@ -27,6 +27,7 @@ from torch.nn import LayerNorm ...@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -77,6 +78,7 @@ class FalconAttention(nn.Module): ...@@ -77,6 +78,7 @@ class FalconAttention(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -168,7 +170,8 @@ class FalconAttention(nn.Module): ...@@ -168,7 +170,8 @@ class FalconAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.inv_norm_factor, scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -229,12 +232,14 @@ class FalconDecoderLayer(nn.Module): ...@@ -229,12 +232,14 @@ class FalconDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, quant_config) self.self_attention = FalconAttention(config, cache_config,
quant_config)
self.mlp = FalconMLP(config, quant_config) self.mlp = FalconMLP(config, quant_config)
self.config = config self.config = config
...@@ -311,6 +316,7 @@ class FalconModel(nn.Module): ...@@ -311,6 +316,7 @@ class FalconModel(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -327,7 +333,7 @@ class FalconModel(nn.Module): ...@@ -327,7 +333,7 @@ class FalconModel(nn.Module):
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.h = nn.ModuleList([
FalconDecoderLayer(config, quant_config) FalconDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
...@@ -359,12 +365,13 @@ class FalconForCausalLM(nn.Module): ...@@ -359,12 +365,13 @@ class FalconForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = FalconModel(config, quant_config) self.transformer = FalconModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -22,7 +22,7 @@ from torch import nn ...@@ -22,7 +22,7 @@ from torch import nn
from transformers import GemmaConfig from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
...@@ -107,6 +107,7 @@ class GemmaAttention(nn.Module): ...@@ -107,6 +107,7 @@ class GemmaAttention(nn.Module):
head_dim: int, head_dim: int,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
rope_theta: float = 10000, rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -155,7 +156,8 @@ class GemmaAttention(nn.Module): ...@@ -155,7 +156,8 @@ class GemmaAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -177,6 +179,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -177,6 +179,7 @@ class GemmaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -188,6 +191,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -188,6 +191,7 @@ class GemmaDecoderLayer(nn.Module):
head_dim=config.head_dim, head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta, rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
self.mlp = GemmaMLP( self.mlp = GemmaMLP(
...@@ -236,6 +240,7 @@ class GemmaModel(nn.Module): ...@@ -236,6 +240,7 @@ class GemmaModel(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -246,7 +251,7 @@ class GemmaModel(nn.Module): ...@@ -246,7 +251,7 @@ class GemmaModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
GemmaDecoderLayer(config, quant_config) GemmaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -309,6 +314,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -309,6 +314,7 @@ class GemmaForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
...@@ -316,7 +322,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -316,7 +322,7 @@ class GemmaForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = GemmaModel(config, quant_config) self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -45,6 +46,7 @@ class GPT2Attention(nn.Module): ...@@ -45,6 +46,7 @@ class GPT2Attention(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -70,7 +72,10 @@ class GPT2Attention(nn.Module): ...@@ -70,7 +72,10 @@ class GPT2Attention(nn.Module):
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
) )
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -122,6 +127,7 @@ class GPT2Block(nn.Module): ...@@ -122,6 +127,7 @@ class GPT2Block(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -130,7 +136,7 @@ class GPT2Block(nn.Module): ...@@ -130,7 +136,7 @@ class GPT2Block(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, quant_config) self.attn = GPT2Attention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config) self.mlp = GPT2MLP(inner_dim, config, quant_config)
...@@ -163,6 +169,7 @@ class GPT2Model(nn.Module): ...@@ -163,6 +169,7 @@ class GPT2Model(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -174,7 +181,7 @@ class GPT2Model(nn.Module): ...@@ -174,7 +181,7 @@ class GPT2Model(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.h = nn.ModuleList([
GPT2Block(config, quant_config) GPT2Block(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module): ...@@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, quant_config) self.transformer = GPT2Model(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -25,6 +25,7 @@ from torch import nn ...@@ -25,6 +25,7 @@ from torch import nn
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -46,6 +47,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -46,6 +47,7 @@ class GPTBigCodeAttention(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -85,7 +87,8 @@ class GPTBigCodeAttention(nn.Module): ...@@ -85,7 +87,8 @@ class GPTBigCodeAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -143,6 +146,7 @@ class GPTBigCodeBlock(nn.Module): ...@@ -143,6 +146,7 @@ class GPTBigCodeBlock(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -151,7 +155,7 @@ class GPTBigCodeBlock(nn.Module): ...@@ -151,7 +155,7 @@ class GPTBigCodeBlock(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, quant_config) self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config) self.mlp = GPTBigMLP(inner_dim, config, quant_config)
...@@ -184,6 +188,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -184,6 +188,7 @@ class GPTBigCodeModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -195,7 +200,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -195,7 +200,7 @@ class GPTBigCodeModel(nn.Module):
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.h = nn.ModuleList([
GPTBigCodeBlock(config, quant_config) GPTBigCodeBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -224,12 +229,13 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -224,12 +229,13 @@ class GPTBigCodeForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, quant_config) self.transformer = GPTBigCodeModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -23,6 +23,7 @@ from torch import nn ...@@ -23,6 +23,7 @@ from torch import nn
from transformers import GPTJConfig from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -45,6 +46,7 @@ class GPTJAttention(nn.Module): ...@@ -45,6 +46,7 @@ class GPTJAttention(nn.Module):
def __init__( def __init__(
self, self,
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -83,7 +85,10 @@ class GPTJAttention(nn.Module): ...@@ -83,7 +85,10 @@ class GPTJAttention(nn.Module):
base=rope_theta, base=rope_theta,
is_neox_style=False, is_neox_style=False,
) )
self.attn = Attention(self.num_heads, self.head_size, scaling) self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -135,13 +140,14 @@ class GPTJBlock(nn.Module): ...@@ -135,13 +140,14 @@ class GPTJBlock(nn.Module):
def __init__( def __init__(
self, self,
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
inner_dim = (4 * config.n_embd inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner) if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, quant_config) self.attn = GPTJAttention(config, cache_config, quant_config)
self.mlp = GPTJMLP(inner_dim, config, quant_config) self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward( def forward(
...@@ -169,6 +175,7 @@ class GPTJModel(nn.Module): ...@@ -169,6 +175,7 @@ class GPTJModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -178,8 +185,10 @@ class GPTJModel(nn.Module): ...@@ -178,8 +185,10 @@ class GPTJModel(nn.Module):
config.vocab_size, config.vocab_size,
self.embed_dim, self.embed_dim,
) )
self.h = nn.ModuleList( self.h = nn.ModuleList([
[GPTJBlock(config, quant_config) for _ in range(config.n_layer)]) GPTJBlock(config, cache_config, quant_config)
for _ in range(config.n_layer)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -207,13 +216,14 @@ class GPTJForCausalLM(nn.Module): ...@@ -207,13 +216,14 @@ class GPTJForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
assert not config.tie_word_embeddings assert not config.tie_word_embeddings
self.transformer = GPTJModel(config, quant_config) self.transformer = GPTJModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.n_embd, config.n_embd,
......
...@@ -23,6 +23,7 @@ from torch import nn ...@@ -23,6 +23,7 @@ from torch import nn
from transformers import GPTNeoXConfig from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -45,6 +46,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -45,6 +46,7 @@ class GPTNeoXAttention(nn.Module):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -84,7 +86,10 @@ class GPTNeoXAttention(nn.Module): ...@@ -84,7 +86,10 @@ class GPTNeoXAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
) )
self.attn = Attention(self.num_heads, self.head_size, scaling) self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -134,6 +139,7 @@ class GPTNeoXLayer(nn.Module): ...@@ -134,6 +139,7 @@ class GPTNeoXLayer(nn.Module):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -142,7 +148,7 @@ class GPTNeoXLayer(nn.Module): ...@@ -142,7 +148,7 @@ class GPTNeoXLayer(nn.Module):
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, quant_config) self.attention = GPTNeoXAttention(config, cache_config, quant_config)
self.mlp = GPTNeoXMLP(config, quant_config) self.mlp = GPTNeoXMLP(config, quant_config)
def forward( def forward(
...@@ -182,6 +188,7 @@ class GPTNeoXModel(nn.Module): ...@@ -182,6 +188,7 @@ class GPTNeoXModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -192,7 +199,7 @@ class GPTNeoXModel(nn.Module): ...@@ -192,7 +199,7 @@ class GPTNeoXModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
GPTNeoXLayer(config, quant_config) GPTNeoXLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,
...@@ -223,12 +230,13 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -223,12 +230,13 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, quant_config) self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
self.embed_out = ParallelLMHead( self.embed_out = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -6,6 +6,7 @@ from torch import nn ...@@ -6,6 +6,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -64,6 +65,7 @@ class InternLM2Attention(nn.Module): ...@@ -64,6 +65,7 @@ class InternLM2Attention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -114,7 +116,8 @@ class InternLM2Attention(nn.Module): ...@@ -114,7 +116,8 @@ class InternLM2Attention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -136,6 +139,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -136,6 +139,7 @@ class InternLMDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -151,6 +155,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -151,6 +155,7 @@ class InternLMDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
self.feed_forward = InternLM2MLP( self.feed_forward = InternLM2MLP(
...@@ -196,6 +201,7 @@ class InternLM2Model(nn.Module): ...@@ -196,6 +201,7 @@ class InternLM2Model(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -207,7 +213,7 @@ class InternLM2Model(nn.Module): ...@@ -207,7 +213,7 @@ class InternLM2Model(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternLMDecoderLayer(config, quant_config) InternLMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -239,12 +245,13 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -239,12 +245,13 @@ class InternLM2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = InternLM2Model(config, quant_config) self.model = InternLM2Model(config, cache_config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -26,6 +26,7 @@ import torch ...@@ -26,6 +26,7 @@ import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -69,6 +70,7 @@ class JAISAttention(nn.Module): ...@@ -69,6 +70,7 @@ class JAISAttention(nn.Module):
def __init__( def __init__(
self, self,
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -108,6 +110,7 @@ class JAISAttention(nn.Module): ...@@ -108,6 +110,7 @@ class JAISAttention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config,
) )
def forward( def forward(
...@@ -170,6 +173,7 @@ class JAISBlock(nn.Module): ...@@ -170,6 +173,7 @@ class JAISBlock(nn.Module):
def __init__( def __init__(
self, self,
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -178,7 +182,7 @@ class JAISBlock(nn.Module): ...@@ -178,7 +182,7 @@ class JAISBlock(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, quant_config) self.attn = JAISAttention(config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config) self.mlp = JAISMLP(inner_dim, config, quant_config)
...@@ -211,6 +215,7 @@ class JAISModel(nn.Module): ...@@ -211,6 +215,7 @@ class JAISModel(nn.Module):
def __init__( def __init__(
self, self,
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -228,7 +233,7 @@ class JAISModel(nn.Module): ...@@ -228,7 +233,7 @@ class JAISModel(nn.Module):
else: else:
self.embeddings_scale = config.mup_embeddings_scale self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([ self.h = nn.ModuleList([
JAISBlock(config, quant_config) JAISBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -262,12 +267,13 @@ class JAISLMHeadModel(nn.Module): ...@@ -262,12 +267,13 @@ class JAISLMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = JAISModel(config, quant_config) self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
if hasattr(config, "width_scale"): if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale self.output_logits_scale = config.width_scale
......
...@@ -28,7 +28,7 @@ from torch import nn ...@@ -28,7 +28,7 @@ from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -94,6 +94,7 @@ class LlamaAttention(nn.Module): ...@@ -94,6 +94,7 @@ class LlamaAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -153,7 +154,8 @@ class LlamaAttention(nn.Module): ...@@ -153,7 +154,8 @@ class LlamaAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window) sliding_window=sliding_window,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -176,6 +178,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -176,6 +178,7 @@ class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -204,6 +207,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -204,6 +207,7 @@ class LlamaDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=attention_bias, bias=attention_bias,
sliding_window=sliding_window, sliding_window=sliding_window,
cache_config=cache_config,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -251,6 +255,7 @@ class LlamaModel(nn.Module): ...@@ -251,6 +255,7 @@ class LlamaModel(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
...@@ -267,7 +272,7 @@ class LlamaModel(nn.Module): ...@@ -267,7 +272,7 @@ class LlamaModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config) LlamaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -332,12 +337,16 @@ class LlamaForCausalLM(nn.Module): ...@@ -332,12 +337,16 @@ class LlamaForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.model = LlamaModel(config, quant_config, lora_config=lora_config) self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
...@@ -7,7 +7,7 @@ from torch import nn ...@@ -7,7 +7,7 @@ from torch import nn
from transformers import CLIPVisionModel, LlavaConfig from transformers import CLIPVisionModel, LlavaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -62,6 +62,7 @@ class LlavaForConditionalGeneration(nn.Module): ...@@ -62,6 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
def __init__(self, def __init__(self,
config: "LlavaConfig", config: "LlavaConfig",
vision_language_config: VisionLanguageConfig, vision_language_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional["QuantizationConfig"] = None) -> None: quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -85,7 +86,8 @@ class LlavaForConditionalGeneration(nn.Module): ...@@ -85,7 +86,8 @@ class LlavaForConditionalGeneration(nn.Module):
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, quant_config) self.language_model = LlamaModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, self.unpadded_vocab_size,
......
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -181,6 +181,7 @@ class MiniCPMAttention(nn.Module): ...@@ -181,6 +181,7 @@ class MiniCPMAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -234,7 +235,8 @@ class MiniCPMAttention(nn.Module): ...@@ -234,7 +235,8 @@ class MiniCPMAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -259,6 +261,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -259,6 +261,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -275,6 +278,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -275,6 +278,7 @@ class MiniCPMDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
...@@ -330,6 +334,7 @@ class MiniCPMModel(nn.Module): ...@@ -330,6 +334,7 @@ class MiniCPMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
...@@ -346,7 +351,7 @@ class MiniCPMModel(nn.Module): ...@@ -346,7 +351,7 @@ class MiniCPMModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MiniCPMDecoderLayer(config, quant_config) MiniCPMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -413,6 +418,7 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -413,6 +418,7 @@ class MiniCPMForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
...@@ -421,6 +427,7 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -421,6 +427,7 @@ class MiniCPMForCausalLM(nn.Module):
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config self.quant_config = quant_config
self.model = MiniCPMModel(config, self.model = MiniCPMModel(config,
cache_config,
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config)
unpadded_vocab_size = config.vocab_size unpadded_vocab_size = config.vocab_size
......
...@@ -29,7 +29,7 @@ from transformers import MixtralConfig ...@@ -29,7 +29,7 @@ from transformers import MixtralConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -252,6 +252,7 @@ class MixtralAttention(nn.Module): ...@@ -252,6 +252,7 @@ class MixtralAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
...@@ -313,6 +314,7 @@ class MixtralAttention(nn.Module): ...@@ -313,6 +314,7 @@ class MixtralAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
cache_config=cache_config,
) )
def forward( def forward(
...@@ -335,6 +337,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -335,6 +337,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -348,6 +351,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -348,6 +351,7 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.block_sparse_moe = MixtralMoE( self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
...@@ -394,6 +398,7 @@ class MixtralModel(nn.Module): ...@@ -394,6 +398,7 @@ class MixtralModel(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
...@@ -410,7 +415,9 @@ class MixtralModel(nn.Module): ...@@ -410,7 +415,9 @@ class MixtralModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MixtralDecoderLayer(config, quant_config=quant_config) MixtralDecoderLayer(config,
cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -460,12 +467,14 @@ class MixtralForCausalLM(nn.Module): ...@@ -460,12 +467,14 @@ class MixtralForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.model = MixtralModel(config, self.model = MixtralModel(config,
cache_config,
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
......
...@@ -30,6 +30,7 @@ from torch import nn ...@@ -30,6 +30,7 @@ from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -157,14 +158,17 @@ class MixtralMoE(nn.Module): ...@@ -157,14 +158,17 @@ class MixtralMoE(nn.Module):
class MixtralAttention(nn.Module): class MixtralAttention(nn.Module):
def __init__(self, def __init__(
hidden_size: int, self,
num_heads: int, hidden_size: int,
num_kv_heads: int, num_heads: int,
max_position: int = 4096 * 32, num_kv_heads: int,
rope_theta: float = 10000, max_position: int = 4096 * 32,
quant_config: Optional[QuantizationConfig] = None, rope_theta: float = 10000,
sliding_window: Optional[int] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -215,6 +219,7 @@ class MixtralAttention(nn.Module): ...@@ -215,6 +219,7 @@ class MixtralAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
cache_config=cache_config,
) )
def forward( def forward(
...@@ -237,6 +242,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -237,6 +242,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -250,6 +256,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -250,6 +256,7 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(config=config, self.block_sparse_moe = MixtralMoE(config=config,
quant_config=quant_config) quant_config=quant_config)
...@@ -292,6 +299,7 @@ class MixtralModel(nn.Module): ...@@ -292,6 +299,7 @@ class MixtralModel(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -303,7 +311,9 @@ class MixtralModel(nn.Module): ...@@ -303,7 +311,9 @@ class MixtralModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MixtralDecoderLayer(config, quant_config=quant_config) MixtralDecoderLayer(config,
cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -332,12 +342,13 @@ class MixtralForCausalLM(nn.Module): ...@@ -332,12 +342,13 @@ class MixtralForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel(config, quant_config) self.model = MixtralModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -43,6 +44,7 @@ class MPTAttention(nn.Module): ...@@ -43,6 +44,7 @@ class MPTAttention(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -107,7 +109,8 @@ class MPTAttention(nn.Module): ...@@ -107,7 +109,8 @@ class MPTAttention(nn.Module):
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -166,12 +169,13 @@ class MPTBlock(nn.Module): ...@@ -166,12 +169,13 @@ class MPTBlock(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size) self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, quant_config) self.attn = MPTAttention(config, cache_config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size) self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config) self.ffn = MPTMLP(config, quant_config)
...@@ -201,6 +205,7 @@ class MPTModel(nn.Module): ...@@ -201,6 +205,7 @@ class MPTModel(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -211,8 +216,10 @@ class MPTModel(nn.Module): ...@@ -211,8 +216,10 @@ class MPTModel(nn.Module):
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList([
[MPTBlock(config, quant_config) for _ in range(config.n_layers)]) MPTBlock(config, cache_config, quant_config)
for _ in range(config.n_layers)
])
self.norm_f = nn.LayerNorm(config.d_model) self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias: if config.no_bias:
for module in self.modules(): for module in self.modules():
...@@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module): ...@@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -253,7 +261,7 @@ class MPTForCausalLM(nn.Module): ...@@ -253,7 +261,7 @@ class MPTForCausalLM(nn.Module):
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = MPTModel(config, quant_config) self.transformer = MPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -28,6 +28,7 @@ from torch import nn ...@@ -28,6 +28,7 @@ from torch import nn
from transformers import OlmoConfig from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -55,6 +56,7 @@ class OlmoAttention(nn.Module): ...@@ -55,6 +56,7 @@ class OlmoAttention(nn.Module):
def __init__( def __init__(
self, self,
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -93,7 +95,8 @@ class OlmoAttention(nn.Module): ...@@ -93,7 +95,8 @@ class OlmoAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling) scale=self.scaling,
cache_config=cache_config)
# Attention output projection. # Attention output projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -175,10 +178,11 @@ class OlmoDecoderLayer(nn.Module): ...@@ -175,10 +178,11 @@ class OlmoDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
# Attention block. # Attention block.
self.self_attn = OlmoAttention(config, quant_config) self.self_attn = OlmoAttention(config, cache_config, quant_config)
# MLP block. # MLP block.
self.mlp = OlmoMLP(config, quant_config) self.mlp = OlmoMLP(config, quant_config)
...@@ -217,6 +221,7 @@ class OlmoModel(nn.Module): ...@@ -217,6 +221,7 @@ class OlmoModel(nn.Module):
def __init__(self, def __init__(self,
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -224,7 +229,7 @@ class OlmoModel(nn.Module): ...@@ -224,7 +229,7 @@ class OlmoModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
OlmoDecoderLayer(config, quant_config) OlmoDecoderLayer(config, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
...@@ -271,10 +276,11 @@ class OlmoForCausalLM(nn.Module): ...@@ -271,10 +276,11 @@ class OlmoForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = OlmoModel(config, quant_config) self.model = OlmoModel(config, cache_config, quant_config)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight self.lm_head_weight = self.model.embed_tokens.weight
else: else:
......
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from transformers import OPTConfig from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -61,6 +62,7 @@ class OPTAttention(nn.Module): ...@@ -61,6 +62,7 @@ class OPTAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
bias: bool = True, bias: bool = True,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -88,7 +90,8 @@ class OPTAttention(nn.Module): ...@@ -88,7 +90,8 @@ class OPTAttention(nn.Module):
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling) scale=self.scaling,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -108,6 +111,7 @@ class OPTDecoderLayer(nn.Module): ...@@ -108,6 +111,7 @@ class OPTDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: OPTConfig, config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -117,6 +121,7 @@ class OPTDecoderLayer(nn.Module): ...@@ -117,6 +121,7 @@ class OPTDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
bias=config.enable_bias, bias=config.enable_bias,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
self.do_layer_norm_before = config.do_layer_norm_before self.do_layer_norm_before = config.do_layer_norm_before
...@@ -181,6 +186,7 @@ class OPTDecoder(nn.Module): ...@@ -181,6 +186,7 @@ class OPTDecoder(nn.Module):
def __init__( def __init__(
self, self,
config: OPTConfig, config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -226,7 +232,7 @@ class OPTDecoder(nn.Module): ...@@ -226,7 +232,7 @@ class OPTDecoder(nn.Module):
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
OPTDecoderLayer(config, quant_config) OPTDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
...@@ -259,10 +265,11 @@ class OPTModel(nn.Module): ...@@ -259,10 +265,11 @@ class OPTModel(nn.Module):
def __init__( def __init__(
self, self,
config: OPTConfig, config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.decoder = OPTDecoder(config, quant_config) self.decoder = OPTDecoder(config, cache_config, quant_config)
def forward( def forward(
self, self,
...@@ -279,12 +286,13 @@ class OPTForCausalLM(nn.Module): ...@@ -279,12 +286,13 @@ class OPTForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OPTModel(config, quant_config) self.model = OPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -11,6 +11,7 @@ from torch import nn ...@@ -11,6 +11,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -68,6 +69,7 @@ class OrionAttention(nn.Module): ...@@ -68,6 +69,7 @@ class OrionAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -118,7 +120,8 @@ class OrionAttention(nn.Module): ...@@ -118,7 +120,8 @@ class OrionAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -140,6 +143,7 @@ class OrionDecoderLayer(nn.Module): ...@@ -140,6 +143,7 @@ class OrionDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -155,6 +159,7 @@ class OrionDecoderLayer(nn.Module): ...@@ -155,6 +159,7 @@ class OrionDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
self.mlp = OrionMLP( self.mlp = OrionMLP(
...@@ -202,6 +207,7 @@ class OrionModel(nn.Module): ...@@ -202,6 +207,7 @@ class OrionModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -213,7 +219,7 @@ class OrionModel(nn.Module): ...@@ -213,7 +219,7 @@ class OrionModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
OrionDecoderLayer(config, quant_config) OrionDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -245,12 +251,13 @@ class OrionForCausalLM(nn.Module): ...@@ -245,12 +251,13 @@ class OrionForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OrionModel(config, quant_config) self.model = OrionModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -42,6 +42,7 @@ from torch import nn ...@@ -42,6 +42,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -63,6 +64,7 @@ class PhiAttention(nn.Module): ...@@ -63,6 +64,7 @@ class PhiAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
...@@ -105,7 +107,10 @@ class PhiAttention(nn.Module): ...@@ -105,7 +107,10 @@ class PhiAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
) )
self.attn = Attention(self.num_heads, self.head_size, scaling) self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -155,11 +160,12 @@ class PhiLayer(nn.Module): ...@@ -155,11 +160,12 @@ class PhiLayer(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, quant_config) self.self_attn = PhiAttention(config, cache_config, quant_config)
self.mlp = PhiMLP(config, quant_config) self.mlp = PhiMLP(config, quant_config)
def forward( def forward(
...@@ -186,6 +192,7 @@ class PhiModel(nn.Module): ...@@ -186,6 +192,7 @@ class PhiModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -193,7 +200,7 @@ class PhiModel(nn.Module): ...@@ -193,7 +200,7 @@ class PhiModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
PhiLayer(config, quant_config) PhiLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
...@@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module): ...@@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = PhiModel(config, quant_config) self.model = PhiModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -11,6 +11,7 @@ from torch import nn ...@@ -11,6 +11,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -68,6 +69,7 @@ class QWenAttention(nn.Module): ...@@ -68,6 +69,7 @@ class QWenAttention(nn.Module):
max_position_embeddings: int, max_position_embeddings: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -101,7 +103,10 @@ class QWenAttention(nn.Module): ...@@ -101,7 +103,10 @@ class QWenAttention(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = Attention(self.num_heads, self.head_dim, self.scaling) self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -123,6 +128,7 @@ class QWenBlock(nn.Module): ...@@ -123,6 +128,7 @@ class QWenBlock(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -135,6 +141,7 @@ class QWenBlock(nn.Module): ...@@ -135,6 +141,7 @@ class QWenBlock(nn.Module):
config.max_position_embeddings, config.max_position_embeddings,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -175,6 +182,7 @@ class QWenModel(nn.Module): ...@@ -175,6 +182,7 @@ class QWenModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -186,7 +194,7 @@ class QWenModel(nn.Module): ...@@ -186,7 +194,7 @@ class QWenModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.h = nn.ModuleList([ self.h = nn.ModuleList([
QWenBlock(config, quant_config) QWenBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -218,12 +226,13 @@ class QWenLMHeadModel(nn.Module): ...@@ -218,12 +226,13 @@ class QWenLMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = QWenModel(config, quant_config) self.transformer = QWenModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -87,6 +87,7 @@ class Qwen2Attention(nn.Module): ...@@ -87,6 +87,7 @@ class Qwen2Attention(nn.Module):
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
use_sliding_window: bool = False, use_sliding_window: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
...@@ -137,7 +138,8 @@ class Qwen2Attention(nn.Module): ...@@ -137,7 +138,8 @@ class Qwen2Attention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window) sliding_window=self.sliding_window,
cache_config=cache_config)
def forward( def forward(
self, self,
...@@ -160,6 +162,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -160,6 +162,7 @@ class Qwen2DecoderLayer(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -175,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -175,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
use_sliding_window=use_sliding_window, use_sliding_window=use_sliding_window,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
sliding_window=config.sliding_window) sliding_window=config.sliding_window)
self.mlp = Qwen2MLP( self.mlp = Qwen2MLP(
...@@ -222,6 +226,7 @@ class Qwen2Model(nn.Module): ...@@ -222,6 +226,7 @@ class Qwen2Model(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2Config, config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -234,7 +239,7 @@ class Qwen2Model(nn.Module): ...@@ -234,7 +239,7 @@ class Qwen2Model(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, quant_config) Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -287,6 +292,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -287,6 +292,7 @@ class Qwen2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2Config, config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
...@@ -294,7 +300,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -294,7 +300,7 @@ class Qwen2ForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config) self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight self.lm_head_weight = self.model.embed_tokens.weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment