Commit eb68afca authored by Xian Li's avatar Xian Li Committed by Facebook Github Bot
Browse files

fix a type mismatch in NAT quantization run

Summary:
Fix a type mismatch which was found after patching NAT on top of quantization.
Ning suggested this fix. Need to further understand: why this only appears after patching quantization diff?

Reviewed By: kahne, jhcross

Differential Revision: D18147726

fbshipit-source-id: a51becc9ad58a637a0180074eaa2b46990ab9f84
parent c07362c6
......@@ -108,7 +108,7 @@ def fill_tensors(x, mask, y, padding_idx: int):
x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx)
x[mask] = y
elif x.size(1) > y.size(1):
x[mask] = torch.tensor(padding_idx)
x[mask] = torch.tensor(padding_idx).type_as(x)
if x.dim() == 2:
x[mask, :y.size(1)] = y
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment