Commit dccf7909 authored by Myle Ott's avatar Myle Ott
Browse files

Add option to SequenceGenerator to retain dropout

parent c542884d
...@@ -86,6 +86,7 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -86,6 +86,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
beam_size is required if using BeamableMM. beam_size is required if using BeamableMM.
""" """
if self._is_incremental_eval: if self._is_incremental_eval:
del self._incremental_state
self._incremental_state = {} self._incremental_state = {}
def apply_clear_incremental_state(module): def apply_clear_incremental_state(module):
......
...@@ -19,7 +19,7 @@ from fairseq.models import FairseqIncrementalDecoder ...@@ -19,7 +19,7 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=200, def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1, stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0): unk_penalty=0, retain_dropout=False):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
...@@ -45,6 +45,7 @@ class SequenceGenerator(object): ...@@ -45,6 +45,7 @@ class SequenceGenerator(object):
self.normalize_scores = normalize_scores self.normalize_scores = normalize_scores
self.len_penalty = len_penalty self.len_penalty = len_penalty
self.unk_penalty = unk_penalty self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
def cuda(self): def cuda(self):
for model in self.models: for model in self.models:
...@@ -98,6 +99,7 @@ class SequenceGenerator(object): ...@@ -98,6 +99,7 @@ class SequenceGenerator(object):
encoder_outs = [] encoder_outs = []
for model in self.models: for model in self.models:
if not self.retain_dropout:
model.eval() model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size) model.decoder.set_beam_size(beam_size)
...@@ -327,15 +329,19 @@ class SequenceGenerator(object): ...@@ -327,15 +329,19 @@ class SequenceGenerator(object):
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, encoder_out) decoder_out, attn = model.decoder(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :], dim=1).data probs = F.softmax(decoder_out[:, -1, :], dim=1).data
attn = attn[:, -1, :].data if avg_probs is None:
if avg_probs is None or avg_attn is None:
avg_probs = probs avg_probs = probs
avg_attn = attn
else: else:
avg_probs.add_(probs) avg_probs.add_(probs)
if attn is not None:
attn = attn[:, -1, :].data
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn) avg_attn.add_(attn)
avg_probs.div_(len(self.models)) avg_probs.div_(len(self.models))
avg_probs.log_() avg_probs.log_()
if avg_attn is not None:
avg_attn.div_(len(self.models)) avg_attn.div_(len(self.models))
return avg_probs, avg_attn return avg_probs, avg_attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment