Commit 5a2f76ed authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

NAT productionization

Summary:
NAT productionization diff

(1) Integrate NAT model training / Evaluation in LATTE base training workflow.
(2) Make NAT tracing compliant. Since it calls into Fairseq transformer, we need to refactor the code and I created a ~copy of it named fb_tracing_transformer.
(3) Decoder side C++ code is landed in the diff earlier.

Reviewed By: xianxl

Differential Revision: D17888324

fbshipit-source-id: ef4ef195fddd360da921502adcef82b087e46ce6
parent 8defa9d9
...@@ -7,6 +7,7 @@ import math ...@@ -7,6 +7,7 @@ import math
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
import torch
from torch import Tensor from torch import Tensor
from . import FairseqCriterion, register_criterion from . import FairseqCriterion, register_criterion
...@@ -44,21 +45,25 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion): ...@@ -44,21 +45,25 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
if dim is None if dim is None
else x.float().mean(dim).type_as(x) else x.float().mean(dim).type_as(x)
) )
if masks is not None: if masks is not None:
outputs, targets = outputs[masks], targets[masks] outputs, targets = outputs[masks], targets[masks]
if not masks.any():
nll_loss = torch.tensor(0)
loss = nll_loss
else:
logits = F.log_softmax(outputs, dim=-1) logits = F.log_softmax(outputs, dim=-1)
if targets.dim() == 1: if targets.dim() == 1:
losses = F.nll_loss(logits, targets, reduction="none") losses = F.nll_loss(logits, targets.to(logits.device), reduction='none')
else: # soft-labels else: # soft-labels
losses = F.kl_div(logits, targets, reduction="none") losses = F.kl_div(logits, targets.to(logits.device), reduction='none')
losses = losses.float().sum(-1).type_as(losses) losses = losses.sum(-1)
nll_loss = mean_ds(losses) nll_loss = mean_ds(losses)
if label_smoothing > 0: if label_smoothing > 0:
loss = nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing loss = nll_loss * (
1 - label_smoothing) - mean_ds(logits) * label_smoothing
else: else:
loss = nll_loss loss = nll_loss
......
...@@ -142,6 +142,8 @@ class LanguagePairDataset(FairseqDataset): ...@@ -142,6 +142,8 @@ class LanguagePairDataset(FairseqDataset):
target if it's absent (default: False). target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments. containing alignments.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
""" """
def __init__( def __init__(
...@@ -152,6 +154,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -152,6 +154,7 @@ class LanguagePairDataset(FairseqDataset):
shuffle=True, input_feeding=True, shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False, remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None, align_dataset=None,
append_bos=False
): ):
if tgt_dict is not None: if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad() assert src_dict.pad() == tgt_dict.pad()
...@@ -174,6 +177,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -174,6 +177,7 @@ class LanguagePairDataset(FairseqDataset):
self.align_dataset = align_dataset self.align_dataset = align_dataset
if self.align_dataset is not None: if self.align_dataset is not None:
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
self.append_bos = append_bos
def __getitem__(self, index): def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None tgt_item = self.tgt[index] if self.tgt is not None else None
...@@ -187,6 +191,15 @@ class LanguagePairDataset(FairseqDataset): ...@@ -187,6 +191,15 @@ class LanguagePairDataset(FairseqDataset):
if self.tgt and self.tgt[index][-1] != eos: if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.append_bos:
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
if self.tgt and self.tgt[index][0] != bos:
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
bos = self.src_dict.bos()
if self.src[index][-1] != bos:
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
if self.remove_eos_from_source: if self.remove_eos_from_source:
eos = self.src_dict.eos() eos = self.src_dict.eos()
if self.src[index][-1] == eos: if self.src[index][-1] == eos:
......
...@@ -4,21 +4,25 @@ ...@@ -4,21 +4,25 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from fairseq import utils
from fairseq.models.model_utils import skip_tensors as _skip from fairseq.models.model_utils import (
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT script_skip_tensor_list,
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel skip_tensors as _skip,
)
class IterativeRefinementGenerator(object): class IterativeRefinementGenerator(object):
def __init__(self, def __init__(
self,
models,
tgt_dict, tgt_dict,
eos_penalty=0., eos_penalty=0.0,
max_iter=10, max_iter=10,
max_ratio=2, max_ratio=2,
decoding_format=None, decoding_format=None,
retain_dropout=False, retain_dropout=False,
adaptive=True): adaptive=True,
):
""" """
Generates translations based on iterative refinement. Generates translations based on iterative refinement.
...@@ -42,34 +46,67 @@ class IterativeRefinementGenerator(object): ...@@ -42,34 +46,67 @@ class IterativeRefinementGenerator(object):
self.decoding_format = decoding_format self.decoding_format = decoding_format
self.retain_dropout = retain_dropout self.retain_dropout = retain_dropout
self.adaptive = adaptive self.adaptive = adaptive
self.models = models
def generate_batched_itr(
self,
data_itr,
maxlen_a=None,
maxlen_b=None,
cuda=False,
timer=None,
prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
@torch.no_grad() Args:
def generate(self, models, sample, prefix_tokens=None): maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
if len(models) == 1: for sample in data_itr:
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this. if "net_input" not in sample:
model = models[0] continue
elif isinstance(models[0], LevenshteinTransformerModel): if timer is not None:
model = EnsembleLevT(models) timer.start()
else: with torch.no_grad():
raise NotImplementedError hypos = self.generate(
sample,
prefix_tokens=sample["target"][:, :prefix_size]
if prefix_size > 0
else None,
)
if timer is not None:
timer.stop(sample["ntokens"])
for i, id in enumerate(sample["id"]):
# remove padding
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad)
ref = utils.strip_pad(sample["target"][i, :], self.pad)
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, sample, prefix_tokens=None):
# TODO: model ensemble
assert len(self.models) == 1, "only support single model"
model = self.models[0]
if not self.retain_dropout: if not self.retain_dropout:
model.eval() model.eval()
# TODO: better encoder inputs? # TODO: better encoder inputs?
src_tokens = sample['net_input']['src_tokens'] src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample['net_input']['src_lengths'] src_lengths = sample["net_input"]["src_lengths"]
bsz, src_len = src_tokens.size() bsz, src_len = src_tokens.size()
sent_idxs = torch.arange(bsz, device=src_tokens.device) sent_idxs = torch.arange(bsz)
# encoding # encoding
encoder_out = model.forward_encoder([src_tokens, src_lengths]) encoder_out = model.forward_encoder([src_tokens, src_lengths])
# initialize buffers (very model specific, with length prediction or not) # initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens( prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
encoder_out, src_tokens) prev_output_tokens = prev_decoder_out[0].clone()
prev_out_tokens = prev_decoder_out['output_tokens'].clone()
finalized = [[] for _ in range(bsz)] finalized = [[] for _ in range(bsz)]
...@@ -94,23 +131,23 @@ class IterativeRefinementGenerator(object): ...@@ -94,23 +131,23 @@ class IterativeRefinementGenerator(object):
hypo_attn = prev_out_attn[cutoff] hypo_attn = prev_out_attn[cutoff]
alignment = hypo_attn.max(dim=1)[1] alignment = hypo_attn.max(dim=1)[1]
return { return {
'steps': step, "steps": step,
'tokens': tokens, "tokens": tokens,
'positional_scores': scores, "positional_scores": scores,
'score': scores.mean(), "score": scores.mean(),
'hypo_attn': hypo_attn, "hypo_attn": hypo_attn,
'alignment': alignment, "alignment": alignment,
} }
for step in range(self.max_iter + 1): for step in range(self.max_iter + 1):
decoder_options = { decoder_options = {
'eos_penalty': self.eos_penalty, "eos_penalty": self.eos_penalty,
'max_ratio': self.max_ratio, "max_ratio": self.max_ratio,
'decoding_format': self.decoding_format "decoding_format": self.decoding_format,
} }
prev_decoder_out['step'] = step prev_decoder_out[3] = step
prev_decoder_out['max_step'] = self.max_iter + 1 prev_decoder_out[4] = self.max_iter + 1
decoder_out = model.forward_decoder( decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options prev_decoder_out, encoder_out, **decoder_options
...@@ -119,24 +156,25 @@ class IterativeRefinementGenerator(object): ...@@ -119,24 +156,25 @@ class IterativeRefinementGenerator(object):
if self.adaptive: if self.adaptive:
# terminate if there is a loop # terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop( terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_out_tokens, decoder_out['output_tokens'], prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2]
decoder_out['output_scores'], decoder_out['attn']) )
decoder_out['output_tokens'] = out_tokens decoder_out[0] = out_tokens
decoder_out['output_scores'] = out_scores decoder_out[1] = out_scores
decoder_out['attn'] = out_attn decoder_out[2] = out_attn
else: else:
terminated = decoder_out['output_tokens'].new_zeros( terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool()
decoder_out['output_tokens'].size(0)).bool()
if step == self.max_iter: # reach last iteration, terminate if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1) terminated.fill_(1)
# collect finalized sentences # collect finalized sentences
finalized_idxs = sent_idxs[terminated] finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out['output_tokens'][terminated] finalized_tokens = decoder_out[0][terminated]
finalized_scores = decoder_out['output_scores'][terminated] finalized_scores = decoder_out[1][terminated]
finalized_attn = None if decoder_out['attn'] is None else decoder_out['attn'][terminated] finalized_attn = (
None if decoder_out[2] is None else decoder_out[2][terminated]
)
for i in range(finalized_idxs.size(0)): for i in range(finalized_idxs.size(0)):
finalized[finalized_idxs[i]] = [ finalized[finalized_idxs[i]] = [
...@@ -144,7 +182,7 @@ class IterativeRefinementGenerator(object): ...@@ -144,7 +182,7 @@ class IterativeRefinementGenerator(object):
step, step,
finalized_tokens[i], finalized_tokens[i],
finalized_scores[i], finalized_scores[i],
None if finalized_attn is None else finalized_attn[i] None if finalized_attn is None else finalized_attn[i],
) )
] ]
# check if all terminated # check if all terminated
...@@ -153,9 +191,9 @@ class IterativeRefinementGenerator(object): ...@@ -153,9 +191,9 @@ class IterativeRefinementGenerator(object):
# for next step # for next step
prev_decoder_out = _skip(decoder_out, ~terminated) prev_decoder_out = _skip(decoder_out, ~terminated)
encoder_out = _skip(encoder_out, ~terminated) encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
sent_idxs = _skip(sent_idxs, ~terminated) sent_idxs = _skip(sent_idxs, ~terminated)
prev_out_tokens = prev_decoder_out['output_tokens'].clone() prev_output_tokens = prev_decoder_out[0].clone()
return finalized return finalized
# Copyright (c) Facebook, Inc. and its affiliates. #!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the license found in the LICENSE file in
# LICENSE file in the root directory of this source tree. # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from __future__ import absolute_import, division, print_function, unicode_literals
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip from fairseq.models.fb_tracing_transformer import (
from fairseq.models.transformer import ( TracingTransformerDecoder,
Embedding, TracingTransformerEncoder,
TransformerDecoder, TracingTransformerModel,
TransformerEncoder, TransformerDecoderLayer,
TransformerModel, )
TransformerDecoderLayer from fairseq.models.model_utils import (
fill_tensors as _fill,
script_skip_tensor,
script_skip_tensor_list,
) )
from fairseq.models.transformer import Embedding
from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.modules.transformer_sentence_encoder import init_bert_params
from torch import Tensor
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
...@@ -24,17 +37,16 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): ...@@ -24,17 +37,16 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
from fairseq import libnat from fairseq import libnat
except ImportError as e: except ImportError as e:
import sys import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
in_tokens_list = [ in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
] ]
out_tokens_list = [ out_tokens_list = [
[t for t in s if t != padding_idx] [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
for i, s in enumerate(out_tokens.tolist())
] ]
full_labels = libnat.suggested_ed2_path( full_labels = libnat.suggested_ed2_path(
...@@ -59,9 +71,7 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): ...@@ -59,9 +71,7 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
] ]
# transform to tensor # transform to tensor
masked_tgt_masks = torch.tensor( masked_tgt_masks = torch.tensor(masked_tgt_masks, device=out_tokens.device).bool()
masked_tgt_masks, device=out_tokens.device
).bool()
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
...@@ -72,17 +82,16 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx): ...@@ -72,17 +82,16 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx):
from fairseq import libnat from fairseq import libnat
except ImportError as e: except ImportError as e:
import sys import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e raise e
out_seq_len = out_tokens.size(1) out_seq_len = out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
in_tokens_list = [ in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
] ]
out_tokens_list = [ out_tokens_list = [
[t for t in s if t != padding_idx] [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
for i, s in enumerate(out_tokens.tolist())
] ]
full_labels = libnat.suggested_ed2_path( full_labels = libnat.suggested_ed2_path(
...@@ -95,7 +104,7 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx): ...@@ -95,7 +104,7 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx):
] ]
# transform to tensor # transform to tensor
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) word_del_targets = torch.tensor(word_del_targets)
return word_del_targets return word_del_targets
...@@ -104,17 +113,16 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): ...@@ -104,17 +113,16 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
from fairseq import libnat from fairseq import libnat
except ImportError as e: except ImportError as e:
import sys import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
in_tokens_list = [ in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
] ]
out_tokens_list = [ out_tokens_list = [
[t for t in s if t != padding_idx] [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
for i, s in enumerate(out_tokens.tolist())
] ]
full_labels = libnat.suggested_ed2_path( full_labels = libnat.suggested_ed2_path(
...@@ -136,96 +144,13 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): ...@@ -136,96 +144,13 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
] ]
# transform to tensor # transform to tensor
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) mask_ins_targets = torch.tensor(mask_ins_targets)
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) word_del_targets = torch.tensor(word_del_targets)
return word_del_targets, mask_ins_targets return word_del_targets, mask_ins_targets
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
new_arange(out_lengths, out_max_len)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
new_arange(in_tokens)
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
@register_model("levenshtein_transformer") @register_model("levenshtein_transformer")
class LevenshteinTransformerModel(TransformerModel): class LevenshteinTransformerModel(TracingTransformerModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
self.tgt_dict = decoder.dictionary self.tgt_dict = decoder.dictionary
...@@ -236,7 +161,7 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -236,7 +161,7 @@ class LevenshteinTransformerModel(TransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
TransformerModel.add_args(parser) TracingTransformerModel.add_args(parser)
parser.add_argument( parser.add_argument(
"--apply-bert-init", "--apply-bert-init",
action="store_true", action="store_true",
...@@ -260,8 +185,17 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -260,8 +185,17 @@ class LevenshteinTransformerModel(TransformerModel):
) )
parser.add_argument( parser.add_argument(
"--sampling-for-deletion", "--sampling-for-deletion",
action='store_true', action="store_true",
help='instead of argmax, use sampling to predict the tokens' help="instead of argmax, use sampling to predict the tokens",
)
# Added for compatibility
parser.add_argument(
"--decoder-out-embed-dim",
default=None,
type=int,
metavar="N",
help="decoder output embedding dimension (bottleneck layer before"
"output layer if specified.)",
) )
@classmethod @classmethod
...@@ -273,7 +207,7 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -273,7 +207,7 @@ class LevenshteinTransformerModel(TransformerModel):
@classmethod @classmethod
def build_encoder(cls, args, src_dict, embed_tokens): def build_encoder(cls, args, src_dict, embed_tokens):
encoder = TransformerEncoder(args, src_dict, embed_tokens) encoder = TracingTransformerEncoder(args, src_dict, embed_tokens)
if getattr(args, "apply_bert_init", False): if getattr(args, "apply_bert_init", False):
encoder.apply(init_bert_params) encoder.apply(init_bert_params)
return encoder return encoder
...@@ -304,8 +238,8 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -304,8 +238,8 @@ class LevenshteinTransformerModel(TransformerModel):
# make online prediction # make online prediction
if self.decoder.sampling_for_deletion: if self.decoder.sampling_for_deletion:
word_predictions = torch.multinomial( word_predictions = torch.multinomial(
F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1).view( F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1
word_ins_out.size(0), -1) ).view(word_ins_out.size(0), -1)
else: else:
word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1] word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1]
...@@ -315,9 +249,7 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -315,9 +249,7 @@ class LevenshteinTransformerModel(TransformerModel):
# generate training labels for deletion # generate training labels for deletion
word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad) word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad)
word_del_out, _ = self.decoder.forward_word_del( word_del_out, _ = self.decoder.forward_word_del(word_predictions, encoder_out)
word_predictions, encoder_out)
return { return {
"mask_ins_out": mask_ins_out, "mask_ins_out": mask_ins_out,
"mask_ins_tgt": mask_ins_targets, "mask_ins_tgt": mask_ins_targets,
...@@ -337,123 +269,246 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -337,123 +269,246 @@ class LevenshteinTransformerModel(TransformerModel):
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
): ):
output_tokens = decoder_out["output_tokens"] output_tokens = decoder_out[0]
output_scores = decoder_out["output_scores"] output_scores = decoder_out[1]
attn = decoder_out["attn"] attn = decoder_out[2]
bsz = output_tokens.size(0) if max_ratio is not None and encoder_out[1] is not None:
if max_ratio is None: max_lengths = ((~encoder_out[1]).sum(1) * max_ratio).clamp(min=10)
max_lens = output_tokens.new().fill_(255)
else:
if encoder_out["encoder_padding_mask"] is None:
max_src_len = encoder_out["encoder_out"].size(1)
src_lens = encoder_out["encoder_out"].new(bsz).fill_(max_src_len)
else:
src_lens = (~encoder_out["encoder_padding_mask"]).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
else:
max_lengths = torch.zeros(output_tokens.size(0)).fill_(255)
@torch.jit.script
def del_word(
output_tokens,
output_scores,
attn: Tensor,
word_del_attn: Optional[Tensor],
word_del_pred,
can_del_word,
pad_idx: int,
bos_idx: int,
eos_idx: int,
):
# delete words # delete words
# do not delete tokens if it is <s> </s> # do not delete tokens if it is <s> </s>
can_del_word = output_tokens.ne(self.pad).sum(1) > 2
if can_del_word.sum() != 0: # we cannot delete, skip if can_del_word.sum() != 0: # we cannot delete, skip
in_tokens = output_tokens[can_del_word]
in_scores = output_scores[can_del_word]
# apply deletion to a tensor
in_masks = in_tokens.ne(pad_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill(word_del_pred, max_len)
.sort(1)[1]
)
_tokens = in_tokens.masked_fill(word_del_pred, pad_idx).gather(
1, reordering
)
_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
if word_del_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(word_del_attn)
_reordering = reordering[:, :, None].expand_as(word_del_attn)
_attn = word_del_attn.masked_fill(_mask, 0.0).gather(1, _reordering)
attn = _fill(attn, can_del_word, _attn, 0)
output_tokens = _fill(output_tokens, can_del_word, _tokens, pad_idx)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
return output_tokens, output_scores, attn
@torch.jit.script
def ins_placeholders(
output_tokens,
output_scores,
mask_ins_pred,
can_ins_mask,
pad_idx: int,
unk_idx: int,
eos_idx: int,
):
# insert placeholders
if can_ins_mask.sum() != 0:
in_tokens = output_tokens[can_ins_mask]
in_scores = output_scores[can_ins_mask]
in_masks = in_tokens.ne(pad_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len)[None, :].long() < out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
torch.zeros(in_tokens.size()[0], out_max_len)
.fill_(pad_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1)
out_tokens.scatter_(1, reordering, in_tokens[:, 1:].float())
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = torch.zeros_like(out_tokens).to(in_scores)
out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1)
out_scores.scatter_(1, reordering, in_scores[:, 1:])
else:
out_scores = None
output_tokens = _fill(output_tokens, can_ins_mask, out_tokens, pad_idx)
output_scores = _fill(output_scores, can_ins_mask, out_scores, 0)
return output_tokens, output_scores
@torch.jit.script
def ins_words(
output_tokens,
output_scores,
attn: Tensor,
word_ins_attn,
word_ins_pred,
word_ins_scores,
can_ins_word,
pad_idx: int,
unk_idx: int,
):
# insert words
if can_ins_word.sum() != 0:
in_tokens = output_tokens[can_ins_word]
in_scores = output_scores[can_ins_word]
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(
word_ins_masks, word_ins_pred[word_ins_masks].float()
)
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
output_tokens = _fill(output_tokens, can_ins_word, out_tokens, pad_idx)
output_scores = _fill(output_scores, can_ins_word, out_scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0)
return output_tokens, output_scores, attn
can_del_word = output_tokens.ne(self.pad).sum(1) > 2
word_del_out, word_del_attn = self.decoder.forward_word_del( word_del_out, word_del_attn = self.decoder.forward_word_del(
_skip(output_tokens, can_del_word), _skip(encoder_out, can_del_word) script_skip_tensor(output_tokens, can_del_word),
script_skip_tensor_list(list(encoder_out), can_del_word),
) )
word_del_score = F.log_softmax(word_del_out, 2) word_del_score = F.log_softmax(word_del_out, 2)
word_del_pred = word_del_score.max(-1)[1].bool() word_del_pred = word_del_score.max(-1)[1].bool()
_tokens, _scores, _attn = _apply_del_words( output_tokens, output_scores, attn = del_word(
output_tokens[can_del_word], output_tokens,
output_scores[can_del_word], output_scores,
attn,
word_del_attn, word_del_attn,
word_del_pred, word_del_pred,
can_del_word,
self.pad, self.pad,
self.bos, self.bos,
self.eos, self.eos,
) )
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
attn = _fill(attn, can_del_word, _attn, 0.)
# insert placeholders can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lengths
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
if can_ins_mask.sum() != 0:
mask_ins_out, _ = self.decoder.forward_mask_ins( mask_ins_out, _ = self.decoder.forward_mask_ins(
_skip(output_tokens, can_ins_mask), _skip(encoder_out, can_ins_mask) script_skip_tensor(output_tokens, can_ins_mask),
script_skip_tensor_list(encoder_out, can_ins_mask),
) )
mask_ins_score = F.log_softmax(mask_ins_out, 2) mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0: if eos_penalty > 0.0:
mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1] mask_ins_pred = mask_ins_score.max(-1)[1]
if max_ratio is not None and encoder_out[1] is not None:
mask_ins_pred = torch.min( mask_ins_pred = torch.min(
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) mask_ins_pred, max_lengths[can_ins_mask, None].expand_as(mask_ins_pred)
) )
_tokens, _scores = _apply_ins_masks( output_tokens, output_scores = ins_placeholders(
output_tokens[can_ins_mask], output_tokens,
output_scores[can_ins_mask], output_scores,
mask_ins_pred, mask_ins_pred,
can_ins_mask,
self.pad, self.pad,
self.unk, self.unk,
self.eos, self.eos,
) )
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
# insert words
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
if can_ins_word.sum() != 0:
word_ins_out, word_ins_attn = self.decoder.forward_word_ins( word_ins_out, word_ins_attn = self.decoder.forward_word_ins(
_skip(output_tokens, can_ins_word), _skip(encoder_out, can_ins_word) script_skip_tensor(output_tokens, can_ins_word),
script_skip_tensor_list(encoder_out, can_ins_word),
) )
word_ins_score = F.log_softmax(word_ins_out, 2)
word_ins_pred = word_ins_score.max(-1)[1]
word_ins_score, word_ins_pred = F.log_softmax(word_ins_out, 2).max(-1) output_tokens, output_scores, attn = ins_words(
output_tokens,
_tokens, _scores = _apply_ins_words( output_scores,
output_tokens[can_ins_word], attn,
output_scores[can_ins_word], word_ins_attn,
word_ins_pred, word_ins_pred,
word_ins_score, word_ins_score,
can_ins_word,
self.pad,
self.unk, self.unk,
) )
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
# delete some unnecessary paddings # delete some unnecessary paddings
cut_off = output_tokens.ne(self.pad).sum(1).max() cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off] @torch.jit.script
attn = None if attn is None else attn[:, :cut_off, :] def slice_wrap(x, l):
return { return x[:, :l]
"output_tokens": output_tokens,
"output_scores": output_scores, @torch.jit.script
"attn": attn, def slice_wrap_attn(x, l):
} return x if x.size()[0] == 0 else x[:, :l, :]
output_tokens = slice_wrap(output_tokens, cut_off)
output_scores = slice_wrap(output_scores, cut_off)
attn = slice_wrap(attn, cut_off)
return [output_tokens, output_scores, attn, 0, 0]
def initialize_output_tokens(self, encoder_out, src_tokens): def initialize_output_tokens(self, encoder_out, src_tokens):
initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2) initial_output_tokens = torch.cat(
initial_output_tokens[:, 0] = self.bos [
initial_output_tokens[:, 1] = self.eos torch.zeros(src_tokens.size(0), 1).fill_(self.bos),
torch.zeros(src_tokens.size(0), 1).fill_(self.eos),
],
1,
)
initial_output_scores = initial_output_tokens.new_zeros( initial_output_scores = torch.zeros_like(initial_output_tokens).to(
*initial_output_tokens.size() encoder_out[0]
).type_as(encoder_out["encoder_out"]) )
initial_attn = None initial_attn = torch.empty([0])
if getattr(self.decoder.layers[-1], "need_attn", False): if getattr(self.decoder.layers[-1], "need_attn", True):
initial_attn = initial_output_tokens.new_zeros( initial_attn = torch.zeros([src_tokens.size(0), 2, src_tokens.size(1)]).to(
src_tokens.size(0), 2, src_tokens.size(1) initial_output_tokens
) )
return {
"output_tokens": initial_output_tokens, return [initial_output_tokens, initial_output_scores, initial_attn, 0, 0]
"output_scores": initial_output_scores,
"attn": initial_attn,
}
class LevenshteinTransformerDecoder(TransformerDecoder): class LevenshteinTransformerDecoder(TracingTransformerDecoder):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__( super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
...@@ -467,25 +522,34 @@ class LevenshteinTransformerDecoder(TransformerDecoder): ...@@ -467,25 +522,34 @@ class LevenshteinTransformerDecoder(TransformerDecoder):
self.embed_word_del = Embedding(2, self.output_embed_dim, None) self.embed_word_del = Embedding(2, self.output_embed_dim, None)
# del_word, ins_mask, ins_word # del_word, ins_mask, ins_word
self.early_exit = [int(i) for i in args.early_exit.split(',')] self.early_exit = [int(i) for i in args.early_exit.split(",")]
assert len(self.early_exit) == 3 assert len(self.early_exit) == 3
# copy layers for mask-predict/deletion # copy layers for mask-predict/deletion
self.layers_msk = None self.layers_msk = None
if getattr(args, "no_share_maskpredictor", False): if getattr(args, "no_share_maskpredictor", False):
self.layers_msk = nn.ModuleList([ self.layers_msk = nn.ModuleList(
[
TransformerDecoderLayer(args, no_encoder_attn) TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(self.early_exit[1]) for _ in range(self.early_exit[1])
]) ]
)
self.layers_del = None self.layers_del = None
if getattr(args, "no_share_discriminator", False): if getattr(args, "no_share_discriminator", False):
self.layers_del = nn.ModuleList([ self.layers_del = nn.ModuleList(
[
TransformerDecoderLayer(args, no_encoder_attn) TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(self.early_exit[0]) for _ in range(self.early_exit[0])
]) ]
)
def extract_features( def extract_features(
self, prev_output_tokens, encoder_out=None, early_exit=None, layers=None, **unused self,
prev_output_tokens,
encoder_out=None,
early_exit=None,
layers=None,
**unused
): ):
""" """
Similar to *forward* but only return features. Similar to *forward* but only return features.
...@@ -508,7 +572,7 @@ class LevenshteinTransformerDecoder(TransformerDecoder): ...@@ -508,7 +572,7 @@ class LevenshteinTransformerDecoder(TransformerDecoder):
) )
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens) x = self.embed_scale * self.embed_tokens(prev_output_tokens.long())
if self.project_in_dim is not None: if self.project_in_dim is not None:
x = self.project_in_dim(x) x = self.project_in_dim(x)
...@@ -525,13 +589,11 @@ class LevenshteinTransformerDecoder(TransformerDecoder): ...@@ -525,13 +589,11 @@ class LevenshteinTransformerDecoder(TransformerDecoder):
decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
layers = self.layers if layers is None else layers layers = self.layers if layers is None else layers
early_exit = len(layers) if early_exit is None else early_exit early_exit = len(layers) if early_exit is None else early_exit
for _, layer in enumerate(layers[: early_exit]): for _, layer in enumerate(layers[:early_exit]):
x, attn = layer( x, attn = layer(
x, x,
encoder_out["encoder_out"] if encoder_out is not None else None, encoder_out[0] if encoder_out is not None else None,
encoder_out["encoder_padding_mask"] encoder_out[1] if encoder_out is not None else None,
if encoder_out is not None
else None,
self_attn_mask=None, self_attn_mask=None,
self_attn_padding_mask=decoder_padding_mask, self_attn_padding_mask=decoder_padding_mask,
) )
...@@ -546,26 +608,38 @@ class LevenshteinTransformerDecoder(TransformerDecoder): ...@@ -546,26 +608,38 @@ class LevenshteinTransformerDecoder(TransformerDecoder):
if self.project_out_dim is not None: if self.project_out_dim is not None:
x = self.project_out_dim(x) x = self.project_out_dim(x)
return x, {"attn": attn, "inner_states": inner_states} return x, attn, inner_states
def forward_mask_ins(self, prev_output_tokens, encoder_out=None, **unused): def forward_mask_ins(self, prev_output_tokens, encoder_out=None, **unused):
features, extra = self.extract_features( features, attn, _ = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1], layers=self.layers_msk, **unused prev_output_tokens,
encoder_out=encoder_out,
early_exit=self.early_exit[1],
layers=self.layers_msk,
**unused
) )
features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn'] return F.linear(features_cat, self.embed_mask_ins.weight), attn
def forward_word_ins(self, prev_output_tokens, encoder_out=None, **unused): def forward_word_ins(self, prev_output_tokens, encoder_out=None, **unused):
features, extra = self.extract_features( features, attn, _ = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2], layers=self.layers, **unused prev_output_tokens,
encoder_out=encoder_out,
early_exit=self.early_exit[2],
layers=self.layers,
**unused
) )
return self.output_layer(features), extra['attn'] return self.output_layer(features), attn
def forward_word_del(self, prev_output_tokens, encoder_out=None, **unused): def forward_word_del(self, prev_output_tokens, encoder_out=None, **unused):
features, extra = self.extract_features( features, attn, _ = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0], layers=self.layers_del, **unused prev_output_tokens,
encoder_out=encoder_out,
early_exit=self.early_exit[0],
layers=self.layers_del,
**unused
) )
return F.linear(features, self.embed_word_del.weight), extra['attn'] return F.linear(features, self.embed_word_del.weight), attn
@register_model_architecture("levenshtein_transformer", "levenshtein_transformer") @register_model_architecture("levenshtein_transformer", "levenshtein_transformer")
...@@ -595,7 +669,7 @@ def base_architecture(args): ...@@ -595,7 +669,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
......
...@@ -3,7 +3,42 @@ ...@@ -3,7 +3,42 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Dict, List
import torch import torch
from torch import Tensor
@torch.jit.script
def script_skip_tensor_list(x: List[Tensor], mask):
res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x]
outputs = []
for i, t in enumerate(res):
if t.numel() != 0:
outputs.append(t)
else:
outputs.append(x[i])
return outputs
@torch.jit.script
def script_skip_tensor(x: Tensor, mask):
# None case
if x.size(0) == 0:
return x
res = x[mask] if x.size(0) == mask.size(0) else x[:, mask]
if res.numel() == 0:
return x
else:
return res
@torch.jit.script
def script_skip_tensor_dict(x: Dict[str, Tensor], mask):
outputs = {}
for s, t in x.items():
outputs[s] = t[mask] if t.size(0) == mask.size(0) else t[:, mask]
return outputs
def skip_tensors(x, mask): def skip_tensors(x, mask):
...@@ -31,7 +66,8 @@ def skip_tensors(x, mask): ...@@ -31,7 +66,8 @@ def skip_tensors(x, mask):
raise NotImplementedError raise NotImplementedError
def expand_2d_or_3d_tensor(x, trg_dim, padding_idx): @torch.jit.script
def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
""" """
Expand 2D/3D tensor on dim=1 Expand 2D/3D tensor on dim=1
""" """
...@@ -46,18 +82,18 @@ def expand_2d_or_3d_tensor(x, trg_dim, padding_idx): ...@@ -46,18 +82,18 @@ def expand_2d_or_3d_tensor(x, trg_dim, padding_idx):
dims = [x.size(0), trg_dim - x.size(1)] dims = [x.size(0), trg_dim - x.size(1)]
if x.dim() == 3: if x.dim() == 3:
dims.append(x.size(2)) dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1) x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1)
return x return x
def fill_tensors(x, mask, y, padding_idx): @torch.jit.script
def fill_tensors(x, mask, y, padding_idx: int):
""" """
Filling tensor x with y at masked positions (dim=0). Filling tensor x with y at masked positions (dim=0).
""" """
if x is None: if x is None or x.size()[0] == 0:
return None return torch.empty([0])
assert x.dim() == y.dim() and mask.size(0) == x.size(0) assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
...@@ -72,7 +108,7 @@ def fill_tensors(x, mask, y, padding_idx): ...@@ -72,7 +108,7 @@ def fill_tensors(x, mask, y, padding_idx):
x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx) x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx)
x[mask] = y x[mask] = y
elif x.size(1) > y.size(1): elif x.size(1) > y.size(1):
x[mask] = padding_idx x[mask] = torch.tensor(padding_idx)
if x.dim() == 2: if x.dim() == 2:
x[mask, :y.size(1)] = y x[mask, :y.size(1)] = y
else: else:
...@@ -80,3 +116,88 @@ def fill_tensors(x, mask, y, padding_idx): ...@@ -80,3 +116,88 @@ def fill_tensors(x, mask, y, padding_idx):
else: else:
x[mask] = y x[mask] = y
return x return x
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len, device=out_lengths.device)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len, device=in_tokens.device)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
from fairseq.models.levenshtein_transformer import _apply_del_words, _apply_ins_masks, _apply_ins_words from fairseq.models.model_utils import _apply_del_words, _apply_ins_masks, _apply_ins_words
class BasicEnsembleModel(torch.nn.Module): class BasicEnsembleModel(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment