Commit f68a4435 authored by Myle Ott's avatar Myle Ott
Browse files

Bug fixes

parent 1235aa08
...@@ -12,6 +12,8 @@ from torch import nn ...@@ -12,6 +12,8 @@ from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
class MultiheadAttention(nn.Module): class MultiheadAttention(nn.Module):
"""Multi-headed attention. """Multi-headed attention.
...@@ -88,7 +90,7 @@ class MultiheadAttention(nn.Module): ...@@ -88,7 +90,7 @@ class MultiheadAttention(nn.Module):
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0) attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
if key_padding_mask.max() > 0: if utils.item(key_padding_mask.max()) > 0:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), key_padding_mask.unsqueeze(1).unsqueeze(2),
......
...@@ -62,7 +62,7 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -62,7 +62,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
# recompute/expand embeddings if needed # recompute/expand embeddings if needed
bsz, seq_len = input.size() bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len max_pos = self.padding_idx + 1 + seq_len
if seq_len > self.weights.size(0): if max_pos > self.weights.size(0):
self.weights = SinusoidalPositionalEmbedding.get_embedding( self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, max_pos,
self.embedding_dim, self.embedding_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment