Commit 2a84f46b authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

remove completed sentences from batch

remove completed sentences from batch and allow batching uneven lengths (with fixes to make padded sequences work correctly in all models)
parent bcdc27dc
......@@ -183,7 +183,8 @@ class LanguageDatasets(object):
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending)
descending=descending,
allow_different_src_lens=True)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
......@@ -369,7 +370,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False, required_batch_size_multiple=1):
descending=False, required_batch_size_multiple=1, allow_different_src_lens=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
......@@ -382,7 +383,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
indices = np.flip(indices, 0)
return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False,
ignore_invalid_inputs, allow_different_src_lens=allow_different_src_lens,
required_batch_size_multiple=required_batch_size_multiple,
))
......
......@@ -26,9 +26,15 @@ class FairseqIncrementalDecoder(FairseqDecoder):
"""
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(incremental_state, new_order)
module.reorder_incremental_state(
incremental_state,
new_order,
)
self.apply(apply_reorder_incremental_state)
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size:
......
......@@ -103,15 +103,12 @@ class FConvEncoder(FairseqEncoder):
self.num_attention_layers = None
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
self.embed_positions = PositionalEmbedding(
max_positions,
embed_dim,
padding_idx,
self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
)
......@@ -142,12 +139,21 @@ class FConvEncoder(FairseqEncoder):
# project to size of convolution
x = self.fc1(x)
# used to mask padding in input
encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B
if not encoder_padding_mask.any():
encoder_padding_mask = None
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# temporal convolutions
for proj, conv in zip(self.projections, self.convolutions):
residual = x if proj is None else proj(x)
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
x = F.dropout(x, p=self.dropout, training=self.training)
if conv.kernel_size[0] % 2 == 1:
# padding is implicit in the conv
......@@ -166,13 +172,20 @@ class FConvEncoder(FairseqEncoder):
# project back to size of embedding
x = self.fc2(x)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask.t() # -> B x T
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
# scale gradients (this only affects backward, not forward)
x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(0.5)
return x, y
return {
'encoder_out': (x, y),
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
......@@ -189,13 +202,20 @@ class AttentionLayer(nn.Module):
self.bmm = bmm if bmm is not None else torch.bmm
def forward(self, x, target_embedding, encoder_out):
def forward(self, x, target_embedding, encoder_out, encoder_padding_mask):
residual = x
# attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = self.bmm(x, encoder_out[0])
# don't attend over padding
if encoder_padding_mask is not None:
x = x.float().masked_fill(
encoder_padding_mask.unsqueeze(1),
float('-inf')
).type_as(x) # FP16 support: cast to float and back
# softmax over last dim
sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1)
......@@ -204,9 +224,14 @@ class AttentionLayer(nn.Module):
x = self.bmm(x, encoder_out[1])
# scale attention output
# scale attention output (respecting potentially different lengths)
s = encoder_out[1].size(1)
if encoder_padding_mask is None:
x = x * (s * math.sqrt(1.0 / s))
else:
s = s - encoder_padding_mask.type_as(x).sum(dim=1, keepdim=True) # exclude padding
s = s.unsqueeze(-1)
x = x * (s * s.rsqrt())
# project back
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
......@@ -274,7 +299,10 @@ class FConvDecoder(FairseqIncrementalDecoder):
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)
......@@ -307,7 +335,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
if attention is not None:
x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None:
avg_attn_scores = attn_scores
......@@ -373,6 +401,23 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if encoder_out is not None:
def update_enc_out(enc_out):
return enc_out.index_select(0, new_order)
encoder_out = tuple([update_enc_out(eo) for eo in encoder_out])
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1)
......
......@@ -209,7 +209,12 @@ class LSTMEncoder(FairseqEncoder):
dim=0).view(bsz, self.output_units))
return x, bi_final_hiddens, bi_final_cells
return x, final_hiddens, final_cells
encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
return {
'encoder_out': (x, final_hiddens, final_cells),
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
......@@ -223,7 +228,7 @@ class AttentionLayer(nn.Module):
self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False)
self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False)
def forward(self, input, source_hids, src_lengths=None):
def forward(self, input, source_hids, encoder_padding_mask):
# input: bsz x input_embed_dim
# source_hids: srclen x bsz x output_embed_dim
......@@ -232,6 +237,14 @@ class AttentionLayer(nn.Module):
# compute attention
attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
# don't attend over padding
if encoder_padding_mask is not None:
attn_scores = attn_scores.float().masked_fill_(
encoder_padding_mask,
float('-inf')
).type_as(attn_scores) # FP16 support: cast to float and back
attn_scores = F.softmax(attn_scores.t(), dim=1).t() # srclen x bsz
# sum weighted sources
......@@ -278,7 +291,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.additional_fc = Linear(hidden_size, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bsz, seqlen = prev_output_tokens.size()
......@@ -324,7 +340,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
# apply attention using the last layer's hidden state
if self.attention is not None:
out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs)
out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs, encoder_padding_mask)
else:
out = hidden
out = F.dropout(out, p=self.dropout_out, training=self.training)
......@@ -371,6 +387,13 @@ class LSTMDecoder(FairseqIncrementalDecoder):
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order) for eo in encoder_out_dict['encoder_out'])
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
......
......@@ -233,6 +233,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return state_dict
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
......
......@@ -112,7 +112,7 @@ class SequenceGenerator(object):
# compute the encoder output for each beam
encoder_out = model.encoder(
src_tokens.repeat(1, beam_size).view(-1, srclen),
src_lengths.repeat(beam_size),
src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
)
encoder_outs.append(encoder_out)
......@@ -200,10 +200,20 @@ class SequenceGenerator(object):
if self.normalize_scores:
eos_scores /= (step+1)**self.len_penalty
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set()
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
sent = idx // beam_size
sents_seen.add(sent)
unfin_idx = idx // beam_size
sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
def get_hypo():
_, alignment = attn_clone[i].max(dim=0)
......@@ -230,23 +240,27 @@ class SequenceGenerator(object):
'idx': idx,
}
# return number of hypotheses finished this step
num_finished = 0
for sent in sents_seen:
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True
num_finished += 1
return num_finished
newly_finished.append(unfin_idx)
return newly_finished
reorder_state = None
batch_idxs = None
for step in range(maxlen + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
for model in self.models:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(
incremental_states[model], reorder_state)
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(
tokens[:, :step+1], encoder_outs, incremental_states)
......@@ -308,6 +322,7 @@ class SequenceGenerator(object):
else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
......@@ -323,18 +338,20 @@ class SequenceGenerator(object):
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
num_remaining_sent -= finalize_hypos(
step, eos_bbsz_idx, eos_scores)
num_remaining_sent -= len(finalize_hypos(
step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add_(bbsz_offsets)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos)
finalized_sents = set()
if step >= self.minlen:
# only consider eos when it's among the top beam_size indices
torch.masked_select(
......@@ -348,14 +365,42 @@ class SequenceGenerator(object):
mask=eos_mask[:, :beam_size],
out=eos_scores,
)
num_remaining_sent -= finalize_hypos(
finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
assert step < maxlen
if len(finalized_sents) > 0:
# construct batch_idxs which holds indices of batches to keep for the next pass
new_bsz = bsz - len(finalized_sents)
batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[torch.LongTensor(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
......@@ -382,6 +427,7 @@ class SequenceGenerator(object):
cand_scores, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
......@@ -425,7 +471,7 @@ class SequenceGenerator(object):
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(bsz):
for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized
......
......@@ -63,7 +63,7 @@ def main(args):
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader(
args.gen_subset,
max_sentences=args.max_sentences,
max_sentences=args.max_sentences or 128,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment