Commit 50cf3bb5 authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

Fix LevT generator interface

Summary: Revert the interface change for iterative_refinement_generator

Reviewed By: kahne

Differential Revision: D18165103

fbshipit-source-id: 075c276746eb90d7c359b6ad92e1ef25e8452bcc
parent dabbef46
...@@ -5,17 +5,16 @@ ...@@ -5,17 +5,16 @@
import torch import torch
from fairseq import utils from fairseq import utils
from fairseq.models.model_utils import ( from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
script_skip_tensor_list, from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip
skip_tensors as _skip, from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
)
class IterativeRefinementGenerator(object): class IterativeRefinementGenerator(object):
def __init__( def __init__(
self, self,
models,
tgt_dict, tgt_dict,
models=None,
eos_penalty=0.0, eos_penalty=0.0,
max_iter=10, max_iter=10,
max_ratio=2, max_ratio=2,
...@@ -73,6 +72,7 @@ class IterativeRefinementGenerator(object): ...@@ -73,6 +72,7 @@ class IterativeRefinementGenerator(object):
timer.start() timer.start()
with torch.no_grad(): with torch.no_grad():
hypos = self.generate( hypos = self.generate(
self.models,
sample, sample,
prefix_tokens=sample["target"][:, :prefix_size] prefix_tokens=sample["target"][:, :prefix_size]
if prefix_size > 0 if prefix_size > 0
...@@ -87,11 +87,15 @@ class IterativeRefinementGenerator(object): ...@@ -87,11 +87,15 @@ class IterativeRefinementGenerator(object):
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
@torch.no_grad() @torch.no_grad()
def generate(self, sample, prefix_tokens=None): def generate(self, models, sample, prefix_tokens=None):
# TODO: model ensemble if len(models) == 1:
assert len(self.models) == 1, "only support single model" # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
model = self.models[0] model = models[0]
elif isinstance(models[0], LevenshteinTransformerModel):
model = EnsembleLevT(models)
else:
raise NotImplementedError
if not self.retain_dropout: if not self.retain_dropout:
model.eval() model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment