"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "cd185d1f93136431f1e6c6cd6bea956301995dd6"
Commit 9012e87d authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

Avoid cast in PositionalEmbeddings to fix BLEU drop in pytorch native export

Summary:
Tracing mode doesn't generalize correctly in positional embedding calculation, which caused -5 BLEU at transformer export when using pytorch native.

Details: The original issue was that in ensemble_export, _to_tensor(x) in scripting mode turns integer x into 1-d tensor torch.tensor([x]), not 0-d tensor (scalar x) which is expected in the embedding. So the return value in embedding forward() is actually of wrong shape. When self.weights is of size [x,y], the return value should be (bsz, y, 1) but it was (bsz, 1, y), which caused problem in downstream computation. Tracing only becomes an issue when I used pos = timestep.view(-1)[0] to fix the shape. Then casting the scalar to primary int, to be used as index is not generalizable by tracing mode. Thus I need to convert everything to tensor and replace the advanced indexing with index_select operator.

In summary, less understood features in both scripting&tracing sides caused the bleu drop. :)

Reviewed By: myleott

Differential Revision: D16623025

fbshipit-source-id: 0c7a2c3eafbd774760a5c880c6034009ee084abb
parent 3903f469
......@@ -67,9 +67,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = (timestep.int() + 1).long() if timestep is not None else seq_len
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
if self.onnx_trace:
return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1)
return self.weights.index_select(index=self.padding_idx + pos, dim=0).unsqueeze(1).repeat(bsz, 1, 1)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = utils.make_positions(input, self.padding_idx, onnx_trace=self.onnx_trace)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment