Commit 50fdf591 authored by Myle Ott's avatar Myle Ott
Browse files

Don't call forward directly (prefer module(x) to module.forward(x))

parent 5c7f4954
...@@ -18,12 +18,10 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -18,12 +18,10 @@ class FairseqIncrementalDecoder(FairseqDecoder):
self._incremental_state = {} self._incremental_state = {}
def forward(self, tokens, encoder_out): def forward(self, tokens, encoder_out):
if self._is_incremental_eval:
raise NotImplementedError
else:
raise NotImplementedError raise NotImplementedError
def incremental_forward(self, tokens, encoder_out):
"""Forward pass for one time step."""
# keep only the last token for incremental forward pass
return self.forward(tokens[:, -1:], encoder_out)
def incremental_inference(self): def incremental_inference(self):
"""Context manager for incremental inference. """Context manager for incremental inference.
...@@ -38,8 +36,7 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -38,8 +36,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
``` ```
with model.decoder.incremental_inference(): with model.decoder.incremental_inference():
for step in range(maxlen): for step in range(maxlen):
out, _ = model.decoder.incremental_forward( out, _ = model.decoder(tokens[:, :step], encoder_out)
tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :]) probs = torch.nn.functional.log_softmax(out[:, -1, :])
``` ```
""" """
......
...@@ -185,6 +185,13 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -185,6 +185,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, input_tokens, encoder_out): def forward(self, input_tokens, encoder_out):
if self._is_incremental_eval:
return self.incremental_forward(input_tokens, encoder_out)
else:
return self.batch_forward(input_tokens, encoder_out)
def batch_forward(self, input_tokens, encoder_out):
"""Forward pass for decoding multiple time steps in batch mode."""
positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(), positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_TARGET)) left_pad=LanguagePairDataset.LEFT_PAD_TARGET))
return self._forward(input_tokens, positions, encoder_out) return self._forward(input_tokens, positions, encoder_out)
......
...@@ -50,13 +50,6 @@ class LinearizedConvolution(ConvTBC): ...@@ -50,13 +50,6 @@ class LinearizedConvolution(ConvTBC):
call reorder_incremental_state. To apply to fresh inputs, call call reorder_incremental_state. To apply to fresh inputs, call
clear_incremental_state. clear_incremental_state.
""" """
if self.training or not self._is_incremental_eval:
raise RuntimeError('incremental_forward only supports incremental evaluation')
# run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values():
hook(self, input)
# reshape weight # reshape weight
weight = self._get_linearized_weight() weight = self._get_linearized_weight()
kw = self.kernel_size[0] kw = self.kernel_size[0]
......
...@@ -325,10 +325,7 @@ class SequenceGenerator(object): ...@@ -325,10 +325,7 @@ class SequenceGenerator(object):
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
if isinstance(model.decoder, FairseqIncrementalDecoder): decoder_out, attn = model.decoder(tokens, encoder_out)
decoder_out, attn = model.decoder.incremental_forward(tokens, encoder_out)
else:
decoder_out, attn = model.decoder.forward(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :]).data probs = F.softmax(decoder_out[:, -1, :]).data
attn = attn[:, -1, :].data attn = attn[:, -1, :].data
if avg_probs is None or avg_attn is None: if avg_probs is None or avg_attn is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment