Unverified Commit eebad39f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] support all attention backends (#10558)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent db100c5c
......@@ -230,6 +230,7 @@ class GLMAttention(nn.Module):
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -285,7 +286,8 @@ class GLMAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -364,6 +366,7 @@ class GLMBlock(nn.Module):
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
......@@ -377,7 +380,10 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, cache_config, quant_config)
self.self_attention = GLMAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
......@@ -446,7 +452,8 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self.start_layer, self.end_layer, self.layers = make_layers(
self.num_layers,
lambda prefix: GLMBlock(config, cache_config, quant_config),
lambda prefix: GLMBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
......@@ -500,16 +507,22 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config)
self.encoder = GLMTransformer(config,
cache_config,
quant_config,
prefix=f"{prefix}.encoder")
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.output_layer")
vision_config_flag = getattr(config, 'vision_config', None)
if vision_config_flag is not None:
self.vision_config = Namespace(**config.vision_config)
self.vision = EVA2CLIPModel(self.config, quant_config)
self.vision = EVA2CLIPModel(self.config,
quant_config,
prefix=f"{prefix}.vision")
else:
self.vision = None
......
......@@ -120,6 +120,7 @@ class CohereAttention(nn.Module):
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
......@@ -175,7 +176,8 @@ class CohereAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
......@@ -215,13 +217,15 @@ class CohereDecoderLayer(nn.Module):
def __init__(self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config,
cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
......@@ -271,8 +275,8 @@ class CohereModel(nn.Module):
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CohereDecoderLayer(config, cache_config,
quant_config),
lambda prefix: CohereDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
......
......@@ -154,6 +154,7 @@ class DbrxAttention(nn.Module):
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
......@@ -208,7 +209,8 @@ class DbrxAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -234,10 +236,14 @@ class DbrxFusedNormAttention(nn.Module):
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, cache_config, quant_config)
self.attn = DbrxAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
......@@ -269,10 +275,14 @@ class DbrxBlock(nn.Module):
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
quant_config)
self.norm_attn_norm = DbrxFusedNormAttention(
config,
cache_config,
quant_config,
prefix=f"{prefix}.norm_attn_norm")
self.ffn = DbrxMoE(config, quant_config)
def forward(
......@@ -308,7 +318,8 @@ class DbrxModel(nn.Module):
)
self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers,
lambda prefix: DbrxBlock(config, cache_config, quant_config),
lambda prefix: DbrxBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks",
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
......
......@@ -184,6 +184,7 @@ class DeepseekAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -236,7 +237,8 @@ class DeepseekAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -261,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -277,6 +280,7 @@ class DeepseekDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
......@@ -346,7 +350,8 @@ class DeepseekModel(nn.Module):
lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
......
......@@ -268,7 +268,8 @@ class DeepseekV2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......
......@@ -174,6 +174,7 @@ class ExaoneAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
......@@ -219,7 +220,7 @@ class ExaoneBlockAttention(nn.Module):
quant_config=quant_config,
bias=bias,
cache_config=cache_config,
prefix=prefix,
prefix=f"{prefix}.attention",
)
def forward(
......
......@@ -84,6 +84,7 @@ class FalconAttention(nn.Module):
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
......@@ -158,7 +159,8 @@ class FalconAttention(nn.Module):
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
......@@ -171,14 +173,16 @@ class FalconAttention(nn.Module):
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -241,12 +245,16 @@ class FalconDecoderLayer(nn.Module):
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, cache_config,
quant_config)
self.self_attention = FalconAttention(
config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.mlp = FalconMLP(config, quant_config)
self.config = config
......@@ -357,8 +365,8 @@ class FalconModel(nn.Module):
# Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(config, cache_config,
quant_config),
lambda prefix: FalconDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
# Final Layer Norm
......
......@@ -35,10 +35,12 @@ class Florence2LanguageModel(nn.Module):
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
self.encoder = BartEncoder(config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.decoder")
if self.config.tie_word_embeddings:
self.encoder.embed_tokens.weight = self.shared.weight
......@@ -99,7 +101,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
self.config = config
self.model = Florence2LanguageModel(vllm_config=vllm_config,
prefix=prefix)
prefix=f"{prefix}.model")
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
......@@ -198,7 +200,7 @@ class Florence2ForConditionalGeneration(nn.Module):
# TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix,
prefix=f"{prefix}.language_model",
)
@property
......
......@@ -174,7 +174,8 @@ class GemmaAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......
......@@ -95,7 +95,8 @@ class Gemma2Attention(nn.Module):
rope_theta: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None:
attn_logits_soft_cap: Optional[float] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
......@@ -154,7 +155,8 @@ class Gemma2Attention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap)
logits_soft_cap=attn_logits_soft_cap,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -179,6 +181,7 @@ class Gemma2DecoderLayer(nn.Module):
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -194,6 +197,7 @@ class Gemma2DecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
attn_logits_soft_cap=config.attn_logit_softcapping,
prefix=f"{prefix}.self_attn",
)
self.hidden_size = config.hidden_size
self.mlp = Gemma2MLP(
......@@ -257,8 +261,11 @@ class Gemma2Model(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
-1]), config, cache_config, quant_config),
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
config,
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
......@@ -56,6 +56,7 @@ class Attention(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -135,11 +136,14 @@ class TransformerLayer(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = Attention(config, quant_config=quant_config)
self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config)
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
......@@ -161,11 +165,14 @@ class Transformer(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
TransformerLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(self, hidden_states):
......@@ -252,12 +259,14 @@ class EVA2CLIPModel(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
vision_config = Namespace(**config.vision_config)
self.patch_embedding = PatchEmbedding(vision_config)
self.transformer = Transformer(vision_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)
......
......@@ -84,7 +84,8 @@ class GPT2Attention(nn.Module):
self.head_dim,
scale=self.scale,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......
......@@ -52,6 +52,7 @@ class GPTBigCodeAttention(nn.Module):
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -92,7 +93,8 @@ class GPTBigCodeAttention(nn.Module):
scale=self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -151,6 +153,7 @@ class GPTBigCodeBlock(nn.Module):
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
......@@ -158,7 +161,10 @@ class GPTBigCodeBlock(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
self.attn = GPTBigCodeAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
......@@ -210,7 +216,8 @@ class GPTBigCodeModel(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
lambda prefix: GPTBigCodeBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......
......@@ -53,6 +53,7 @@ class GPTJAttention(nn.Module):
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.total_num_heads = config.num_attention_heads
......@@ -94,7 +95,8 @@ class GPTJAttention(nn.Module):
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -147,12 +149,16 @@ class GPTJBlock(nn.Module):
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, cache_config, quant_config)
self.attn = GPTJAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward(
......@@ -193,7 +199,8 @@ class GPTJModel(nn.Module):
)
self.start_layer, self.end_layer, self.h = make_layers(
config.n_layer,
lambda prefix: GPTJBlock(config, cache_config, quant_config),
lambda prefix: GPTJBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......
......@@ -52,6 +52,7 @@ class GPTNeoXAttention(nn.Module):
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.total_num_heads = config.num_attention_heads
......@@ -94,7 +95,8 @@ class GPTNeoXAttention(nn.Module):
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -145,6 +147,7 @@ class GPTNeoXLayer(nn.Module):
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
......@@ -152,7 +155,10 @@ class GPTNeoXLayer(nn.Module):
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, cache_config, quant_config)
self.attention = GPTNeoXAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attention")
self.mlp = GPTNeoXMLP(config, quant_config)
def forward(
......@@ -205,7 +211,8 @@ class GPTNeoXModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
lambda prefix: GPTNeoXLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
......
......@@ -161,7 +161,8 @@ class GraniteAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......
......@@ -164,7 +164,8 @@ class GraniteMoeAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch
from torch import nn
......@@ -250,7 +250,12 @@ class InternLMDecoderLayer(nn.Module):
@support_torch_compile
class InternLM2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -266,7 +271,7 @@ class InternLM2Model(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer(
lambda prefix: layer_type(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -316,14 +321,18 @@ class InternLM2Model(nn.Module):
class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = model_type(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
......
......@@ -14,8 +14,6 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
class InternLM2VEDecoderLayer(nn.Module):
......@@ -105,17 +103,9 @@ class InternLM2VEDecoderLayer(nn.Module):
class InternLM2VEModel(InternLM2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=InternLM2VEDecoderLayer)
def forward(
self,
......@@ -159,7 +149,6 @@ class InternLM2VEModel(InternLM2Model):
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.model = InternLM2VEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
super().__init__(vllm_config=vllm_config,
prefix=prefix,
model_type=InternLM2VEModel)
......@@ -76,6 +76,7 @@ class JAISAttention(nn.Module):
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -114,7 +115,8 @@ class JAISAttention(nn.Module):
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -178,6 +180,7 @@ class JAISBlock(nn.Module):
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
......@@ -185,7 +188,10 @@ class JAISBlock(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, cache_config, quant_config)
self.attn = JAISAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config)
......@@ -241,7 +247,8 @@ class JAISModel(nn.Module):
config.num_hidden_layers,
lambda prefix: JAISBlock(config=config,
cache_config=cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.h",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment