Commit f68a4435 authored by Myle Ott's avatar Myle Ott
Browse files

Bug fixes

parent 1235aa08
......@@ -12,6 +12,8 @@ from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from fairseq import utils
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
......@@ -88,7 +90,7 @@ class MultiheadAttention(nn.Module):
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None:
# don't attend to padding symbols
if key_padding_mask.max() > 0:
if utils.item(key_padding_mask.max()) > 0:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
......
......@@ -62,7 +62,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
# recompute/expand embeddings if needed
bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len
if seq_len > self.weights.size(0):
if max_pos > self.weights.size(0):
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment