Unverified Commit 4a0a1cd0 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Fix w4a16 conversion failed which is brought by PR 803 (#847)

parent 3295eac3
......@@ -152,6 +152,7 @@ class BaseOutputModel(ABC):
tensor.contiguous().cpu().numpy().tofile(path)
if self.to_file:
if torch.is_floating_point(param):
torch_type = _WEIGHT_DTYPE_MAP.get(self.cfg.weight_type,
torch.float16)
param = param.to(torch_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment