Commit 14506a83 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix some recursive functions (e.g., reorder_incremental_state) to only touch...

Fix some recursive functions (e.g., reorder_incremental_state) to only touch each module once (#379)

Summary:
This can happen if a module is registered in more than one place in the network.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/379

Differential Revision: D13154498

Pulled By: myleott

fbshipit-source-id: a35575d1956a46cd35ac8b16a719ad20ac3e380a
parent 3c19878f
......@@ -56,19 +56,26 @@ class FairseqIncrementalDecoder(FairseqDecoder):
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
seen = set()
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(
incremental_state,
new_order,
)
if module != self and hasattr(module, 'reorder_incremental_state') \
and module not in seen:
seen.add(module)
module.reorder_incremental_state(incremental_state, new_order)
self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size:
seen = set()
def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'):
if module != self and hasattr(module, 'set_beam_size') \
and module not in seen:
seen.add(module)
module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
self._beam_size = beam_size
......@@ -99,8 +99,12 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_remove_weight_norm)
seen = set()
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
if module != self and hasattr(module, 'make_generation_fast_') \
and module not in seen:
seen.add(module)
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
......@@ -115,8 +119,12 @@ class BaseFairseqModel(nn.Module):
def prepare_for_onnx_export_(self, **kwargs):
"""Make model exportable via ONNX trace."""
seen = set()
def apply_prepare_for_onnx_export_(module):
if module != self and hasattr(module, 'prepare_for_onnx_export_'):
if module != self and hasattr(module, 'prepare_for_onnx_export_') \
and module not in seen:
seen.add(module)
module.prepare_for_onnx_export_(**kwargs)
self.apply(apply_prepare_for_onnx_export_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment