Commit 66d24dc2 authored by Jiatao Gu's avatar Jiatao Gu Committed by Facebook Github Bot
Browse files

Enable separate models for insertion and deletion;

Summary:
The Diff conatins two fixes:
(1) enabling non-shared decoder layers for deletion/insertion
(2) adding options to perform sampling instead of argmax when learning the deletion

Reviewed By: kahne

Differential Revision: D18011220

fbshipit-source-id: c60815fb7bc3a0004c81249504f7a641536ae2d8
parent a3c629b5
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.utils import new_arange from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
...@@ -13,6 +14,7 @@ from fairseq.models.transformer import ( ...@@ -13,6 +14,7 @@ from fairseq.models.transformer import (
TransformerDecoder, TransformerDecoder,
TransformerEncoder, TransformerEncoder,
TransformerModel, TransformerModel,
TransformerDecoderLayer
) )
from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.modules.transformer_sentence_encoder import init_bert_params
...@@ -24,7 +26,6 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): ...@@ -24,7 +26,6 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
import sys import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n') sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens): with torch.cuda.device_of(in_tokens):
...@@ -73,7 +74,6 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx): ...@@ -73,7 +74,6 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx):
import sys import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n') sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e raise e
out_seq_len = out_tokens.size(1) out_seq_len = out_tokens.size(1)
with torch.cuda.device_of(in_tokens): with torch.cuda.device_of(in_tokens):
...@@ -106,7 +106,6 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): ...@@ -106,7 +106,6 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
import sys import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n') sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens): with torch.cuda.device_of(in_tokens):
...@@ -247,7 +246,22 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -247,7 +246,22 @@ class LevenshteinTransformerModel(TransformerModel):
"--early-exit", "--early-exit",
default="6,6,6", default="6,6,6",
type=str, type=str,
help="number of decoder layers before mask_ins, word_ins and word_del heads", help="number of decoder layers for del_word, ins_mask, ins_word",
)
parser.add_argument(
"--no-share-discriminator",
action="store_true",
help="addtional decoder-layers to learn deletion",
)
parser.add_argument(
"--no-share-maskpredictor",
action="store_true",
help="addtional decoder-layers to learn predicting masks",
)
parser.add_argument(
"--sampling-for-deletion",
action='store_true',
help='instead of argmax, use sampling to predict the tokens'
) )
@classmethod @classmethod
...@@ -288,7 +302,13 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -288,7 +302,13 @@ class LevenshteinTransformerModel(TransformerModel):
) )
# make online prediction # make online prediction
word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1] if self.decoder.sampling_for_deletion:
word_predictions = torch.multinomial(
F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1).view(
word_ins_out.size(0), -1)
else:
word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1]
word_predictions.masked_scatter_( word_predictions.masked_scatter_(
~masked_tgt_masks, tgt_tokens[~masked_tgt_masks] ~masked_tgt_masks, tgt_tokens[~masked_tgt_masks]
) )
...@@ -363,7 +383,7 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -363,7 +383,7 @@ class LevenshteinTransformerModel(TransformerModel):
) )
mask_ins_score = F.log_softmax(mask_ins_out, 2) mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0: if eos_penalty > 0.0:
mask_ins_score[:, :, 0] -= eos_penalty mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1] mask_ins_pred = mask_ins_score.max(-1)[1]
mask_ins_pred = torch.min( mask_ins_pred = torch.min(
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
...@@ -442,15 +462,30 @@ class LevenshteinTransformerDecoder(TransformerDecoder): ...@@ -442,15 +462,30 @@ class LevenshteinTransformerDecoder(TransformerDecoder):
self.bos = dictionary.bos() self.bos = dictionary.bos()
self.unk = dictionary.unk() self.unk = dictionary.unk()
self.eos = dictionary.eos() self.eos = dictionary.eos()
self.sampling_for_deletion = getattr(args, "sampling_for_deletion", False)
self.embed_mask_ins = Embedding(256, self.output_embed_dim * 2, None) self.embed_mask_ins = Embedding(256, self.output_embed_dim * 2, None)
self.embed_word_del = Embedding(2, self.output_embed_dim, None) self.embed_word_del = Embedding(2, self.output_embed_dim, None)
# del_word, ins_mask, ins_word # del_word, ins_mask, ins_word
self.early_exit = [int(i) for i in args.early_exit.split(',')] self.early_exit = [int(i) for i in args.early_exit.split(',')]
assert len(self.early_exit) == 3 assert len(self.early_exit) == 3
# copy layers for mask-predict/deletion
self.layers_msk = None
if getattr(args, "no_share_maskpredictor", False):
self.layers_msk = nn.ModuleList([
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(self.early_exit[1])
])
self.layers_del = None
if getattr(args, "no_share_discriminator", False):
self.layers_del = nn.ModuleList([
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(self.early_exit[0])
])
def extract_features( def extract_features(
self, prev_output_tokens, encoder_out=None, early_exit=None, **unused self, prev_output_tokens, encoder_out=None, early_exit=None, layers=None, **unused
): ):
""" """
Similar to *forward* but only return features. Similar to *forward* but only return features.
...@@ -488,12 +523,9 @@ class LevenshteinTransformerDecoder(TransformerDecoder): ...@@ -488,12 +523,9 @@ class LevenshteinTransformerDecoder(TransformerDecoder):
# decoder layers # decoder layers
decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
for i, layer in enumerate(self.layers): layers = self.layers if layers is None else layers
early_exit = len(layers) if early_exit is None else early_exit
# early exit from the decoder. for _, layer in enumerate(layers[: early_exit]):
if (early_exit is not None) and (i >= early_exit):
break
x, attn = layer( x, attn = layer(
x, x,
encoder_out["encoder_out"] if encoder_out is not None else None, encoder_out["encoder_out"] if encoder_out is not None else None,
...@@ -516,36 +548,25 @@ class LevenshteinTransformerDecoder(TransformerDecoder): ...@@ -516,36 +548,25 @@ class LevenshteinTransformerDecoder(TransformerDecoder):
return x, {"attn": attn, "inner_states": inner_states} return x, {"attn": attn, "inner_states": inner_states}
def forward_mask_ins(self, prev_output_tokens, encoder_out=None): def forward_mask_ins(self, prev_output_tokens, encoder_out=None, **unused):
features, extra = self.extract_features( features, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1] prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1], layers=self.layers_msk, **unused
) )
features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn'] return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn']
def forward_word_ins(self, prev_output_tokens, encoder_out=None): def forward_word_ins(self, prev_output_tokens, encoder_out=None, **unused):
features, extra = self.extract_features( features, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2] prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2], layers=self.layers, **unused
) )
return self.output_layer(features), extra['attn'] return self.output_layer(features), extra['attn']
def forward_word_del(self, prev_output_tokens, encoder_out=None): def forward_word_del(self, prev_output_tokens, encoder_out=None, **unused):
features, extra = self.extract_features( features, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0] prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0], layers=self.layers_del, **unused
) )
return F.linear(features, self.embed_word_del.weight), extra['attn'] return F.linear(features, self.embed_word_del.weight), extra['attn']
def forward_word_del_mask_ins(self, prev_output_tokens, encoder_out=None):
# merge the word-deletion and mask insertion into one operation,
assert self.early_exit[0] == self.early_exit[1], "must the same depth."
features, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2]
)
features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
f_word_del = F.linear(features, self.embed_word_del.weight)
f_mask_ins = F.linear(features_cat, self.embed_mask_ins.weight)
return f_word_del, f_mask_ins, extra['attn']
@register_model_architecture("levenshtein_transformer", "levenshtein_transformer") @register_model_architecture("levenshtein_transformer", "levenshtein_transformer")
def base_architecture(args): def base_architecture(args):
...@@ -584,9 +605,11 @@ def base_architecture(args): ...@@ -584,9 +605,11 @@ def base_architecture(args):
args.decoder_output_dim = getattr( args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim args, "decoder_output_dim", args.decoder_embed_dim
) )
args.sampling_for_deletion = getattr(args, "sampling_for_deletion", False)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.early_exit = getattr(args, "early_exit", "6,6,6")
args.early_exit = getattr(args, "early_exit", "(6, 6, 6)") args.no_share_discriminator = getattr(args, "no_share_discriminator", False)
args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False)
@register_model_architecture( @register_model_architecture(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment