"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "0233efcaaa5124a4e1d3ffdbe416750fde3c22b0"
Unverified Commit 4a0a1cd0 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Fix w4a16 conversion failed which is brought by PR 803 (#847)

parent 3295eac3
...@@ -152,9 +152,10 @@ class BaseOutputModel(ABC): ...@@ -152,9 +152,10 @@ class BaseOutputModel(ABC):
tensor.contiguous().cpu().numpy().tofile(path) tensor.contiguous().cpu().numpy().tofile(path)
if self.to_file: if self.to_file:
torch_type = _WEIGHT_DTYPE_MAP.get(self.cfg.weight_type, if torch.is_floating_point(param):
torch.float16) torch_type = _WEIGHT_DTYPE_MAP.get(self.cfg.weight_type,
param = param.to(torch_type) torch.float16)
param = param.to(torch_type)
tprint(name, param.shape) tprint(name, param.shape)
_tofile(param, osp.join(self.out_dir, name)) _tofile(param, osp.join(self.out_dir, name))
elif len(self.tm_params) > 0: elif len(self.tm_params) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment