Unverified Commit eebad39f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] support all attention backends (#10558)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent db100c5c
...@@ -102,7 +102,8 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -102,7 +102,8 @@ class JambaMambaDecoderLayer(nn.Module):
config: JambaConfig, config: JambaConfig,
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.mamba = MambaMixer(hidden_size= config.hidden_size, self.mamba = MambaMixer(hidden_size= config.hidden_size,
...@@ -157,6 +158,7 @@ class JambaAttentionDecoderLayer(nn.Module): ...@@ -157,6 +158,7 @@ class JambaAttentionDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -198,6 +200,7 @@ class JambaAttentionDecoderLayer(nn.Module): ...@@ -198,6 +200,7 @@ class JambaAttentionDecoderLayer(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.attn",
) )
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
...@@ -287,7 +290,8 @@ class JambaModel(nn.Module): ...@@ -287,7 +290,8 @@ class JambaModel(nn.Module):
layer_class(config, layer_class(config,
layer_idx=i, layer_idx=i,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config)) quant_config=quant_config,
prefix=f"{prefix}.layers.{i}"))
self.layers = nn.ModuleList(decoder_layers) self.layers = nn.ModuleList(decoder_layers)
self.final_layernorm = RMSNorm(config.hidden_size, self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
......
...@@ -174,6 +174,7 @@ class LlamaAttention(nn.Module): ...@@ -174,6 +174,7 @@ class LlamaAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
......
...@@ -192,6 +192,7 @@ class MiniCPMAttention(nn.Module): ...@@ -192,6 +192,7 @@ class MiniCPMAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -246,7 +247,8 @@ class MiniCPMAttention(nn.Module): ...@@ -246,7 +247,8 @@ class MiniCPMAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -273,6 +275,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -273,6 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -283,6 +286,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -283,6 +286,7 @@ class MiniCPMDecoderLayer(nn.Module):
self.rope_scaling = getattr(config, "rope_scaling", None) self.rope_scaling = getattr(config, "rope_scaling", None)
self.max_position_embeddings = getattr(config, self.max_position_embeddings = getattr(config,
"max_position_embeddings", 8192) "max_position_embeddings", 8192)
self.prefix = prefix
self._init_attn_block() self._init_attn_block()
self._init_ffn_block() self._init_ffn_block()
...@@ -298,6 +302,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -298,6 +302,7 @@ class MiniCPMDecoderLayer(nn.Module):
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config, cache_config=self.cache_config,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
) )
def _init_ffn_block(self): def _init_ffn_block(self):
...@@ -388,8 +393,8 @@ class MiniCPMModel(nn.Module): ...@@ -388,8 +393,8 @@ class MiniCPMModel(nn.Module):
): ):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config, lambda prefix: MiniCPMDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
......
...@@ -60,6 +60,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -60,6 +60,7 @@ class MiniCPM3Attention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -119,7 +120,8 @@ class MiniCPM3Attention(nn.Module): ...@@ -119,7 +120,8 @@ class MiniCPM3Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -195,6 +197,7 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): ...@@ -195,6 +197,7 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config, cache_config=self.cache_config,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
) )
...@@ -209,8 +212,8 @@ class MiniCPM3Model(MiniCPMModel): ...@@ -209,8 +212,8 @@ class MiniCPM3Model(MiniCPMModel):
): ):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config, lambda prefix: MiniCPM3DecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
......
...@@ -166,7 +166,8 @@ class MixtralAttention(nn.Module): ...@@ -166,7 +166,8 @@ class MixtralAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
......
...@@ -170,6 +170,7 @@ class MixtralAttention(nn.Module): ...@@ -170,6 +170,7 @@ class MixtralAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -219,7 +220,8 @@ class MixtralAttention(nn.Module): ...@@ -219,7 +220,8 @@ class MixtralAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -243,6 +245,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -243,6 +245,7 @@ class MixtralDecoderLayer(nn.Module):
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -255,7 +258,9 @@ class MixtralDecoderLayer(nn.Module): ...@@ -255,7 +258,9 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.block_sparse_moe = MixtralMoE(config=config, self.block_sparse_moe = MixtralMoE(config=config,
quant_config=quant_config) quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
...@@ -311,7 +316,8 @@ class MixtralModel(nn.Module): ...@@ -311,7 +316,8 @@ class MixtralModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer( lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config), config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -370,6 +370,7 @@ class MolmoAttention(nn.Module): ...@@ -370,6 +370,7 @@ class MolmoAttention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -427,7 +428,8 @@ class MolmoAttention(nn.Module): ...@@ -427,7 +428,8 @@ class MolmoAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
# Attention output projection. # Attention output projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -517,10 +519,14 @@ class MolmoDecoderLayer(nn.Module): ...@@ -517,10 +519,14 @@ class MolmoDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Attention block. # Attention block.
self.self_attn = MolmoAttention(config, cache_config, quant_config) self.self_attn = MolmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
# MLP block. # MLP block.
self.mlp = MolmoMLP(config, quant_config=quant_config) self.mlp = MolmoMLP(config, quant_config=quant_config)
...@@ -738,7 +744,8 @@ class MolmoModel(nn.Module): ...@@ -738,7 +744,8 @@ class MolmoModel(nn.Module):
else MolmoDecoderLayer else MolmoDecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: decoder_layer(config, cache_config, quant_config), lambda prefix: decoder_layer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
......
...@@ -50,6 +50,7 @@ class MPTAttention(nn.Module): ...@@ -50,6 +50,7 @@ class MPTAttention(nn.Module):
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -115,7 +116,8 @@ class MPTAttention(nn.Module): ...@@ -115,7 +116,8 @@ class MPTAttention(nn.Module):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -176,11 +178,15 @@ class MPTBlock(nn.Module): ...@@ -176,11 +178,15 @@ class MPTBlock(nn.Module):
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.d_model hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size) self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, cache_config, quant_config) self.attn = MPTAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.norm_2 = nn.LayerNorm(hidden_size) self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config) self.ffn = MPTMLP(config, quant_config)
...@@ -224,7 +230,8 @@ class MPTModel(nn.Module): ...@@ -224,7 +230,8 @@ class MPTModel(nn.Module):
) )
self.start_layer, self.end_layer, self.blocks = make_layers( self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers, config.n_layers,
lambda prefix: MPTBlock(config, cache_config, quant_config), lambda prefix: MPTBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks") prefix=f"{prefix}.blocks")
self.norm_f = nn.LayerNorm(config.d_model) self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias: if config.no_bias:
......
...@@ -195,7 +195,8 @@ class NemotronAttention(nn.Module): ...@@ -195,7 +195,8 @@ class NemotronAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
......
...@@ -62,6 +62,7 @@ class OlmoAttention(nn.Module): ...@@ -62,6 +62,7 @@ class OlmoAttention(nn.Module):
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -101,7 +102,8 @@ class OlmoAttention(nn.Module): ...@@ -101,7 +102,8 @@ class OlmoAttention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scaling, scale=self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
# Attention output projection. # Attention output projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module): ...@@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
# Attention block. # Attention block.
self.self_attn = OlmoAttention(config, cache_config, quant_config) self.self_attn = OlmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
# MLP block. # MLP block.
self.mlp = OlmoMLP(config, quant_config) self.mlp = OlmoMLP(config, quant_config)
...@@ -238,8 +244,8 @@ class OlmoModel(nn.Module): ...@@ -238,8 +244,8 @@ class OlmoModel(nn.Module):
config.hidden_size) config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config lambda prefix: OlmoDecoderLayer(
), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False, elementwise_affine=False,
......
...@@ -102,6 +102,7 @@ class OlmoeAttention(nn.Module): ...@@ -102,6 +102,7 @@ class OlmoeAttention(nn.Module):
max_position_embeddings: int = 4096, max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -156,7 +157,8 @@ class OlmoeAttention(nn.Module): ...@@ -156,7 +157,8 @@ class OlmoeAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -182,6 +184,7 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -182,6 +184,7 @@ class OlmoeDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -199,6 +202,7 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -199,6 +202,7 @@ class OlmoeDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = OlmoeMoE( self.mlp = OlmoeMoE(
...@@ -260,8 +264,11 @@ class OlmoeModel(nn.Module): ...@@ -260,8 +264,11 @@ class OlmoeModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(config, int( lambda prefix: OlmoeDecoderLayer(config,
prefix.split(".")[-1]), cache_config, quant_config), int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5) self.norm = RMSNorm(config.hidden_size, eps=1e-5)
......
...@@ -75,6 +75,7 @@ class OrionAttention(nn.Module): ...@@ -75,6 +75,7 @@ class OrionAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -126,7 +127,8 @@ class OrionAttention(nn.Module): ...@@ -126,7 +127,8 @@ class OrionAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -150,6 +152,7 @@ class OrionDecoderLayer(nn.Module): ...@@ -150,6 +152,7 @@ class OrionDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -166,6 +169,7 @@ class OrionDecoderLayer(nn.Module): ...@@ -166,6 +169,7 @@ class OrionDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = OrionMLP( self.mlp = OrionMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -226,10 +230,7 @@ class OrionModel(nn.Module): ...@@ -226,10 +230,7 @@ class OrionModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: OrionDecoderLayer( lambda prefix: OrionDecoderLayer(
config, config, cache_config, quant_config, prefix=prefix),
cache_config,
quant_config,
),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module): ...@@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module):
def __init__(self, def __init__(self,
config: PersimmonConfig, config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.config = config self.config = config
tensor_parallel_world_size = get_tensor_model_parallel_world_size() tensor_parallel_world_size = get_tensor_model_parallel_world_size()
...@@ -122,7 +123,8 @@ class PersimmonAttention(nn.Module): ...@@ -122,7 +123,8 @@ class PersimmonAttention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scaling, scale=self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def _split_heads(self, x: torch.Tensor) -> torch.Tensor: def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
...@@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module): ...@@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: PersimmonConfig, config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = PersimmonAttention(config=config, self.self_attn = PersimmonAttention(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = PersimmonMLP(config, quant_config=quant_config) self.mlp = PersimmonMLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -226,8 +230,8 @@ class PersimmonModel(nn.Module): ...@@ -226,8 +230,8 @@ class PersimmonModel(nn.Module):
config.hidden_size) config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: PersimmonDecoderLayer(config, cache_config, lambda prefix: PersimmonDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
......
...@@ -69,7 +69,8 @@ class PhiAttention(nn.Module): ...@@ -69,7 +69,8 @@ class PhiAttention(nn.Module):
def __init__(self, def __init__(self,
config: PhiConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -116,7 +117,8 @@ class PhiAttention(nn.Module): ...@@ -116,7 +117,8 @@ class PhiAttention(nn.Module):
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -167,11 +169,15 @@ class PhiLayer(nn.Module): ...@@ -167,11 +169,15 @@ class PhiLayer(nn.Module):
def __init__(self, def __init__(self,
config: PhiConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, cache_config, quant_config) self.self_attn = PhiAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = PhiMLP(config, quant_config) self.mlp = PhiMLP(config, quant_config)
def forward( def forward(
...@@ -210,7 +216,8 @@ class PhiModel(nn.Module): ...@@ -210,7 +216,8 @@ class PhiModel(nn.Module):
config.hidden_size) config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: PhiLayer(config, cache_config, quant_config), lambda prefix: PhiLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
......
...@@ -117,6 +117,7 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -117,6 +117,7 @@ class Phi3SmallSelfAttention(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
...@@ -214,15 +215,14 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -214,15 +215,14 @@ class Phi3SmallSelfAttention(nn.Module):
"homo_head": self.homo_heads "homo_head": self.homo_heads
} }
self.attn = Attention( self.attn = Attention(self.num_heads_per_partition,
self.num_heads_per_partition,
self.head_dim, self.head_dim,
self.scale, self.scale,
num_kv_heads=self.num_kv_heads_per_partion, num_kv_heads=self.num_kv_heads_per_partion,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
blocksparse_params=bs_params, blocksparse_params=bs_params,
) prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -259,13 +259,15 @@ class Phi3SmallDecoderLayer(nn.Module): ...@@ -259,13 +259,15 @@ class Phi3SmallDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Phi3SmallSelfAttention(config, self.self_attn = Phi3SmallSelfAttention(config,
layer_idx, layer_idx,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Phi3SmallMLP(config, quant_config) self.mlp = Phi3SmallMLP(config, quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
...@@ -315,7 +317,9 @@ class Phi3SmallModel(nn.Module): ...@@ -315,7 +317,9 @@ class Phi3SmallModel(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Phi3SmallDecoderLayer(config, lambda prefix: Phi3SmallDecoderLayer(config,
int(prefix.split('.')[-1]), int(prefix.split('.')[-1]),
cache_config, quant_config), cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
......
...@@ -294,6 +294,7 @@ class PhiMoEAttention(nn.Module): ...@@ -294,6 +294,7 @@ class PhiMoEAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[dict] = None, rope_scaling: Optional[dict] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -347,6 +348,7 @@ class PhiMoEAttention(nn.Module): ...@@ -347,6 +348,7 @@ class PhiMoEAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
...@@ -371,6 +373,7 @@ class PhiMoEDecoderLayer(nn.Module): ...@@ -371,6 +373,7 @@ class PhiMoEDecoderLayer(nn.Module):
config: PhiMoEConfig, config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -385,6 +388,7 @@ class PhiMoEDecoderLayer(nn.Module): ...@@ -385,6 +388,7 @@ class PhiMoEDecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
rope_scaling=config.rope_scaling, rope_scaling=config.rope_scaling,
prefix=f"{prefix}.self_attn",
) )
self.block_sparse_moe = PhiMoE( self.block_sparse_moe = PhiMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
...@@ -454,8 +458,8 @@ class PhiMoEModel(nn.Module): ...@@ -454,8 +458,8 @@ class PhiMoEModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: PhiMoEDecoderLayer(config, cache_config, lambda prefix: PhiMoEDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
......
...@@ -442,6 +442,7 @@ class QWenAttention(nn.Module): ...@@ -442,6 +442,7 @@ class QWenAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -478,7 +479,8 @@ class QWenAttention(nn.Module): ...@@ -478,7 +479,8 @@ class QWenAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -502,6 +504,7 @@ class QWenBlock(nn.Module): ...@@ -502,6 +504,7 @@ class QWenBlock(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -514,7 +517,8 @@ class QWenBlock(nn.Module): ...@@ -514,7 +517,8 @@ class QWenBlock(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -568,7 +572,8 @@ class QWenModel(nn.Module): ...@@ -568,7 +572,8 @@ class QWenModel(nn.Module):
) )
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: QWenBlock(config, cache_config, quant_config), lambda prefix: QWenBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -168,6 +168,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -168,6 +168,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -220,7 +221,8 @@ class Qwen2MoeAttention(nn.Module): ...@@ -220,7 +221,8 @@ class Qwen2MoeAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -245,6 +247,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -245,6 +247,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
...@@ -336,7 +340,8 @@ class Qwen2MoeModel(nn.Module): ...@@ -336,7 +340,8 @@ class Qwen2MoeModel(nn.Module):
layer_idx=int( layer_idx=int(
prefix.split(".")[-1]), prefix.split(".")[-1]),
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -167,6 +167,7 @@ class SolarAttention(nn.Module): ...@@ -167,6 +167,7 @@ class SolarAttention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
......
...@@ -77,7 +77,8 @@ class StablelmAttention(nn.Module): ...@@ -77,7 +77,8 @@ class StablelmAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -131,7 +132,8 @@ class StablelmAttention(nn.Module): ...@@ -131,7 +132,8 @@ class StablelmAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_key_value_heads, num_kv_heads=self.num_key_value_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -155,9 +157,13 @@ class StablelmDecoderLayer(nn.Module): ...@@ -155,9 +157,13 @@ class StablelmDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = StablelmAttention(config, cache_config, quant_config) self.self_attn = StablelmAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = StablelmMLP(config, quant_config) self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05)) getattr(config, "layer_norm_eps", 1e-05))
...@@ -207,8 +213,8 @@ class StableLMEpochModel(nn.Module): ...@@ -207,8 +213,8 @@ class StableLMEpochModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: StablelmDecoderLayer(config, cache_config, lambda prefix: StablelmDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment