Commit ef179415 authored by Myle Ott's avatar Myle Ott
Browse files

Faster generation when using a single model (rather than ensemble)

parent bfcc6ec7
...@@ -493,25 +493,18 @@ class SequenceGenerator(object): ...@@ -493,25 +493,18 @@ class SequenceGenerator(object):
return finalized return finalized
def _decode(self, tokens, encoder_outs, incremental_states): def _decode(self, tokens, encoder_outs, incremental_states):
if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
with torch.no_grad(): probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
avg_probs.add_(probs) avg_probs.add_(probs)
if attn is not None: if attn is not None:
attn = attn[:, -1, :].data
if avg_attn is None: if avg_attn is None:
avg_attn = attn avg_attn = attn
else: else:
...@@ -520,5 +513,17 @@ class SequenceGenerator(object): ...@@ -520,5 +513,17 @@ class SequenceGenerator(object):
avg_probs.log_() avg_probs.log_()
if avg_attn is not None: if avg_attn is not None:
avg_attn.div_(len(self.models)) avg_attn.div_(len(self.models))
return avg_probs, avg_attn return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
if attn is not None:
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
return probs, attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment