"...models/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a00fa4067d445eb965f0dd71e0b47082c5a77b15"
Unverified Commit fc231d3d authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Add FP8 conversion for specific weight keys (#431)

Convert weights to FP8 format using FloatQuantizer.
parent f0f13701
...@@ -49,6 +49,7 @@ def main(): ...@@ -49,6 +49,7 @@ def main():
print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}") print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}")
## fp8 ## fp8
weight = state_dict[key].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel") w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight) weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
weight = weight.to(torch.float8_e4m3fn) weight = weight.to(torch.float8_e4m3fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment