Unverified Commit eebad39f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] support all attention backends (#10558)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent db100c5c
......@@ -102,7 +102,8 @@ class JambaMambaDecoderLayer(nn.Module):
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.mamba = MambaMixer(hidden_size= config.hidden_size,
......@@ -157,6 +158,7 @@ class JambaAttentionDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -198,6 +200,7 @@ class JambaAttentionDecoderLayer(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
num_experts = config.layers_num_experts[layer_idx]
......@@ -287,7 +290,8 @@ class JambaModel(nn.Module):
layer_class(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config))
quant_config=quant_config,
prefix=f"{prefix}.layers.{i}"))
self.layers = nn.ModuleList(decoder_layers)
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......
......@@ -174,6 +174,7 @@ class LlamaAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
......
......@@ -192,6 +192,7 @@ class MiniCPMAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -246,7 +247,8 @@ class MiniCPMAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -273,6 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -283,6 +286,7 @@ class MiniCPMDecoderLayer(nn.Module):
self.rope_scaling = getattr(config, "rope_scaling", None)
self.max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.prefix = prefix
self._init_attn_block()
self._init_ffn_block()
......@@ -298,6 +302,7 @@ class MiniCPMDecoderLayer(nn.Module):
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
)
def _init_ffn_block(self):
......@@ -388,8 +393,8 @@ class MiniCPMModel(nn.Module):
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
quant_config),
lambda prefix: MiniCPMDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
......
......@@ -60,6 +60,7 @@ class MiniCPM3Attention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -119,7 +120,8 @@ class MiniCPM3Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -195,6 +197,7 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"{self.prefix}.self_attn",
)
......@@ -209,8 +212,8 @@ class MiniCPM3Model(MiniCPMModel):
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
quant_config),
lambda prefix: MiniCPM3DecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
......
......@@ -166,7 +166,8 @@ class MixtralAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......
......@@ -170,6 +170,7 @@ class MixtralAttention(nn.Module):
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -219,7 +220,8 @@ class MixtralAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -243,6 +245,7 @@ class MixtralDecoderLayer(nn.Module):
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -255,7 +258,9 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.block_sparse_moe = MixtralMoE(config=config,
quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
......@@ -311,7 +316,8 @@ class MixtralModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config),
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
......
......@@ -370,6 +370,7 @@ class MolmoAttention(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -427,7 +428,8 @@ class MolmoAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
# Attention output projection.
self.o_proj = RowParallelLinear(
......@@ -517,10 +519,14 @@ class MolmoDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
# Attention block.
self.self_attn = MolmoAttention(config, cache_config, quant_config)
self.self_attn = MolmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = MolmoMLP(config, quant_config=quant_config)
......@@ -738,7 +744,8 @@ class MolmoModel(nn.Module):
else MolmoDecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer(config, cache_config, quant_config),
lambda prefix: decoder_layer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
......
......@@ -50,6 +50,7 @@ class MPTAttention(nn.Module):
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
......@@ -115,7 +116,8 @@ class MPTAttention(nn.Module):
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -176,11 +178,15 @@ class MPTBlock(nn.Module):
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config, cache_config, quant_config)
self.attn = MPTAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
......@@ -224,7 +230,8 @@ class MPTModel(nn.Module):
)
self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers,
lambda prefix: MPTBlock(config, cache_config, quant_config),
lambda prefix: MPTBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.blocks")
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
......
......@@ -195,7 +195,8 @@ class NemotronAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......
......@@ -62,6 +62,7 @@ class OlmoAttention(nn.Module):
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -101,7 +102,8 @@ class OlmoAttention(nn.Module):
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
# Attention output projection.
self.o_proj = RowParallelLinear(
......@@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module):
def __init__(self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Attention block.
self.self_attn = OlmoAttention(config, cache_config, quant_config)
self.self_attn = OlmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = OlmoMLP(config, quant_config)
......@@ -238,8 +244,8 @@ class OlmoModel(nn.Module):
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config
),
lambda prefix: OlmoDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
......
......@@ -102,6 +102,7 @@ class OlmoeAttention(nn.Module):
max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -156,7 +157,8 @@ class OlmoeAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -182,6 +184,7 @@ class OlmoeDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -199,6 +202,7 @@ class OlmoeDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = OlmoeMoE(
......@@ -260,8 +264,11 @@ class OlmoeModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(config, int(
prefix.split(".")[-1]), cache_config, quant_config),
lambda prefix: OlmoeDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
......
......@@ -75,6 +75,7 @@ class OrionAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -126,7 +127,8 @@ class OrionAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -150,6 +152,7 @@ class OrionDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -166,6 +169,7 @@ class OrionDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = OrionMLP(
hidden_size=self.hidden_size,
......@@ -226,10 +230,7 @@ class OrionModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OrionDecoderLayer(
config,
cache_config,
quant_config,
),
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
......
......@@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
tensor_parallel_world_size = get_tensor_model_parallel_world_size()
......@@ -122,7 +123,8 @@ class PersimmonAttention(nn.Module):
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
......@@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PersimmonAttention(config=config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = PersimmonMLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
......@@ -226,8 +230,8 @@ class PersimmonModel(nn.Module):
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PersimmonDecoderLayer(config, cache_config,
quant_config),
lambda prefix: PersimmonDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
......
......@@ -69,7 +69,8 @@ class PhiAttention(nn.Module):
def __init__(self,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
......@@ -116,7 +117,8 @@ class PhiAttention(nn.Module):
self.head_size,
scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -167,11 +169,15 @@ class PhiLayer(nn.Module):
def __init__(self,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, cache_config, quant_config)
self.self_attn = PhiAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = PhiMLP(config, quant_config)
def forward(
......@@ -210,7 +216,8 @@ class PhiModel(nn.Module):
config.hidden_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PhiLayer(config, cache_config, quant_config),
lambda prefix: PhiLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
......
......@@ -117,6 +117,7 @@ class Phi3SmallSelfAttention(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
......@@ -214,15 +215,14 @@ class Phi3SmallSelfAttention(nn.Module):
"homo_head": self.homo_heads
}
self.attn = Attention(
self.num_heads_per_partition,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads_per_partion,
cache_config=cache_config,
quant_config=quant_config,
blocksparse_params=bs_params,
)
self.attn = Attention(self.num_heads_per_partition,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads_per_partion,
cache_config=cache_config,
quant_config=quant_config,
blocksparse_params=bs_params,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -259,13 +259,15 @@ class Phi3SmallDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Phi3SmallSelfAttention(config,
layer_idx,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Phi3SmallMLP(config, quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
......@@ -315,7 +317,9 @@ class Phi3SmallModel(nn.Module):
config.num_hidden_layers,
lambda prefix: Phi3SmallDecoderLayer(config,
int(prefix.split('.')[-1]),
cache_config, quant_config),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
......
......@@ -294,6 +294,7 @@ class PhiMoEAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[dict] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -347,6 +348,7 @@ class PhiMoEAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
......@@ -371,6 +373,7 @@ class PhiMoEDecoderLayer(nn.Module):
config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -385,6 +388,7 @@ class PhiMoEDecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=config.rope_scaling,
prefix=f"{prefix}.self_attn",
)
self.block_sparse_moe = PhiMoE(
num_experts=config.num_local_experts,
......@@ -454,8 +458,8 @@ class PhiMoEModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PhiMoEDecoderLayer(config, cache_config,
quant_config),
lambda prefix: PhiMoEDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps,
......
......@@ -442,6 +442,7 @@ class QWenAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
......@@ -478,7 +479,8 @@ class QWenAttention(nn.Module):
self.head_dim,
self.scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -502,6 +504,7 @@ class QWenBlock(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -514,7 +517,8 @@ class QWenBlock(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -568,7 +572,8 @@ class QWenModel(nn.Module):
)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: QWenBlock(config, cache_config, quant_config),
lambda prefix: QWenBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
......
......@@ -168,6 +168,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -220,7 +221,8 @@ class Qwen2MoeAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -245,6 +247,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -261,6 +264,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
......@@ -336,7 +340,8 @@ class Qwen2MoeModel(nn.Module):
layer_idx=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config),
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
......@@ -167,6 +167,7 @@ class SolarAttention(nn.Module):
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
......
......@@ -77,7 +77,8 @@ class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
......@@ -131,7 +132,8 @@ class StablelmAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -155,9 +157,13 @@ class StablelmDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config, cache_config, quant_config)
self.self_attn = StablelmAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
......@@ -207,8 +213,8 @@ class StableLMEpochModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: StablelmDecoderLayer(config, cache_config,
quant_config),
lambda prefix: StablelmDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
norm_eps = getattr(config, "norm_eps",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment