Commit dccf7909 authored by Myle Ott's avatar Myle Ott
Browse files

Add option to SequenceGenerator to retain dropout

parent c542884d
......@@ -86,6 +86,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
beam_size is required if using BeamableMM.
"""
if self._is_incremental_eval:
del self._incremental_state
self._incremental_state = {}
def apply_clear_incremental_state(module):
......
......@@ -19,7 +19,7 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0):
unk_penalty=0, retain_dropout=False):
"""Generates translations of a given source sentence.
Args:
......@@ -45,6 +45,7 @@ class SequenceGenerator(object):
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
def cuda(self):
for model in self.models:
......@@ -98,7 +99,8 @@ class SequenceGenerator(object):
encoder_outs = []
for model in self.models:
model.eval()
if not self.retain_dropout:
model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size)
......@@ -327,15 +329,19 @@ class SequenceGenerator(object):
for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :], dim=1).data
attn = attn[:, -1, :].data
if avg_probs is None or avg_attn is None:
if avg_probs is None:
avg_probs = probs
avg_attn = attn
else:
avg_probs.add_(probs)
avg_attn.add_(attn)
if attn is not None:
attn = attn[:, -1, :].data
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_probs.div_(len(self.models))
avg_probs.log_()
avg_attn.div_(len(self.models))
if avg_attn is not None:
avg_attn.div_(len(self.models))
return avg_probs, avg_attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment