Commit 50fdf591 authored by Myle Ott's avatar Myle Ott
Browse files

Don't call forward directly (prefer module(x) to module.forward(x))

parent 5c7f4954
......@@ -18,12 +18,10 @@ class FairseqIncrementalDecoder(FairseqDecoder):
self._incremental_state = {}
def forward(self, tokens, encoder_out):
if self._is_incremental_eval:
raise NotImplementedError
else:
raise NotImplementedError
def incremental_forward(self, tokens, encoder_out):
"""Forward pass for one time step."""
# keep only the last token for incremental forward pass
return self.forward(tokens[:, -1:], encoder_out)
def incremental_inference(self):
"""Context manager for incremental inference.
......@@ -38,8 +36,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder.incremental_forward(
tokens[:, :step], encoder_out)
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :])
```
"""
......
......@@ -185,6 +185,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, input_tokens, encoder_out):
if self._is_incremental_eval:
return self.incremental_forward(input_tokens, encoder_out)
else:
return self.batch_forward(input_tokens, encoder_out)
def batch_forward(self, input_tokens, encoder_out):
"""Forward pass for decoding multiple time steps in batch mode."""
positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_TARGET))
return self._forward(input_tokens, positions, encoder_out)
......
......@@ -50,13 +50,6 @@ class LinearizedConvolution(ConvTBC):
call reorder_incremental_state. To apply to fresh inputs, call
clear_incremental_state.
"""
if self.training or not self._is_incremental_eval:
raise RuntimeError('incremental_forward only supports incremental evaluation')
# run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values():
hook(self, input)
# reshape weight
weight = self._get_linearized_weight()
kw = self.kernel_size[0]
......
......@@ -325,10 +325,7 @@ class SequenceGenerator(object):
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
if isinstance(model.decoder, FairseqIncrementalDecoder):
decoder_out, attn = model.decoder.incremental_forward(tokens, encoder_out)
else:
decoder_out, attn = model.decoder.forward(tokens, encoder_out)
decoder_out, attn = model.decoder(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :]).data
attn = attn[:, -1, :].data
if avg_probs is None or avg_attn is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment