"git@developer.sourcefind.cn:change/sglang.git" did not exist on "0c5532b0c1e77a9ccfb50b06c61405794f76bc15"
Unverified Commit e17deb27 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: llama 3.1 405b fp8 (#714)

parent 2d3ae4e1
......@@ -35,6 +35,7 @@ class LlamaMLP(nn.Module):
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -42,12 +43,14 @@ class LlamaMLP(nn.Module):
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
......@@ -76,6 +79,7 @@ class LlamaAttention(nn.Module):
rope_is_neox_style: bool = True,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -110,12 +114,14 @@ class LlamaAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
......@@ -154,6 +160,7 @@ class LlamaDecoderLayer(nn.Module):
config: LlamaConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -178,12 +185,14 @@ class LlamaDecoderLayer(nn.Module):
rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
......@@ -231,7 +240,9 @@ class LlamaModel(nn.Module):
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config, i, quant_config=quant_config)
LlamaDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
)
for i in range(config.num_hidden_layers)
]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment