Commit 5a2f76ed authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

NAT productionization

Summary:
NAT productionization diff

(1) Integrate NAT model training / Evaluation in LATTE base training workflow.
(2) Make NAT tracing compliant. Since it calls into Fairseq transformer, we need to refactor the code and I created a ~copy of it named fb_tracing_transformer.
(3) Decoder side C++ code is landed in the diff earlier.

Reviewed By: xianxl

Differential Revision: D17888324

fbshipit-source-id: ef4ef195fddd360da921502adcef82b087e46ce6
parent 8defa9d9
...@@ -7,6 +7,7 @@ import math ...@@ -7,6 +7,7 @@ import math
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
import torch
from torch import Tensor from torch import Tensor
from . import FairseqCriterion, register_criterion from . import FairseqCriterion, register_criterion
...@@ -44,23 +45,27 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion): ...@@ -44,23 +45,27 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
if dim is None if dim is None
else x.float().mean(dim).type_as(x) else x.float().mean(dim).type_as(x)
) )
if masks is not None: if masks is not None:
outputs, targets = outputs[masks], targets[masks] outputs, targets = outputs[masks], targets[masks]
logits = F.log_softmax(outputs, dim=-1) if not masks.any():
if targets.dim() == 1: nll_loss = torch.tensor(0)
losses = F.nll_loss(logits, targets, reduction="none")
else: # soft-labels
losses = F.kl_div(logits, targets, reduction="none")
losses = losses.float().sum(-1).type_as(losses)
nll_loss = mean_ds(losses)
if label_smoothing > 0:
loss = nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing
else:
loss = nll_loss loss = nll_loss
else:
logits = F.log_softmax(outputs, dim=-1)
if targets.dim() == 1:
losses = F.nll_loss(logits, targets.to(logits.device), reduction='none')
else: # soft-labels
losses = F.kl_div(logits, targets.to(logits.device), reduction='none')
losses = losses.sum(-1)
nll_loss = mean_ds(losses)
if label_smoothing > 0:
loss = nll_loss * (
1 - label_smoothing) - mean_ds(logits) * label_smoothing
else:
loss = nll_loss
loss = loss * factor loss = loss * factor
return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor} return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor}
......
...@@ -142,6 +142,8 @@ class LanguagePairDataset(FairseqDataset): ...@@ -142,6 +142,8 @@ class LanguagePairDataset(FairseqDataset):
target if it's absent (default: False). target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments. containing alignments.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
""" """
def __init__( def __init__(
...@@ -152,6 +154,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -152,6 +154,7 @@ class LanguagePairDataset(FairseqDataset):
shuffle=True, input_feeding=True, shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False, remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None, align_dataset=None,
append_bos=False
): ):
if tgt_dict is not None: if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad() assert src_dict.pad() == tgt_dict.pad()
...@@ -174,6 +177,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -174,6 +177,7 @@ class LanguagePairDataset(FairseqDataset):
self.align_dataset = align_dataset self.align_dataset = align_dataset
if self.align_dataset is not None: if self.align_dataset is not None:
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
self.append_bos = append_bos
def __getitem__(self, index): def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None tgt_item = self.tgt[index] if self.tgt is not None else None
...@@ -187,6 +191,15 @@ class LanguagePairDataset(FairseqDataset): ...@@ -187,6 +191,15 @@ class LanguagePairDataset(FairseqDataset):
if self.tgt and self.tgt[index][-1] != eos: if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.append_bos:
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
if self.tgt and self.tgt[index][0] != bos:
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
bos = self.src_dict.bos()
if self.src[index][-1] != bos:
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
if self.remove_eos_from_source: if self.remove_eos_from_source:
eos = self.src_dict.eos() eos = self.src_dict.eos()
if self.src[index][-1] == eos: if self.src[index][-1] == eos:
......
...@@ -4,21 +4,25 @@ ...@@ -4,21 +4,25 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from fairseq import utils
from fairseq.models.model_utils import skip_tensors as _skip from fairseq.models.model_utils import (
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT script_skip_tensor_list,
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel skip_tensors as _skip,
)
class IterativeRefinementGenerator(object): class IterativeRefinementGenerator(object):
def __init__(self, def __init__(
tgt_dict, self,
eos_penalty=0., models,
max_iter=10, tgt_dict,
max_ratio=2, eos_penalty=0.0,
decoding_format=None, max_iter=10,
retain_dropout=False, max_ratio=2,
adaptive=True): decoding_format=None,
retain_dropout=False,
adaptive=True,
):
""" """
Generates translations based on iterative refinement. Generates translations based on iterative refinement.
...@@ -42,34 +46,67 @@ class IterativeRefinementGenerator(object): ...@@ -42,34 +46,67 @@ class IterativeRefinementGenerator(object):
self.decoding_format = decoding_format self.decoding_format = decoding_format
self.retain_dropout = retain_dropout self.retain_dropout = retain_dropout
self.adaptive = adaptive self.adaptive = adaptive
self.models = models
def generate_batched_itr(
self,
data_itr,
maxlen_a=None,
maxlen_b=None,
cuda=False,
timer=None,
prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
@torch.no_grad() Args:
def generate(self, models, sample, prefix_tokens=None): maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
if len(models) == 1: for sample in data_itr:
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this. if "net_input" not in sample:
model = models[0] continue
elif isinstance(models[0], LevenshteinTransformerModel): if timer is not None:
model = EnsembleLevT(models) timer.start()
else: with torch.no_grad():
raise NotImplementedError hypos = self.generate(
sample,
prefix_tokens=sample["target"][:, :prefix_size]
if prefix_size > 0
else None,
)
if timer is not None:
timer.stop(sample["ntokens"])
for i, id in enumerate(sample["id"]):
# remove padding
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad)
ref = utils.strip_pad(sample["target"][i, :], self.pad)
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, sample, prefix_tokens=None):
# TODO: model ensemble
assert len(self.models) == 1, "only support single model"
model = self.models[0]
if not self.retain_dropout: if not self.retain_dropout:
model.eval() model.eval()
# TODO: better encoder inputs? # TODO: better encoder inputs?
src_tokens = sample['net_input']['src_tokens'] src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample['net_input']['src_lengths'] src_lengths = sample["net_input"]["src_lengths"]
bsz, src_len = src_tokens.size() bsz, src_len = src_tokens.size()
sent_idxs = torch.arange(bsz, device=src_tokens.device) sent_idxs = torch.arange(bsz)
# encoding # encoding
encoder_out = model.forward_encoder([src_tokens, src_lengths]) encoder_out = model.forward_encoder([src_tokens, src_lengths])
# initialize buffers (very model specific, with length prediction or not) # initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens( prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
encoder_out, src_tokens) prev_output_tokens = prev_decoder_out[0].clone()
prev_out_tokens = prev_decoder_out['output_tokens'].clone()
finalized = [[] for _ in range(bsz)] finalized = [[] for _ in range(bsz)]
...@@ -94,23 +131,23 @@ class IterativeRefinementGenerator(object): ...@@ -94,23 +131,23 @@ class IterativeRefinementGenerator(object):
hypo_attn = prev_out_attn[cutoff] hypo_attn = prev_out_attn[cutoff]
alignment = hypo_attn.max(dim=1)[1] alignment = hypo_attn.max(dim=1)[1]
return { return {
'steps': step, "steps": step,
'tokens': tokens, "tokens": tokens,
'positional_scores': scores, "positional_scores": scores,
'score': scores.mean(), "score": scores.mean(),
'hypo_attn': hypo_attn, "hypo_attn": hypo_attn,
'alignment': alignment, "alignment": alignment,
} }
for step in range(self.max_iter + 1): for step in range(self.max_iter + 1):
decoder_options = { decoder_options = {
'eos_penalty': self.eos_penalty, "eos_penalty": self.eos_penalty,
'max_ratio': self.max_ratio, "max_ratio": self.max_ratio,
'decoding_format': self.decoding_format "decoding_format": self.decoding_format,
} }
prev_decoder_out['step'] = step prev_decoder_out[3] = step
prev_decoder_out['max_step'] = self.max_iter + 1 prev_decoder_out[4] = self.max_iter + 1
decoder_out = model.forward_decoder( decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options prev_decoder_out, encoder_out, **decoder_options
...@@ -119,24 +156,25 @@ class IterativeRefinementGenerator(object): ...@@ -119,24 +156,25 @@ class IterativeRefinementGenerator(object):
if self.adaptive: if self.adaptive:
# terminate if there is a loop # terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop( terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_out_tokens, decoder_out['output_tokens'], prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2]
decoder_out['output_scores'], decoder_out['attn']) )
decoder_out['output_tokens'] = out_tokens decoder_out[0] = out_tokens
decoder_out['output_scores'] = out_scores decoder_out[1] = out_scores
decoder_out['attn'] = out_attn decoder_out[2] = out_attn
else: else:
terminated = decoder_out['output_tokens'].new_zeros( terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool()
decoder_out['output_tokens'].size(0)).bool()
if step == self.max_iter: # reach last iteration, terminate if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1) terminated.fill_(1)
# collect finalized sentences # collect finalized sentences
finalized_idxs = sent_idxs[terminated] finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out['output_tokens'][terminated] finalized_tokens = decoder_out[0][terminated]
finalized_scores = decoder_out['output_scores'][terminated] finalized_scores = decoder_out[1][terminated]
finalized_attn = None if decoder_out['attn'] is None else decoder_out['attn'][terminated] finalized_attn = (
None if decoder_out[2] is None else decoder_out[2][terminated]
)
for i in range(finalized_idxs.size(0)): for i in range(finalized_idxs.size(0)):
finalized[finalized_idxs[i]] = [ finalized[finalized_idxs[i]] = [
...@@ -144,7 +182,7 @@ class IterativeRefinementGenerator(object): ...@@ -144,7 +182,7 @@ class IterativeRefinementGenerator(object):
step, step,
finalized_tokens[i], finalized_tokens[i],
finalized_scores[i], finalized_scores[i],
None if finalized_attn is None else finalized_attn[i] None if finalized_attn is None else finalized_attn[i],
) )
] ]
# check if all terminated # check if all terminated
...@@ -153,9 +191,9 @@ class IterativeRefinementGenerator(object): ...@@ -153,9 +191,9 @@ class IterativeRefinementGenerator(object):
# for next step # for next step
prev_decoder_out = _skip(decoder_out, ~terminated) prev_decoder_out = _skip(decoder_out, ~terminated)
encoder_out = _skip(encoder_out, ~terminated) encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
sent_idxs = _skip(sent_idxs, ~terminated) sent_idxs = _skip(sent_idxs, ~terminated)
prev_out_tokens = prev_decoder_out['output_tokens'].clone() prev_output_tokens = prev_decoder_out[0].clone()
return finalized return finalized
This diff is collapsed.
...@@ -3,7 +3,42 @@ ...@@ -3,7 +3,42 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Dict, List
import torch import torch
from torch import Tensor
@torch.jit.script
def script_skip_tensor_list(x: List[Tensor], mask):
res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x]
outputs = []
for i, t in enumerate(res):
if t.numel() != 0:
outputs.append(t)
else:
outputs.append(x[i])
return outputs
@torch.jit.script
def script_skip_tensor(x: Tensor, mask):
# None case
if x.size(0) == 0:
return x
res = x[mask] if x.size(0) == mask.size(0) else x[:, mask]
if res.numel() == 0:
return x
else:
return res
@torch.jit.script
def script_skip_tensor_dict(x: Dict[str, Tensor], mask):
outputs = {}
for s, t in x.items():
outputs[s] = t[mask] if t.size(0) == mask.size(0) else t[:, mask]
return outputs
def skip_tensors(x, mask): def skip_tensors(x, mask):
...@@ -31,7 +66,8 @@ def skip_tensors(x, mask): ...@@ -31,7 +66,8 @@ def skip_tensors(x, mask):
raise NotImplementedError raise NotImplementedError
def expand_2d_or_3d_tensor(x, trg_dim, padding_idx): @torch.jit.script
def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
""" """
Expand 2D/3D tensor on dim=1 Expand 2D/3D tensor on dim=1
""" """
...@@ -46,18 +82,18 @@ def expand_2d_or_3d_tensor(x, trg_dim, padding_idx): ...@@ -46,18 +82,18 @@ def expand_2d_or_3d_tensor(x, trg_dim, padding_idx):
dims = [x.size(0), trg_dim - x.size(1)] dims = [x.size(0), trg_dim - x.size(1)]
if x.dim() == 3: if x.dim() == 3:
dims.append(x.size(2)) dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1) x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1)
return x return x
def fill_tensors(x, mask, y, padding_idx): @torch.jit.script
def fill_tensors(x, mask, y, padding_idx: int):
""" """
Filling tensor x with y at masked positions (dim=0). Filling tensor x with y at masked positions (dim=0).
""" """
if x is None: if x is None or x.size()[0] == 0:
return None return torch.empty([0])
assert x.dim() == y.dim() and mask.size(0) == x.size(0) assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
...@@ -72,7 +108,7 @@ def fill_tensors(x, mask, y, padding_idx): ...@@ -72,7 +108,7 @@ def fill_tensors(x, mask, y, padding_idx):
x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx) x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx)
x[mask] = y x[mask] = y
elif x.size(1) > y.size(1): elif x.size(1) > y.size(1):
x[mask] = padding_idx x[mask] = torch.tensor(padding_idx)
if x.dim() == 2: if x.dim() == 2:
x[mask, :y.size(1)] = y x[mask, :y.size(1)] = y
else: else:
...@@ -80,3 +116,88 @@ def fill_tensors(x, mask, y, padding_idx): ...@@ -80,3 +116,88 @@ def fill_tensors(x, mask, y, padding_idx):
else: else:
x[mask] = y x[mask] = y
return x return x
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len, device=out_lengths.device)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len, device=in_tokens.device)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
from fairseq.models.levenshtein_transformer import _apply_del_words, _apply_ins_masks, _apply_ins_words from fairseq.models.model_utils import _apply_del_words, _apply_ins_masks, _apply_ins_words
class BasicEnsembleModel(torch.nn.Module): class BasicEnsembleModel(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment