Unverified Commit f0eecee6 authored by Mor Zusman's avatar Mor Zusman Committed by GitHub
Browse files

[Bugfix] Fix dummy weight for fp8 (#4916)



Allow dummy load format for fp8,
torch.uniform_ doesn't support FP8 at the moment
Co-authored-by: default avatarMor Zusman <morz@ai21.com>
parent 943e72ca
...@@ -369,4 +369,11 @@ def initialize_dummy_weights( ...@@ -369,4 +369,11 @@ def initialize_dummy_weights(
""" """
for param in model.state_dict().values(): for param in model.state_dict().values():
if torch.is_floating_point(param): if torch.is_floating_point(param):
param.data.uniform_(low, high) if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment