Commit 2a84f46b authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

remove completed sentences from batch

remove completed sentences from batch and allow batching uneven lengths (with fixes to make padded sequences work correctly in all models)
parent bcdc27dc
...@@ -183,7 +183,8 @@ class LanguageDatasets(object): ...@@ -183,7 +183,8 @@ class LanguageDatasets(object):
dataset.src, dataset.dst, max_tokens, max_sentences, dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending) descending=descending,
allow_different_src_lens=True)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards) batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater, dataset, num_workers=num_workers, collate_fn=dataset.collater,
...@@ -369,7 +370,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -369,7 +370,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def batches_by_size(src, dst, max_tokens=None, max_sentences=None, def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False, max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False, required_batch_size_multiple=1): descending=False, required_batch_size_multiple=1, allow_different_src_lens=False):
"""Returns batches of indices sorted by size. Sequences with different """Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch.""" source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset)) assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
...@@ -382,7 +383,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -382,7 +383,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
indices = np.flip(indices, 0) indices = np.flip(indices, 0)
return list(_make_batches( return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions, src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False, ignore_invalid_inputs, allow_different_src_lens=allow_different_src_lens,
required_batch_size_multiple=required_batch_size_multiple, required_batch_size_multiple=required_batch_size_multiple,
)) ))
......
...@@ -26,9 +26,15 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -26,9 +26,15 @@ class FairseqIncrementalDecoder(FairseqDecoder):
""" """
def apply_reorder_incremental_state(module): def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'): if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(incremental_state, new_order) module.reorder_incremental_state(
incremental_state,
new_order,
)
self.apply(apply_reorder_incremental_state) self.apply(apply_reorder_incremental_state)
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children.""" """Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size: if getattr(self, '_beam_size', -1) != beam_size:
......
...@@ -103,15 +103,12 @@ class FConvEncoder(FairseqEncoder): ...@@ -103,15 +103,12 @@ class FConvEncoder(FairseqEncoder):
self.num_attention_layers = None self.num_attention_layers = None
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
max_positions, max_positions,
embed_dim, embed_dim,
padding_idx, self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE, left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
) )
...@@ -142,12 +139,21 @@ class FConvEncoder(FairseqEncoder): ...@@ -142,12 +139,21 @@ class FConvEncoder(FairseqEncoder):
# project to size of convolution # project to size of convolution
x = self.fc1(x) x = self.fc1(x)
# used to mask padding in input
encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B
if not encoder_padding_mask.any():
encoder_padding_mask = None
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
# temporal convolutions # temporal convolutions
for proj, conv in zip(self.projections, self.convolutions): for proj, conv in zip(self.projections, self.convolutions):
residual = x if proj is None else proj(x) residual = x if proj is None else proj(x)
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
if conv.kernel_size[0] % 2 == 1: if conv.kernel_size[0] % 2 == 1:
# padding is implicit in the conv # padding is implicit in the conv
...@@ -166,13 +172,20 @@ class FConvEncoder(FairseqEncoder): ...@@ -166,13 +172,20 @@ class FConvEncoder(FairseqEncoder):
# project back to size of embedding # project back to size of embedding
x = self.fc2(x) x = self.fc2(x)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask.t() # -> B x T
x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0)
# scale gradients (this only affects backward, not forward) # scale gradients (this only affects backward, not forward)
x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention # add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(0.5) y = (x + input_embedding) * math.sqrt(0.5)
return x, y return {
'encoder_out': (x, y),
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
...@@ -189,13 +202,20 @@ class AttentionLayer(nn.Module): ...@@ -189,13 +202,20 @@ class AttentionLayer(nn.Module):
self.bmm = bmm if bmm is not None else torch.bmm self.bmm = bmm if bmm is not None else torch.bmm
def forward(self, x, target_embedding, encoder_out): def forward(self, x, target_embedding, encoder_out, encoder_padding_mask):
residual = x residual = x
# attention # attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5) x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = self.bmm(x, encoder_out[0]) x = self.bmm(x, encoder_out[0])
# don't attend over padding
if encoder_padding_mask is not None:
x = x.float().masked_fill(
encoder_padding_mask.unsqueeze(1),
float('-inf')
).type_as(x) # FP16 support: cast to float and back
# softmax over last dim # softmax over last dim
sz = x.size() sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1) x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1)
...@@ -204,9 +224,14 @@ class AttentionLayer(nn.Module): ...@@ -204,9 +224,14 @@ class AttentionLayer(nn.Module):
x = self.bmm(x, encoder_out[1]) x = self.bmm(x, encoder_out[1])
# scale attention output # scale attention output (respecting potentially different lengths)
s = encoder_out[1].size(1) s = encoder_out[1].size(1)
if encoder_padding_mask is None:
x = x * (s * math.sqrt(1.0 / s)) x = x * (s * math.sqrt(1.0 / s))
else:
s = s - encoder_padding_mask.type_as(x).sum(dim=1, keepdim=True) # exclude padding
s = s.unsqueeze(-1)
x = x * (s * s.rsqrt())
# project back # project back
x = (self.out_projection(x) + residual) * math.sqrt(0.5) x = (self.out_projection(x) + residual) * math.sqrt(0.5)
...@@ -274,7 +299,10 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -274,7 +299,10 @@ class FConvDecoder(FairseqIncrementalDecoder):
else: else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
# split and transpose encoder outputs # split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state) encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)
...@@ -307,7 +335,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -307,7 +335,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
if attention is not None: if attention is not None:
x = self._transpose_if_training(x, incremental_state) x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b)) x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
attn_scores = attn_scores / num_attn_layers attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None: if avg_attn_scores is None:
avg_attn_scores = attn_scores avg_attn_scores = attn_scores
...@@ -373,6 +401,23 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -373,6 +401,23 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x return x
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if encoder_out is not None:
def update_enc_out(enc_out):
return enc_out.index_select(0, new_order)
encoder_out = tuple([update_enc_out(eo) for eo in encoder_out])
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1) m.weight.data.normal_(0, 0.1)
......
...@@ -209,7 +209,12 @@ class LSTMEncoder(FairseqEncoder): ...@@ -209,7 +209,12 @@ class LSTMEncoder(FairseqEncoder):
dim=0).view(bsz, self.output_units)) dim=0).view(bsz, self.output_units))
return x, bi_final_hiddens, bi_final_cells return x, bi_final_hiddens, bi_final_cells
return x, final_hiddens, final_cells encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
return {
'encoder_out': (x, final_hiddens, final_cells),
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
}
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
...@@ -223,7 +228,7 @@ class AttentionLayer(nn.Module): ...@@ -223,7 +228,7 @@ class AttentionLayer(nn.Module):
self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False) self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False)
self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False) self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False)
def forward(self, input, source_hids, src_lengths=None): def forward(self, input, source_hids, encoder_padding_mask):
# input: bsz x input_embed_dim # input: bsz x input_embed_dim
# source_hids: srclen x bsz x output_embed_dim # source_hids: srclen x bsz x output_embed_dim
...@@ -232,6 +237,14 @@ class AttentionLayer(nn.Module): ...@@ -232,6 +237,14 @@ class AttentionLayer(nn.Module):
# compute attention # compute attention
attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2) attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
# don't attend over padding
if encoder_padding_mask is not None:
attn_scores = attn_scores.float().masked_fill_(
encoder_padding_mask,
float('-inf')
).type_as(attn_scores) # FP16 support: cast to float and back
attn_scores = F.softmax(attn_scores.t(), dim=1).t() # srclen x bsz attn_scores = F.softmax(attn_scores.t(), dim=1).t() # srclen x bsz
# sum weighted sources # sum weighted sources
...@@ -278,7 +291,10 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -278,7 +291,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.additional_fc = Linear(hidden_size, out_embed_dim) self.additional_fc = Linear(hidden_size, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
if incremental_state is not None: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
bsz, seqlen = prev_output_tokens.size() bsz, seqlen = prev_output_tokens.size()
...@@ -324,7 +340,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -324,7 +340,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
# apply attention using the last layer's hidden state # apply attention using the last layer's hidden state
if self.attention is not None: if self.attention is not None:
out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs, encoder_padding_mask)
else: else:
out = hidden out = hidden
out = F.dropout(out, p=self.dropout_out, training=self.training) out = F.dropout(out, p=self.dropout_out, training=self.training)
...@@ -371,6 +387,13 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -371,6 +387,13 @@ class LSTMDecoder(FairseqIncrementalDecoder):
new_state = tuple(map(reorder_state, cached_state)) new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order) for eo in encoder_out_dict['encoder_out'])
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
......
...@@ -233,6 +233,11 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -233,6 +233,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return state_dict return state_dict
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block.
......
...@@ -112,7 +112,7 @@ class SequenceGenerator(object): ...@@ -112,7 +112,7 @@ class SequenceGenerator(object):
# compute the encoder output for each beam # compute the encoder output for each beam
encoder_out = model.encoder( encoder_out = model.encoder(
src_tokens.repeat(1, beam_size).view(-1, srclen), src_tokens.repeat(1, beam_size).view(-1, srclen),
src_lengths.repeat(beam_size), src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
) )
encoder_outs.append(encoder_out) encoder_outs.append(encoder_out)
...@@ -200,10 +200,20 @@ class SequenceGenerator(object): ...@@ -200,10 +200,20 @@ class SequenceGenerator(object):
if self.normalize_scores: if self.normalize_scores:
eos_scores /= (step+1)**self.len_penalty eos_scores /= (step+1)**self.len_penalty
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set() sents_seen = set()
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())): for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
sent = idx // beam_size unfin_idx = idx // beam_size
sents_seen.add(sent) sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
def get_hypo(): def get_hypo():
_, alignment = attn_clone[i].max(dim=0) _, alignment = attn_clone[i].max(dim=0)
...@@ -230,23 +240,27 @@ class SequenceGenerator(object): ...@@ -230,23 +240,27 @@ class SequenceGenerator(object):
'idx': idx, 'idx': idx,
} }
# return number of hypotheses finished this step newly_finished = []
num_finished = 0 for sent, unfin_idx in sents_seen:
for sent in sents_seen:
# check termination conditions for this sentence # check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores): if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True finished[sent] = True
num_finished += 1 newly_finished.append(unfin_idx)
return num_finished return newly_finished
reorder_state = None reorder_state = None
batch_idxs = None
for step in range(maxlen + 1): # one extra step for EOS marker for step in range(maxlen + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams # reorder decoder internal states based on the prev choice of beams
if reorder_state is not None: if reorder_state is not None:
for model in self.models: if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state( model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
incremental_states[model], reorder_state) encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode( probs, avg_attn_scores = self._decode(
tokens[:, :step+1], encoder_outs, incremental_states) tokens[:, :step+1], encoder_outs, incremental_states)
...@@ -308,6 +322,7 @@ class SequenceGenerator(object): ...@@ -308,6 +322,7 @@ class SequenceGenerator(object):
else: else:
# take the best 2 x beam_size predictions. We'll choose the first # take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with. # beam_size of these which don't predict eos to continue with.
torch.topk( torch.topk(
probs.view(bsz, -1), probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
...@@ -323,18 +338,20 @@ class SequenceGenerator(object): ...@@ -323,18 +338,20 @@ class SequenceGenerator(object):
descending=True, descending=True,
out=(eos_scores, eos_bbsz_idx), out=(eos_scores, eos_bbsz_idx),
) )
num_remaining_sent -= finalize_hypos( num_remaining_sent -= len(finalize_hypos(
step, eos_bbsz_idx, eos_scores) step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0 assert num_remaining_sent == 0
break break
# cand_bbsz_idx contains beam indices for the top candidate # cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size), # hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size] # and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add_(bbsz_offsets) cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos # finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos) eos_mask = cand_indices.eq(self.eos)
finalized_sents = set()
if step >= self.minlen: if step >= self.minlen:
# only consider eos when it's among the top beam_size indices # only consider eos when it's among the top beam_size indices
torch.masked_select( torch.masked_select(
...@@ -348,14 +365,42 @@ class SequenceGenerator(object): ...@@ -348,14 +365,42 @@ class SequenceGenerator(object):
mask=eos_mask[:, :beam_size], mask=eos_mask[:, :beam_size],
out=eos_scores, out=eos_scores,
) )
num_remaining_sent -= finalize_hypos( finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores) step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0 assert num_remaining_sent >= 0
if num_remaining_sent == 0: if num_remaining_sent == 0:
break break
assert step < maxlen assert step < maxlen
if len(finalized_sents) > 0:
# construct batch_idxs which holds indices of batches to keep for the next pass
new_bsz = bsz - len(finalized_sents)
batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[torch.LongTensor(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos # set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos. # and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos # After, the min values per row are the top candidate active hypos
...@@ -382,6 +427,7 @@ class SequenceGenerator(object): ...@@ -382,6 +427,7 @@ class SequenceGenerator(object):
cand_scores, dim=1, index=active_hypos, cand_scores, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size), out=scores[:, step].view(bsz, beam_size),
) )
active_bbsz_idx = active_bbsz_idx.view(-1) active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1) active_scores = active_scores.view(-1)
...@@ -425,7 +471,7 @@ class SequenceGenerator(object): ...@@ -425,7 +471,7 @@ class SequenceGenerator(object):
reorder_state = active_bbsz_idx reorder_state = active_bbsz_idx
# sort by score descending # sort by score descending
for sent in range(bsz): for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized return finalized
......
...@@ -63,7 +63,7 @@ def main(args): ...@@ -63,7 +63,7 @@ def main(args):
max_positions = min(model.max_encoder_positions() for model in models) max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader( itr = dataset.eval_dataloader(
args.gen_subset, args.gen_subset,
max_sentences=args.max_sentences, max_sentences=args.max_sentences or 128,
max_positions=max_positions, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment