Commit 14506a83 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix some recursive functions (e.g., reorder_incremental_state) to only touch...

Fix some recursive functions (e.g., reorder_incremental_state) to only touch each module once (#379)

Summary:
This can happen if a module is registered in more than one place in the network.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/379

Differential Revision: D13154498

Pulled By: myleott

fbshipit-source-id: a35575d1956a46cd35ac8b16a719ad20ac3e380a
parent 3c19878f
...@@ -56,19 +56,26 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -56,19 +56,26 @@ class FairseqIncrementalDecoder(FairseqDecoder):
previous time step. A typical use case is beam search, where the input previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams. order changes between time steps based on the selection of beams.
""" """
seen = set()
def apply_reorder_incremental_state(module): def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'): if module != self and hasattr(module, 'reorder_incremental_state') \
module.reorder_incremental_state( and module not in seen:
incremental_state, seen.add(module)
new_order, module.reorder_incremental_state(incremental_state, new_order)
)
self.apply(apply_reorder_incremental_state) self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children.""" """Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size: if getattr(self, '_beam_size', -1) != beam_size:
seen = set()
def apply_set_beam_size(module): def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'): if module != self and hasattr(module, 'set_beam_size') \
and module not in seen:
seen.add(module)
module.set_beam_size(beam_size) module.set_beam_size(beam_size)
self.apply(apply_set_beam_size) self.apply(apply_set_beam_size)
self._beam_size = beam_size self._beam_size = beam_size
...@@ -99,8 +99,12 @@ class BaseFairseqModel(nn.Module): ...@@ -99,8 +99,12 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_remove_weight_norm) self.apply(apply_remove_weight_norm)
seen = set()
def apply_make_generation_fast_(module): def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'): if module != self and hasattr(module, 'make_generation_fast_') \
and module not in seen:
seen.add(module)
module.make_generation_fast_(**kwargs) module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_) self.apply(apply_make_generation_fast_)
...@@ -115,8 +119,12 @@ class BaseFairseqModel(nn.Module): ...@@ -115,8 +119,12 @@ class BaseFairseqModel(nn.Module):
def prepare_for_onnx_export_(self, **kwargs): def prepare_for_onnx_export_(self, **kwargs):
"""Make model exportable via ONNX trace.""" """Make model exportable via ONNX trace."""
seen = set()
def apply_prepare_for_onnx_export_(module): def apply_prepare_for_onnx_export_(module):
if module != self and hasattr(module, 'prepare_for_onnx_export_'): if module != self and hasattr(module, 'prepare_for_onnx_export_') \
and module not in seen:
seen.add(module)
module.prepare_for_onnx_export_(**kwargs) module.prepare_for_onnx_export_(**kwargs)
self.apply(apply_prepare_for_onnx_export_) self.apply(apply_prepare_for_onnx_export_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment