Commit 9012e87d authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

Avoid cast in PositionalEmbeddings to fix BLEU drop in pytorch native export

Summary:
Tracing mode doesn't generalize correctly in positional embedding calculation, which caused -5 BLEU at transformer export when using pytorch native.

Details: The original issue was that in ensemble_export, _to_tensor(x) in scripting mode turns integer x into 1-d tensor torch.tensor([x]), not 0-d tensor (scalar x) which is expected in the embedding. So the return value in embedding forward() is actually of wrong shape. When self.weights is of size [x,y], the return value should be (bsz, y, 1) but it was (bsz, 1, y), which caused problem in downstream computation. Tracing only becomes an issue when I used pos = timestep.view(-1)[0] to fix the shape. Then casting the scalar to primary int, to be used as index is not generalizable by tracing mode. Thus I need to convert everything to tensor and replace the advanced indexing with index_select operator.

In summary, less understood features in both scripting&tracing sides caused the bleu drop. :)

Reviewed By: myleott

Differential Revision: D16623025

fbshipit-source-id: 0c7a2c3eafbd774760a5c880c6034009ee084abb
parent 3903f469
......@@ -67,9 +67,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = (timestep.int() + 1).long() if timestep is not None else seq_len
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
if self.onnx_trace:
return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1)
return self.weights.index_select(index=self.padding_idx + pos, dim=0).unsqueeze(1).repeat(bsz, 1, 1)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = utils.make_positions(input, self.padding_idx, onnx_trace=self.onnx_trace)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment