Commit 672977c1 authored by James Cross's avatar James Cross Committed by Facebook Github Bot
Browse files

transformer onnx trace: skip no-op transpose (#333)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/333

A tiny hack to speed up inference slightly for transformer beam search after export to graph mode. Specifically, there is no need to transpose a dimension with size 1 (the sequence length of a single decoder time step during beam search) with its neighbor immediately before a view/reshape.

Reviewed By: jmp84

Differential Revision: D12833011

fbshipit-source-id: f9c344a9ad595e6e48a8a65b31cf2b1392f9b938
parent 90c01b3a
...@@ -177,6 +177,11 @@ class MultiheadAttention(nn.Module): ...@@ -177,6 +177,11 @@ class MultiheadAttention(nn.Module):
attn = torch.bmm(attn_weights, v) attn = torch.bmm(attn_weights, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if (self.onnx_trace and attn.size(1) == 1):
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn) attn = self.out_proj(attn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment