Unverified Commit 0e52a5c2 authored by TechxGenus's avatar TechxGenus Committed by GitHub
Browse files

Fix starcoder2 fused norm (#442)

parent e9f62694
...@@ -110,13 +110,9 @@ class Starcoder2Fuser: ...@@ -110,13 +110,9 @@ class Starcoder2Fuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj, module.self_attn.v_proj,
) )
norm_1 = FasterTransformerRMSNorm( # SC2 use normal LayerNorm
module.input_layernorm.weight, module.input_layernorm.eps norm_1 = module.input_layernorm
) norm_2 = module.post_attention_layernorm
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.eps,
)
blocks.append( blocks.append(
LlamaLikeBlock( LlamaLikeBlock(
hidden_size=self.model.config.hidden_size, hidden_size=self.model.config.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment