Unverified Commit e17deb27 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: llama 3.1 405b fp8 (#714)

parent 2d3ae4e1
...@@ -35,6 +35,7 @@ class LlamaMLP(nn.Module): ...@@ -35,6 +35,7 @@ class LlamaMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -42,12 +43,14 @@ class LlamaMLP(nn.Module): ...@@ -42,12 +43,14 @@ class LlamaMLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj",
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -76,6 +79,7 @@ class LlamaAttention(nn.Module): ...@@ -76,6 +79,7 @@ class LlamaAttention(nn.Module):
rope_is_neox_style: bool = True, rope_is_neox_style: bool = True,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -110,12 +114,14 @@ class LlamaAttention(nn.Module): ...@@ -110,12 +114,14 @@ class LlamaAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -154,6 +160,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -154,6 +160,7 @@ class LlamaDecoderLayer(nn.Module):
config: LlamaConfig, config: LlamaConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -178,12 +185,14 @@ class LlamaDecoderLayer(nn.Module): ...@@ -178,12 +185,14 @@ class LlamaDecoderLayer(nn.Module):
rope_is_neox_style=rope_is_neox_style, rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp",
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -231,7 +240,9 @@ class LlamaModel(nn.Module): ...@@ -231,7 +240,9 @@ class LlamaModel(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
LlamaDecoderLayer(config, i, quant_config=quant_config) LlamaDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment