Unverified Commit fc231d3d authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Add FP8 conversion for specific weight keys (#431)

Convert weights to FP8 format using FloatQuantizer.
parent f0f13701
...@@ -49,6 +49,7 @@ def main(): ...@@ -49,6 +49,7 @@ def main():
print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}") print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}")
## fp8 ## fp8
weight = state_dict[key].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel") w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight) weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
weight = weight.to(torch.float8_e4m3fn) weight = weight.to(torch.float8_e4m3fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment