Commit 3a64aced authored by James Cross's avatar James Cross Committed by Facebook Github Bot
Browse files

work around lack of optional output for forks (#429)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/429

Pull Request resolved: https://github.com/pytorch/fairseq/pull/618

PyTorch export for transformer models was broken because as written, they used a placeholder `None` value during inference for the variable `key_padding_mask` to indicate no padding, but PyTorch is unable trace such values. This diff adds a minor hack to allow the use of an empty tensor for the same purpose.

Reviewed By: jmp84

Differential Revision: D14581730

fbshipit-source-id: 2ea4664c20ecab8478c578b2182a85319140036c
parent 10ad7495
......@@ -145,6 +145,11 @@ class MultiheadAttention(nn.Module):
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment