"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "e42b0c4f704fa0f5e262f370dccac537b5edf2b1"
Commit bb5f15d1 authored by Myle Ott's avatar Myle Ott
Browse files

Iterate on need_attn and fix tests

parent 498a186d
...@@ -14,7 +14,7 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -14,7 +14,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary): def __init__(self, dictionary):
super().__init__(dictionary) super().__init__(dictionary)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None, need_attn=False): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
raise NotImplementedError raise NotImplementedError
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
......
...@@ -104,9 +104,9 @@ class FairseqModel(BaseFairseqModel): ...@@ -104,9 +104,9 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.encoder, FairseqEncoder) assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, need_attn=False): def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths) encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out, need_attn=need_attn) decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out return decoder_out
def max_positions(self): def max_positions(self):
......
...@@ -352,6 +352,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -352,6 +352,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.dropout = dropout self.dropout = dropout
self.normalization_constant = normalization_constant self.normalization_constant = normalization_constant
self.left_pad = left_pad self.left_pad = left_pad
self.need_attn = True
convolutions = extend_conv_spec(convolutions) convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
...@@ -417,7 +418,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -417,7 +418,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
else: else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None, need_attn=False): def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None):
if encoder_out_dict is not None: if encoder_out_dict is not None:
encoder_out = encoder_out_dict['encoder_out'] encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
...@@ -467,7 +468,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -467,7 +468,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask) x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
if need_attn: if self.need_attn:
attn_scores = attn_scores / num_attn_layers attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None: if avg_attn_scores is None:
avg_attn_scores = attn_scores avg_attn_scores = attn_scores
...@@ -523,6 +524,9 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -523,6 +524,9 @@ class FConvDecoder(FairseqIncrementalDecoder):
state_dict['decoder.version'] = torch.Tensor([1]) state_dict['decoder.version'] = torch.Tensor([1])
return state_dict return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _embed_tokens(self, tokens, incremental_state): def _embed_tokens(self, tokens, incremental_state):
if incremental_state is not None: if incremental_state is not None:
# keep only the last token for incremental forward pass # keep only the last token for incremental forward pass
......
...@@ -259,6 +259,7 @@ class FConvDecoder(FairseqDecoder): ...@@ -259,6 +259,7 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder = trained_decoder self.pretrained_decoder = trained_decoder
self.dropout = dropout self.dropout = dropout
self.left_pad = left_pad self.left_pad = left_pad
self.need_attn = True
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
def expand_bool_array(val): def expand_bool_array(val):
...@@ -352,7 +353,7 @@ class FConvDecoder(FairseqDecoder): ...@@ -352,7 +353,7 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder.fc2.register_forward_hook(save_output()) self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out_dict, need_attn=False): def forward(self, prev_output_tokens, encoder_out_dict):
encoder_out = encoder_out_dict['encoder']['encoder_out'] encoder_out = encoder_out_dict['encoder']['encoder_out']
trained_encoder_out = encoder_out_dict['pretrained'] if self.pretrained else None trained_encoder_out = encoder_out_dict['pretrained'] if self.pretrained else None
...@@ -388,7 +389,7 @@ class FConvDecoder(FairseqDecoder): ...@@ -388,7 +389,7 @@ class FConvDecoder(FairseqDecoder):
r = x r = x
x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b) x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
x = x + r x = x + r
if need_attn: if self.need_attn:
if avg_attn_scores is None: if avg_attn_scores is None:
avg_attn_scores = attn_scores avg_attn_scores = attn_scores
else: else:
...@@ -427,6 +428,9 @@ class FConvDecoder(FairseqDecoder): ...@@ -427,6 +428,9 @@ class FConvDecoder(FairseqDecoder):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _split_encoder_out(self, encoder_out): def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs.""" """Split and transpose encoder outputs."""
# transpose only once to speed up attention layers # transpose only once to speed up attention layers
......
...@@ -298,6 +298,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -298,6 +298,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.dropout_out = dropout_out self.dropout_out = dropout_out
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed self.share_input_output_embed = share_input_output_embed
self.need_attn = True
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
...@@ -324,7 +325,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -324,7 +325,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if not self.share_input_output_embed: if not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None, need_attn=False): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out'] encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
...@@ -395,7 +396,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -395,7 +396,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(1, 0) x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2) if need_attn else None attn_scores = attn_scores.transpose(0, 2) if self.need_attn else None
# project back to size of vocabulary # project back to size of vocabulary
if hasattr(self, 'additional_fc'): if hasattr(self, 'additional_fc'):
...@@ -426,6 +427,9 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -426,6 +427,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
......
...@@ -193,6 +193,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -193,6 +193,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
self.need_attn = True
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx padding_idx = embed_tokens.padding_idx
...@@ -215,7 +216,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -215,7 +216,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None, need_attn=False): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions # embed positions
positions = self.embed_positions( positions = self.embed_positions(
prev_output_tokens, prev_output_tokens,
...@@ -266,6 +267,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -266,6 +267,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor() state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block.
...@@ -339,8 +343,9 @@ class TransformerDecoderLayer(nn.Module): ...@@ -339,8 +343,9 @@ class TransformerDecoderLayer(nn.Module):
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)]) self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, need_attn=False): def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x residual = x
x = self.maybe_layer_norm(0, x, before=True) x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn( x, _ = self.self_attn(
...@@ -364,7 +369,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -364,7 +369,7 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask, key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state, incremental_state=incremental_state,
static_kv=True, static_kv=True,
need_weights=need_attn, need_weights=self.need_attn,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -387,6 +392,9 @@ class TransformerDecoderLayer(nn.Module): ...@@ -387,6 +392,9 @@ class TransformerDecoderLayer(nn.Module):
else: else:
return x return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
......
...@@ -54,7 +54,7 @@ class SequenceGenerator(object): ...@@ -54,7 +54,7 @@ class SequenceGenerator(object):
def generate_batched_itr( def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None, self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0, with_attention=False, cuda=False, timer=None, prefix_size=0,
): ):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
Args: Args:
...@@ -81,7 +81,6 @@ class SequenceGenerator(object): ...@@ -81,7 +81,6 @@ class SequenceGenerator(object):
beam_size=beam_size, beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b), maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
with_attention=with_attention,
) )
if timer is not None: if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos)) timer.stop(sum(len(h[0]['tokens']) for h in hypos))
...@@ -91,12 +90,12 @@ class SequenceGenerator(object): ...@@ -91,12 +90,12 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None, with_attention=False): def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with torch.no_grad(): with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens, with_attention) return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None, with_attention=False): def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
bsz, srclen = src_tokens.size() bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
...@@ -127,9 +126,8 @@ class SequenceGenerator(object): ...@@ -127,9 +126,8 @@ class SequenceGenerator(object):
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad) tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = self.eos tokens[:, 0] = self.eos
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2) attn, attn_buf = None, None
attn_buf = attn.clone() nonpad_idxs = None
nonpad_idxs = src_tokens.ne(self.pad) if with_attention else None
# list of completed sentences # list of completed sentences
finalized = [[] for i in range(bsz)] finalized = [[] for i in range(bsz)]
...@@ -193,7 +191,7 @@ class SequenceGenerator(object): ...@@ -193,7 +191,7 @@ class SequenceGenerator(object):
tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
# compute scores per token position # compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1] pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
...@@ -222,7 +220,7 @@ class SequenceGenerator(object): ...@@ -222,7 +220,7 @@ class SequenceGenerator(object):
def get_hypo(): def get_hypo():
if with_attention: if attn_clone is not None:
# remove padding tokens from attn scores # remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]] hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0) _, alignment = hypo_attn.max(dim=0)
...@@ -275,8 +273,7 @@ class SequenceGenerator(object): ...@@ -275,8 +273,7 @@ class SequenceGenerator(object):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state) model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state) encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode( probs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
tokens[:, :step + 1], encoder_outs, incremental_states, with_attention)
if step == 0: if step == 0:
# at the first step all hypotheses are equally likely, so use # at the first step all hypotheses are equally likely, so use
# only the first beam # only the first beam
...@@ -292,6 +289,10 @@ class SequenceGenerator(object): ...@@ -292,6 +289,10 @@ class SequenceGenerator(object):
# Record attention scores # Record attention scores
if avg_attn_scores is not None: if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores) attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores) cand_scores = buffer('cand_scores', type_of=scores)
...@@ -423,8 +424,9 @@ class SequenceGenerator(object): ...@@ -423,8 +424,9 @@ class SequenceGenerator(object):
scores_buf.resize_as_(scores) scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens) tokens_buf.resize_as_(tokens)
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1) if attn is not None:
attn_buf.resize_as_(attn) attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz bsz = new_bsz
else: else:
batch_idxs = None batch_idxs = None
...@@ -479,15 +481,17 @@ class SequenceGenerator(object): ...@@ -479,15 +481,17 @@ class SequenceGenerator(object):
) )
# copy attention for active hypotheses # copy attention for active hypotheses
torch.index_select( if attn is not None:
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, torch.index_select(
out=attn_buf[:, :, :step + 2], attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
) out=attn_buf[:, :, :step + 2],
)
# swap buffers # swap buffers
tokens, tokens_buf = tokens_buf, tokens tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores scores, scores_buf = scores_buf, scores
attn, attn_buf = attn_buf, attn if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder # reorder incremental state in decoder
reorder_state = active_bbsz_idx reorder_state = active_bbsz_idx
...@@ -498,16 +502,14 @@ class SequenceGenerator(object): ...@@ -498,16 +502,14 @@ class SequenceGenerator(object):
return finalized return finalized
def _decode(self, tokens, encoder_outs, incremental_states, with_attention): def _decode(self, tokens, encoder_outs, incremental_states):
if len(self.models) == 1: if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True, return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
with_attention=with_attention, )
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False, probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
with_attention=with_attention, )
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
...@@ -523,13 +525,12 @@ class SequenceGenerator(object): ...@@ -523,13 +525,12 @@ class SequenceGenerator(object):
avg_attn.div_(len(self.models)) avg_attn.div_(len(self.models))
return avg_probs, avg_attn return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs, with_attention): def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad(): with torch.no_grad():
if incremental_states[model] is not None: if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model], decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
need_attn=with_attention))
else: else:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=None, need_attn=with_attention)) decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :] decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1] attn = decoder_out[1]
if attn is not None: if attn is not None:
......
...@@ -42,7 +42,10 @@ def main(args): ...@@ -42,7 +42,10 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam) model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -88,7 +91,6 @@ def main(args): ...@@ -88,7 +91,6 @@ def main(args):
translations = translator.generate_batched_itr( translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size, cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
with_attention=args.print_alignment,
) )
wps_meter = TimeMeter() wps_meter = TimeMeter()
......
...@@ -81,7 +81,10 @@ def main(args): ...@@ -81,7 +81,10 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam) model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -112,13 +115,16 @@ def main(args): ...@@ -112,13 +115,16 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(), hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict, align_dict=align_dict,
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str)) result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.alignments.append('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))) result.alignments.append(
'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
if args.print_alignment else None
)
return result return result
def process_batch(batch): def process_batch(batch):
...@@ -152,7 +158,8 @@ def main(args): ...@@ -152,7 +158,8 @@ def main(args):
print(result.src_str) print(result.src_str)
for hypo, align in zip(result.hypos, result.alignments): for hypo, align in zip(result.hypos, result.alignments):
print(hypo) print(hypo)
print(align) if align is not None:
print(align)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -203,6 +203,7 @@ def generate_main(data_dir, extra_flags=None): ...@@ -203,6 +203,7 @@ def generate_main(data_dir, extra_flags=None):
'--max-len-b', '5', '--max-len-b', '5',
'--gen-subset', 'valid', '--gen-subset', 'valid',
'--no-progress-bar', '--no-progress-bar',
'--print-alignment',
] + (extra_flags or []), ] + (extra_flags or []),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment