Commit a5e2d786 authored by Haoran Li's avatar Haoran Li Committed by Facebook Github Bot
Browse files

onnx bi-transformer (#385)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/385

Pull Request resolved: https://github.com/facebookresearch/pytext/pull/6

Pull Request resolved: https://github.com/pytorch/pytorch/pull/14292

Reviewed By: jingfeidu

Differential Revision: D10517864

fbshipit-source-id: 81008b5cc6aab70e23329c187392fb72ee057d78
parent 14506a83
...@@ -51,5 +51,5 @@ class Highway(torch.nn.Module): ...@@ -51,5 +51,5 @@ class Highway(torch.nn.Module):
proj_x, gate = projection.chunk(2, dim=-1) proj_x, gate = projection.chunk(2, dim=-1)
proj_x = self.activation(proj_x) proj_x = self.activation(proj_x)
gate = F.sigmoid(gate) gate = F.sigmoid(gate)
x = gate * x + (1 - gate) * proj_x x = gate * x + (gate.new_tensor([1]) - gate) * proj_x
return x return x
...@@ -156,16 +156,28 @@ class MultiheadAttention(nn.Module): ...@@ -156,16 +156,28 @@ class MultiheadAttention(nn.Module):
if attn_mask is not None: if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None: if key_padding_mask is not None:
key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None: if attn_mask is not None:
attn_weights += attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if self.onnx_trace:
attn_weights = torch.where(
key_padding_mask.unsqueeze(1).unsqueeze(2),
torch.Tensor([float("-Inf")]),
attn_weights.float()
).type_as(attn_weights)
else:
attn_weights = attn_weights.float().masked_fill( attn_weights = attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'), float('-inf'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment