"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "f080a83511511a9c0a222451a752a1623aec095d"
Unverified Commit f0eecee6 authored by Mor Zusman's avatar Mor Zusman Committed by GitHub
Browse files

[Bugfix] Fix dummy weight for fp8 (#4916)



Allow dummy load format for fp8,
torch.uniform_ doesn't support FP8 at the moment
Co-authored-by: default avatarMor Zusman <morz@ai21.com>
parent 943e72ca
......@@ -369,4 +369,11 @@ def initialize_dummy_weights(
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
param.data.uniform_(low, high)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment