Unverified Commit e2afb03c authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] Enable loading FP8 checkpoints for gpt_bigcode models (#5460)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 6e2527a7
...@@ -299,4 +299,10 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -299,4 +299,10 @@ class GPTBigCodeForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
else:
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment