Unverified Commit eebad39f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] support all attention backends (#10558)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent db100c5c
...@@ -230,6 +230,7 @@ class GLMAttention(nn.Module): ...@@ -230,6 +230,7 @@ class GLMAttention(nn.Module):
config: ChatGLMConfig, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -285,7 +286,8 @@ class GLMAttention(nn.Module): ...@@ -285,7 +286,8 @@ class GLMAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -364,6 +366,7 @@ class GLMBlock(nn.Module): ...@@ -364,6 +366,7 @@ class GLMBlock(nn.Module):
config: ChatGLMConfig, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
...@@ -377,7 +380,10 @@ class GLMBlock(nn.Module): ...@@ -377,7 +380,10 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon) eps=config.layernorm_epsilon)
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, cache_config, quant_config) self.self_attention = GLMAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
...@@ -446,7 +452,8 @@ class GLMTransformer(nn.Module): ...@@ -446,7 +452,8 @@ class GLMTransformer(nn.Module):
# Transformer layers. # Transformer layers.
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
self.num_layers, self.num_layers,
lambda prefix: GLMBlock(config, cache_config, quant_config), lambda prefix: GLMBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
...@@ -500,16 +507,22 @@ class ChatGLMModel(nn.Module): ...@@ -500,16 +507,22 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config) self.encoder = GLMTransformer(config,
cache_config,
quant_config,
prefix=f"{prefix}.encoder")
self.output_layer = ParallelLMHead(config.padded_vocab_size, self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.output_layer")
vision_config_flag = getattr(config, 'vision_config', None) vision_config_flag = getattr(config, 'vision_config', None)
if vision_config_flag is not None: if vision_config_flag is not None:
self.vision_config = Namespace(**config.vision_config) self.vision_config = Namespace(**config.vision_config)
self.vision = EVA2CLIPModel(self.config, quant_config) self.vision = EVA2CLIPModel(self.config,
quant_config,
prefix=f"{prefix}.vision")
else: else:
self.vision = None self.vision = None
......
...@@ -120,6 +120,7 @@ class CohereAttention(nn.Module): ...@@ -120,6 +120,7 @@ class CohereAttention(nn.Module):
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -175,7 +176,8 @@ class CohereAttention(nn.Module): ...@@ -175,7 +176,8 @@ class CohereAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads, self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim), self.head_dim),
...@@ -215,13 +217,15 @@ class CohereDecoderLayer(nn.Module): ...@@ -215,13 +217,15 @@ class CohereDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: CohereConfig, config: CohereConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, self.self_attn = CohereAttention(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = CohereMLP(config, quant_config=quant_config) self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
...@@ -271,8 +275,8 @@ class CohereModel(nn.Module): ...@@ -271,8 +275,8 @@ class CohereModel(nn.Module):
config.hidden_size) config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: CohereDecoderLayer(config, cache_config, lambda prefix: CohereDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = LayerNorm(param_shape=(config.hidden_size), self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
......
...@@ -154,6 +154,7 @@ class DbrxAttention(nn.Module): ...@@ -154,6 +154,7 @@ class DbrxAttention(nn.Module):
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -208,7 +209,8 @@ class DbrxAttention(nn.Module): ...@@ -208,7 +209,8 @@ class DbrxAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -234,10 +236,14 @@ class DbrxFusedNormAttention(nn.Module): ...@@ -234,10 +236,14 @@ class DbrxFusedNormAttention(nn.Module):
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
self.attn = DbrxAttention(config, cache_config, quant_config) self.attn = DbrxAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.norm_1 = nn.LayerNorm(self.d_model) self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model)
...@@ -269,10 +275,14 @@ class DbrxBlock(nn.Module): ...@@ -269,10 +275,14 @@ class DbrxBlock(nn.Module):
config: DbrxConfig, config: DbrxConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, self.norm_attn_norm = DbrxFusedNormAttention(
quant_config) config,
cache_config,
quant_config,
prefix=f"{prefix}.norm_attn_norm")
self.ffn = DbrxMoE(config, quant_config) self.ffn = DbrxMoE(config, quant_config)
def forward( def forward(
...@@ -308,7 +318,8 @@ class DbrxModel(nn.Module): ...@@ -308,7 +318,8 @@ class DbrxModel(nn.Module):
) )
self.start_layer, self.end_layer, self.blocks = make_layers( self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers, config.n_layers,
lambda prefix: DbrxBlock(config, cache_config, quant_config), lambda prefix: DbrxBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks", prefix=f"{prefix}.blocks",
) )
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
......
...@@ -184,6 +184,7 @@ class DeepseekAttention(nn.Module): ...@@ -184,6 +184,7 @@ class DeepseekAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -236,7 +237,8 @@ class DeepseekAttention(nn.Module): ...@@ -236,7 +237,8 @@ class DeepseekAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -261,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -261,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -277,6 +280,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -277,6 +280,7 @@ class DeepseekDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace and layer_idx >= config.first_k_dense_replace
...@@ -346,7 +350,8 @@ class DeepseekModel(nn.Module): ...@@ -346,7 +350,8 @@ class DeepseekModel(nn.Module):
lambda prefix: DeepseekDecoderLayer(config, lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]), int(prefix.split(".")[-1]),
cache_config, cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -268,7 +268,8 @@ class DeepseekV2Attention(nn.Module): ...@@ -268,7 +268,8 @@ class DeepseekV2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
......
...@@ -174,6 +174,7 @@ class ExaoneAttention(nn.Module): ...@@ -174,6 +174,7 @@ class ExaoneAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
...@@ -219,7 +220,7 @@ class ExaoneBlockAttention(nn.Module): ...@@ -219,7 +220,7 @@ class ExaoneBlockAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=bias, bias=bias,
cache_config=cache_config, cache_config=cache_config,
prefix=prefix, prefix=f"{prefix}.attention",
) )
def forward( def forward(
......
...@@ -84,6 +84,7 @@ class FalconAttention(nn.Module): ...@@ -84,6 +84,7 @@ class FalconAttention(nn.Module):
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -158,7 +159,8 @@ class FalconAttention(nn.Module): ...@@ -158,7 +159,8 @@ class FalconAttention(nn.Module):
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
elif self.use_alibi: elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads head_start = tp_rank * self.num_heads
...@@ -171,14 +173,16 @@ class FalconAttention(nn.Module): ...@@ -171,14 +173,16 @@ class FalconAttention(nn.Module):
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
else: else:
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.inv_norm_factor, scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -241,12 +245,16 @@ class FalconDecoderLayer(nn.Module): ...@@ -241,12 +245,16 @@ class FalconDecoderLayer(nn.Module):
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, cache_config, self.self_attention = FalconAttention(
quant_config) config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.mlp = FalconMLP(config, quant_config) self.mlp = FalconMLP(config, quant_config)
self.config = config self.config = config
...@@ -357,8 +365,8 @@ class FalconModel(nn.Module): ...@@ -357,8 +365,8 @@ class FalconModel(nn.Module):
# Transformer blocks # Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(config, cache_config, lambda prefix: FalconDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
# Final Layer Norm # Final Layer Norm
......
...@@ -35,10 +35,12 @@ class Florence2LanguageModel(nn.Module): ...@@ -35,10 +35,12 @@ class Florence2LanguageModel(nn.Module):
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
self.encoder = BartEncoder(config, self.encoder = BartEncoder(config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config, self.decoder = BartDecoder(config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.decoder")
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.encoder.embed_tokens.weight = self.shared.weight self.encoder.embed_tokens.weight = self.shared.weight
...@@ -99,7 +101,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module): ...@@ -99,7 +101,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
self.config = config self.config = config
self.model = Florence2LanguageModel(vllm_config=vllm_config, self.model = Florence2LanguageModel(vllm_config=vllm_config,
prefix=prefix) prefix=f"{prefix}.model")
embed_scale = math.sqrt( embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0 config.d_model) if config.scale_embedding else 1.0
...@@ -198,7 +200,7 @@ class Florence2ForConditionalGeneration(nn.Module): ...@@ -198,7 +200,7 @@ class Florence2ForConditionalGeneration(nn.Module):
# TODO(Isotr0py): Add vision backbone # TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration( self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config), vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix, prefix=f"{prefix}.language_model",
) )
@property @property
......
...@@ -174,7 +174,8 @@ class GemmaAttention(nn.Module): ...@@ -174,7 +174,8 @@ class GemmaAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
......
...@@ -95,7 +95,8 @@ class Gemma2Attention(nn.Module): ...@@ -95,7 +95,8 @@ class Gemma2Attention(nn.Module):
rope_theta: float, rope_theta: float,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None: attn_logits_soft_cap: Optional[float] = None,
prefix: str = "") -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.config = config self.config = config
...@@ -154,7 +155,8 @@ class Gemma2Attention(nn.Module): ...@@ -154,7 +155,8 @@ class Gemma2Attention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap) logits_soft_cap=attn_logits_soft_cap,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -179,6 +181,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -179,6 +181,7 @@ class Gemma2DecoderLayer(nn.Module):
config: Gemma2Config, config: Gemma2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -194,6 +197,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -194,6 +197,7 @@ class Gemma2DecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
attn_logits_soft_cap=config.attn_logit_softcapping, attn_logits_soft_cap=config.attn_logit_softcapping,
prefix=f"{prefix}.self_attn",
) )
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.mlp = Gemma2MLP( self.mlp = Gemma2MLP(
...@@ -257,8 +261,11 @@ class Gemma2Model(nn.Module): ...@@ -257,8 +261,11 @@ class Gemma2Model(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[ lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
-1]), config, cache_config, quant_config), config,
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -56,6 +56,7 @@ class Attention(nn.Module): ...@@ -56,6 +56,7 @@ class Attention(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -135,11 +136,14 @@ class TransformerLayer(nn.Module): ...@@ -135,11 +136,14 @@ class TransformerLayer(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size, self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.attention = Attention(config, quant_config=quant_config) self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config) self.mlp = MLP(config, quant_config=quant_config)
self.post_attention_layernorm = LayerNorm(config.hidden_size, self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -161,11 +165,14 @@ class Transformer(nn.Module): ...@@ -161,11 +165,14 @@ class Transformer(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
TransformerLayer(config, quant_config=quant_config) TransformerLayer(config,
for _ in range(config.num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
]) ])
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -252,12 +259,14 @@ class EVA2CLIPModel(nn.Module): ...@@ -252,12 +259,14 @@ class EVA2CLIPModel(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
vision_config = Namespace(**config.vision_config) vision_config = Namespace(**config.vision_config)
self.patch_embedding = PatchEmbedding(vision_config) self.patch_embedding = PatchEmbedding(vision_config)
self.transformer = Transformer(vision_config, self.transformer = Transformer(vision_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config, self.linear_proj = GLU(config,
in_features=config.hidden_size, in_features=config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -84,7 +84,8 @@ class GPT2Attention(nn.Module): ...@@ -84,7 +84,8 @@ class GPT2Attention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
......
...@@ -52,6 +52,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -52,6 +52,7 @@ class GPTBigCodeAttention(nn.Module):
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -92,7 +93,8 @@ class GPTBigCodeAttention(nn.Module): ...@@ -92,7 +93,8 @@ class GPTBigCodeAttention(nn.Module):
scale=self.scale, scale=self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -151,6 +153,7 @@ class GPTBigCodeBlock(nn.Module): ...@@ -151,6 +153,7 @@ class GPTBigCodeBlock(nn.Module):
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -158,7 +161,10 @@ class GPTBigCodeBlock(nn.Module): ...@@ -158,7 +161,10 @@ class GPTBigCodeBlock(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, cache_config, quant_config) self.attn = GPTBigCodeAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config) self.mlp = GPTBigMLP(inner_dim, config, quant_config)
...@@ -210,7 +216,8 @@ class GPTBigCodeModel(nn.Module): ...@@ -210,7 +216,8 @@ class GPTBigCodeModel(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config), lambda prefix: GPTBigCodeBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h", prefix=f"{prefix}.h",
) )
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......
...@@ -53,6 +53,7 @@ class GPTJAttention(nn.Module): ...@@ -53,6 +53,7 @@ class GPTJAttention(nn.Module):
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
...@@ -94,7 +95,8 @@ class GPTJAttention(nn.Module): ...@@ -94,7 +95,8 @@ class GPTJAttention(nn.Module):
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -147,12 +149,16 @@ class GPTJBlock(nn.Module): ...@@ -147,12 +149,16 @@ class GPTJBlock(nn.Module):
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
inner_dim = (4 * config.n_embd inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner) if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, cache_config, quant_config) self.attn = GPTJAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.mlp = GPTJMLP(inner_dim, config, quant_config) self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward( def forward(
...@@ -193,7 +199,8 @@ class GPTJModel(nn.Module): ...@@ -193,7 +199,8 @@ class GPTJModel(nn.Module):
) )
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.n_layer, config.n_layer,
lambda prefix: GPTJBlock(config, cache_config, quant_config), lambda prefix: GPTJBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h", prefix=f"{prefix}.h",
) )
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......
...@@ -52,6 +52,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -52,6 +52,7 @@ class GPTNeoXAttention(nn.Module):
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
...@@ -94,7 +95,8 @@ class GPTNeoXAttention(nn.Module): ...@@ -94,7 +95,8 @@ class GPTNeoXAttention(nn.Module):
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -145,6 +147,7 @@ class GPTNeoXLayer(nn.Module): ...@@ -145,6 +147,7 @@ class GPTNeoXLayer(nn.Module):
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
...@@ -152,7 +155,10 @@ class GPTNeoXLayer(nn.Module): ...@@ -152,7 +155,10 @@ class GPTNeoXLayer(nn.Module):
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config, cache_config, quant_config) self.attention = GPTNeoXAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attention")
self.mlp = GPTNeoXMLP(config, quant_config) self.mlp = GPTNeoXMLP(config, quant_config)
def forward( def forward(
...@@ -205,7 +211,8 @@ class GPTNeoXModel(nn.Module): ...@@ -205,7 +211,8 @@ class GPTNeoXModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config), lambda prefix: GPTNeoXLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,
......
...@@ -161,7 +161,8 @@ class GraniteAttention(nn.Module): ...@@ -161,7 +161,8 @@ class GraniteAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
......
...@@ -164,7 +164,8 @@ class GraniteMoeAttention(nn.Module): ...@@ -164,7 +164,8 @@ class GraniteMoeAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
......
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch import torch
from torch import nn from torch import nn
...@@ -250,7 +250,12 @@ class InternLMDecoderLayer(nn.Module): ...@@ -250,7 +250,12 @@ class InternLMDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class InternLM2Model(nn.Module): class InternLM2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -266,7 +271,7 @@ class InternLM2Model(nn.Module): ...@@ -266,7 +271,7 @@ class InternLM2Model(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer( lambda prefix: layer_type(
config, cache_config, quant_config, prefix=prefix), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -316,13 +321,17 @@ class InternLM2Model(nn.Module): ...@@ -316,13 +321,17 @@ class InternLM2Model(nn.Module):
class InternLM2ForCausalLM(nn.Module, SupportsPP): class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = InternLM2Model(vllm_config=vllm_config, self.model = model_type(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size, self.output = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -14,8 +14,6 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention, ...@@ -14,8 +14,6 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2MLP, InternLM2Model) InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
class InternLM2VEDecoderLayer(nn.Module): class InternLM2VEDecoderLayer(nn.Module):
...@@ -105,17 +103,9 @@ class InternLM2VEDecoderLayer(nn.Module): ...@@ -105,17 +103,9 @@ class InternLM2VEDecoderLayer(nn.Module):
class InternLM2VEModel(InternLM2Model): class InternLM2VEModel(InternLM2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config,
prefix=prefix,
config = vllm_config.model_config.hf_config layer_type=InternLM2VEDecoderLayer)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def forward( def forward(
self, self,
...@@ -159,7 +149,6 @@ class InternLM2VEModel(InternLM2Model): ...@@ -159,7 +149,6 @@ class InternLM2VEModel(InternLM2Model):
class InternLM2VEForCausalLM(InternLM2ForCausalLM): class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config,
prefix=prefix,
self.model = InternLM2VEModel(vllm_config=vllm_config, model_type=InternLM2VEModel)
prefix=maybe_prefix(prefix, "model"))
...@@ -76,6 +76,7 @@ class JAISAttention(nn.Module): ...@@ -76,6 +76,7 @@ class JAISAttention(nn.Module):
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -114,7 +115,8 @@ class JAISAttention(nn.Module): ...@@ -114,7 +115,8 @@ class JAISAttention(nn.Module):
scale=self.scale, scale=self.scale,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -178,6 +180,7 @@ class JAISBlock(nn.Module): ...@@ -178,6 +180,7 @@ class JAISBlock(nn.Module):
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -185,7 +188,10 @@ class JAISBlock(nn.Module): ...@@ -185,7 +188,10 @@ class JAISBlock(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = JAISAttention(config, cache_config, quant_config) self.attn = JAISAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config) self.mlp = JAISMLP(inner_dim, config, quant_config)
...@@ -241,7 +247,8 @@ class JAISModel(nn.Module): ...@@ -241,7 +247,8 @@ class JAISModel(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: JAISBlock(config=config, lambda prefix: JAISBlock(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.h", prefix=f"{prefix}.h",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment