Commit 1c667929 authored by Sarthak Garg's avatar Sarthak Garg Committed by Facebook Github Bot
Browse files

Implementation of the paper "Jointly Learning to Align and Translate with...

Implementation of the paper "Jointly Learning to Align and Translate with Transformer Models" (#877)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/877

This PR implements guided alignment training described in  "Jointly Learning to Align and Translate with Transformer Models (https://arxiv.org/abs/1909.02074)".

In summary, it allows for training selected heads of the Transformer Model with external alignments computed by Statistical Alignment Toolkits. During inference, attention probabilities from the trained heads can be used to extract reliable alignments. In our work, we did not see any regressions in the translation performance because of guided alignment training.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1095

Differential Revision: D17170337

Pulled By: myleott

fbshipit-source-id: daa418bef70324d7088dbb30aa2adf9f95774859
parent acb6fba0
...@@ -33,6 +33,7 @@ Fairseq provides reference implementations of various sequence-to-sequence model ...@@ -33,6 +33,7 @@ Fairseq provides reference implementations of various sequence-to-sequence model
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- **Non-autoregressive Transformers** - **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
...@@ -100,6 +101,7 @@ as well as example training and evaluation commands. ...@@ -100,6 +101,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available - [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
We also have more detailed READMEs to reproduce results from specific papers: We also have more detailed READMEs to reproduce results from specific papers:
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
......
# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)
This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074).
## Training a joint alignment-translation model on WMT'18 En-De
##### 1. Extract and preprocess the WMT'18 En-De data
```bash
./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
```
##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign.
In this example, we use FastAlign.
```bash
git clone git@github.com:clab/fast_align.git
pushd fast_align
mkdir build
cd build
cmake ..
make
popd
ALIGN=fast_align/build/fast_align
paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de
$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align
```
##### 3. Preprocess the dataset with the above generated alignments.
```bash
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref bpe.32k/train \
--validpref bpe.32k/valid \
--testpref bpe.32k/test \
--align-suffix align \
--destdir binarized/ \
--joined-dictionary \
--workers 32
```
##### 4. Train a model
```bash
fairseq-train \
binarized \
--arch transformer_wmt_en_de_big_align --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\
--lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
--max-tokens 3500 --label-smoothing 0.1 \
--save-dir ./checkpoints --log-interval 1000 --max-update 60000 \
--keep-interval-updates -1 --save-interval-updates 0 \
--load-alignments --criterion label_smoothed_cross_entropy_with_alignment \
--fp16
```
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
If you want to train the above model with big batches (assuming your machine has 8 GPUs):
- add `--update-freq 8` to simulate training on 8x8=64 GPUs
- increase the learning rate; 0.0007 works well for big batches
##### 5. Evaluate and generate the alignments (BPE level)
```bash
fairseq-generate \
binarized --gen-subset test --print-alignment \
--source-lang en --target-lang de \
--path checkpoints/checkpoint_best.pt --beam 5 --nbest 1
```
##### 6. Other resources.
The code for:
1. preparing alignment test sets
2. converting BPE level alignments to token level alignments
3. symmetrizing bidirectional alignments
4. evaluating alignments using AER metric
can be found [here](https://github.com/lilt/alignment-scripts)
## Citation
```bibtex
@inproceedings{garg2019jointly,
title = {Jointly Learning to Align and Translate with Transformer Models},
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
address = {Hong Kong},
month = {November},
url = {https://arxiv.org/abs/1909.02074},
year = {2019},
}
```
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
"http://statmt.org/wmt14/test-full.tgz"
)
CORPORA=(
"training/europarl-v7.de-en"
"commoncrawl.de-en"
"training-parallel-nc-v13/news-commentary-v13.de-en"
"rapid2016.de-en"
)
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=en
tgt=de
lang=en-de
prep=wmt18_en_de
tmp=$prep/tmp
orig=orig
dev=dev/newstest2012
codes=32000
bpe=bpe.32k
mkdir -p $orig $tmp $prep $bpe
cd $orig
for ((i=0;i<${#URLS[@]};++i)); do
url=${URLS[i]}
file=$(basename $url)
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit 1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
rm -rf $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l
done
done
echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l
echo ""
done
# apply length filtering before BPE
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100
# use newstest2012 for valid
echo "pre-processing valid data..."
for l in $src $tgt; do
rm -rf $tmp/valid.$l
cat $orig/$dev.$l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l
done
mkdir output
mv $tmp/{train,valid,test}.{$src,$tgt} output
#BPE
git clone git@github.com:glample/fastBPE.git
pushd fastBPE
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
popd
fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes
for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done
...@@ -52,6 +52,22 @@ class Binarizer: ...@@ -52,6 +52,22 @@ class Binarizer:
line = f.readline() line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced} return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1):
nseq = 0
with open(filename, 'r') as f:
f.seek(offset)
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = alignment_parser(line)
nseq += 1
consumer(ids)
line = f.readline()
return {'nseq': nseq}
@staticmethod @staticmethod
def find_offsets(filename, num_chunks): def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f: with open(filename, 'r', encoding='utf-8') as f:
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import utils
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from . import register_criterion
@register_criterion('label_smoothed_cross_entropy_with_alignment')
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):
def __init__(self, args, task):
super().__init__(args, task)
self.alignment_lambda = args.alignment_lambda
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
super(LabelSmoothedCrossEntropyCriterionWithAlignment,
LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser)
parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D',
help='weight for the alignment loss')
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
alignment_loss = None
# Compute alignment loss only for training set and non dummy batches.
if 'alignments' in sample and sample['alignments'] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)
if alignment_loss is not None:
logging_output['alignment_loss'] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss
return loss, sample_size, logging_output
def compute_alignment_loss(self, sample, net_output):
attn_prob = net_output[1]['attn']
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)
align = sample['alignments']
align_weights = sample['align_weights'].float()
if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum()
else:
return None
return loss
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
...@@ -22,6 +22,28 @@ def collate( ...@@ -22,6 +22,28 @@ def collate(
pad_idx, eos_idx, left_pad, move_eos_to_beginning, pad_idx, eos_idx, left_pad, move_eos_to_beginning,
) )
def check_alignment(alignment, src_len, tgt_len):
if alignment is None or len(alignment) == 0:
return False
if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
print("| alignment size mismatch found, skipping alignment!")
return False
return True
def compute_alignment_weights(alignments):
"""
Given a tensor of shape [:, 2] containing the source-target indices
corresponding to the alignments, a weight vector containing the
inverse frequency of each target index is computed.
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
index 3 is repeated twice)
"""
align_tgt = alignments[:, 1]
_, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True)
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
return 1. / align_weights.float()
id = torch.LongTensor([s['id'] for s in samples]) id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=left_pad_source) src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length # sort by descending source length
...@@ -35,6 +57,7 @@ def collate( ...@@ -35,6 +57,7 @@ def collate(
if samples[0].get('target', None) is not None: if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target) target = merge('target', left_pad=left_pad_target)
target = target.index_select(0, sort_order) target = target.index_select(0, sort_order)
tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples) ntokens = sum(len(s['target']) for s in samples)
if input_feeding: if input_feeding:
...@@ -61,6 +84,32 @@ def collate( ...@@ -61,6 +84,32 @@ def collate(
} }
if prev_output_tokens is not None: if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens batch['net_input']['prev_output_tokens'] = prev_output_tokens
if samples[0].get('alignment', None) is not None:
bsz, tgt_sz = batch['target'].shape
src_sz = batch['net_input']['src_tokens'].shape[1]
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz)
if left_pad_source:
offsets[:, 0] += (src_sz - src_lengths)
if left_pad_target:
offsets[:, 1] += (tgt_sz - tgt_lengths)
alignments = [
alignment + offset
for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
if check_alignment(alignment, src_len, tgt_len)
]
if len(alignments) > 0:
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)
batch['alignments'] = alignments
batch['align_weights'] = align_weights
return batch return batch
...@@ -91,6 +140,8 @@ class LanguagePairDataset(FairseqDataset): ...@@ -91,6 +140,8 @@ class LanguagePairDataset(FairseqDataset):
of source if it's present (default: False). of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent (default: False). target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments.
""" """
def __init__( def __init__(
...@@ -98,7 +149,9 @@ class LanguagePairDataset(FairseqDataset): ...@@ -98,7 +149,9 @@ class LanguagePairDataset(FairseqDataset):
tgt=None, tgt_sizes=None, tgt_dict=None, tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False, left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024, max_source_positions=1024, max_target_positions=1024,
shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None,
): ):
if tgt_dict is not None: if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad() assert src_dict.pad() == tgt_dict.pad()
...@@ -118,6 +171,9 @@ class LanguagePairDataset(FairseqDataset): ...@@ -118,6 +171,9 @@ class LanguagePairDataset(FairseqDataset):
self.input_feeding = input_feeding self.input_feeding = input_feeding
self.remove_eos_from_source = remove_eos_from_source self.remove_eos_from_source = remove_eos_from_source
self.append_eos_to_target = append_eos_to_target self.append_eos_to_target = append_eos_to_target
self.align_dataset = align_dataset
if self.align_dataset is not None:
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
def __getitem__(self, index): def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None tgt_item = self.tgt[index] if self.tgt is not None else None
...@@ -136,11 +192,14 @@ class LanguagePairDataset(FairseqDataset): ...@@ -136,11 +192,14 @@ class LanguagePairDataset(FairseqDataset):
if self.src[index][-1] == eos: if self.src[index][-1] == eos:
src_item = self.src[index][:-1] src_item = self.src[index][:-1]
return { example = {
'id': index, 'id': index,
'source': src_item, 'source': src_item,
'target': tgt_item, 'target': tgt_item,
} }
if self.align_dataset is not None:
example['alignment'] = self.align_dataset[index]
return example
def __len__(self): def __len__(self):
return len(self.src) return len(self.src)
...@@ -212,3 +271,5 @@ class LanguagePairDataset(FairseqDataset): ...@@ -212,3 +271,5 @@ class LanguagePairDataset(FairseqDataset):
self.src.prefetch(indices) self.src.prefetch(indices)
if self.tgt is not None: if self.tgt is not None:
self.tgt.prefetch(indices) self.tgt.prefetch(indices)
if self.align_dataset is not None:
self.align_dataset.prefetch(indices)
...@@ -222,6 +222,9 @@ class FairseqEncoderDecoderModel(BaseFairseqModel): ...@@ -222,6 +222,9 @@ class FairseqEncoderDecoderModel(BaseFairseqModel):
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs) decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out return decoder_out
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
""" """
Similar to *forward* but only return features. Similar to *forward* but only return features.
......
...@@ -68,6 +68,7 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -68,6 +68,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
self.supports_align_args = True
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -195,6 +196,69 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -195,6 +196,69 @@ class TransformerModel(FairseqEncoderDecoderModel):
) )
@register_model('transformer_align')
class TransformerAlignModel(TransformerModel):
"""
See "Jointly Learning to Align and Translate with Transformer
Models" (Garg et al., EMNLP 2019).
"""
def __init__(self, encoder, decoder, args):
super().__init__(encoder, decoder)
self.alignment_heads = args.alignment_heads
self.alignment_layer = args.alignment_layer
self.full_context_alignment = args.full_context_alignment
@staticmethod
def add_args(parser):
# fmt: off
super(TransformerAlignModel, TransformerAlignModel).add_args(parser)
parser.add_argument('--alignment-heads', type=int, metavar='D',
help='Number of cross attention heads per layer to supervised with alignments')
parser.add_argument('--alignment-layer', type=int, metavar='D',
help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.')
parser.add_argument('--full-context-alignment', type=bool, metavar='D',
help='Whether or not alignment is supervised conditioned on the full target context.')
# fmt: on
@classmethod
def build_model(cls, args, task):
# set any default arguments
transformer_align(args)
transformer_model = TransformerModel.build_model(args, task)
return TransformerAlignModel(transformer_model.encoder, transformer_model.decoder, args)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
return self.forward_decoder(prev_output_tokens, encoder_out)
def forward_decoder(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
):
attn_args = {'alignment_layer': self.alignment_layer, 'alignment_heads': self.alignment_heads}
decoder_out = self.decoder(
prev_output_tokens,
encoder_out,
**attn_args,
**extra_args,
)
if self.full_context_alignment:
attn_args['full_context_alignment'] = self.full_context_alignment
_, alignment_out = self.decoder(
prev_output_tokens, encoder_out, features_only=True, **attn_args, **extra_args,
)
decoder_out[1]['attn'] = alignment_out['attn']
return decoder_out
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
""" """
Transformer encoder consisting of *args.encoder_layers* layers. Each layer Transformer encoder consisting of *args.encoder_layers* layers. Each layer
...@@ -423,7 +487,14 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -423,7 +487,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else: else:
self.layer_norm = None self.layer_norm = None
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): def forward(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
):
""" """
Args: Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape prev_output_tokens (LongTensor): previous decoder outputs of shape
...@@ -432,25 +503,53 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -432,25 +503,53 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder-side attention encoder-side attention
incremental_state (dict): dictionary used for storing state during incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding` :ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns: Returns:
tuple: tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)` - the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs - a dictionary with any model-specific outputs
""" """
x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state) x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args,
)
if not features_only:
x = self.output_layer(x) x = self.output_layer(x)
return x, extra return x, extra
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): def extract_features(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
full_context_alignment=False,
alignment_layer=None,
alignment_heads=None,
**unused,
):
""" """
Similar to *forward* but only return features. Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns: Returns:
tuple: tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)` - the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs - a dictionary with any model-specific outputs
""" """
if alignment_layer is None:
alignment_layer = len(self.layers) - 1
# embed positions # embed positions
positions = self.embed_positions( positions = self.embed_positions(
prev_output_tokens, prev_output_tokens,
...@@ -474,15 +573,14 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -474,15 +573,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
attn = None
inner_states = [x]
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention: if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None self_attn_padding_mask = None
# decoder layers # decoder layers
attn = None
inner_states = [x]
for idx, layer in enumerate(self.layers): for idx, layer in enumerate(self.layers):
encoder_state = None encoder_state = None
if encoder_out is not None: if encoder_out is not None:
...@@ -491,15 +589,32 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -491,15 +589,32 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else: else:
encoder_state = encoder_out['encoder_out'] encoder_state = encoder_out['encoder_out']
x, attn = layer( if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn = layer(
x, x,
encoder_state, encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state, incremental_state,
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
) )
inner_states.append(x) inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm: if self.layer_norm:
x = self.layer_norm(x) x = self.layer_norm(x)
...@@ -531,7 +646,12 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -531,7 +646,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
def buffered_future_mask(self, tensor): def buffered_future_mask(self, tensor):
dim = tensor.size(0) dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim: if (
not hasattr(self, '_future_mask')
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
return self._future_mask[:dim, :dim] return self._future_mask[:dim, :dim]
...@@ -668,3 +788,18 @@ def transformer_wmt_en_de_big_t2t(args): ...@@ -668,3 +788,18 @@ def transformer_wmt_en_de_big_t2t(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.1) args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args) transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture('transformer_align', 'transformer_align')
def transformer_align(args):
args.alignment_heads = getattr(args, 'alignment_heads', 1)
args.alignment_layer = getattr(args, 'alignment_layer', 4)
args.full_context_alignment = getattr(args, 'full_context_alignment', False)
base_architecture(args)
@register_model_architecture('transformer_align', 'transformer_wmt_en_de_big_align')
def transformer_wmt_en_de_big_align(args):
args.alignment_heads = getattr(args, 'alignment_heads', 1)
args.alignment_layer = getattr(args, 'alignment_layer', 4)
transformer_wmt_en_de_big(args)
...@@ -90,15 +90,37 @@ class MultiheadAttention(nn.Module): ...@@ -90,15 +90,37 @@ class MultiheadAttention(nn.Module):
if self.bias_v is not None: if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v) nn.init.xavier_normal_(self.bias_v)
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, def forward(
need_weights=True, static_kv=False, attn_mask=None, before_softmax=False): self,
query, key, value,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
):
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
Timesteps can be masked by supplying a T x T mask in the Args:
`attn_mask` argument. Padding elements can be excluded from key_padding_mask (ByteTensor, optional): mask to exclude
the key by passing a binary ByteTensor (`key_padding_mask`) with shape: keys that are pads, of shape `(batch, src_len)`, where
batch x src_len, where padding elements are indicated by 1s. padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
""" """
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim] assert list(query.size()) == [tgt_len, bsz, embed_dim]
...@@ -249,12 +271,11 @@ class MultiheadAttention(nn.Module): ...@@ -249,12 +271,11 @@ class MultiheadAttention(nn.Module):
if before_softmax: if before_softmax:
return attn_weights, v return attn_weights, v
attn_weights = utils.softmax( attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
attn_weights, dim=-1, onnx_trace=self.onnx_trace, attn_weights = attn_weights_float.type_as(attn_weights)
).type_as(attn_weights) attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v) attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if (self.onnx_trace and attn.size(1) == 1): if (self.onnx_trace and attn.size(1) == 1):
# when ONNX tracing a single decoder step (sequence length == 1) # when ONNX tracing a single decoder step (sequence length == 1)
...@@ -265,9 +286,10 @@ class MultiheadAttention(nn.Module): ...@@ -265,9 +286,10 @@ class MultiheadAttention(nn.Module):
attn = self.out_proj(attn) attn = self.out_proj(attn)
if need_weights: if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads # average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.mean(dim=0)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
else: else:
attn_weights = None attn_weights = None
......
...@@ -195,16 +195,25 @@ class TransformerDecoderLayer(nn.Module): ...@@ -195,16 +195,25 @@ class TransformerDecoderLayer(nn.Module):
prev_attn_state=None, prev_attn_state=None,
self_attn_mask=None, self_attn_mask=None,
self_attn_padding_mask=None, self_attn_padding_mask=None,
need_attn=False,
need_head_weights=False,
): ):
""" """
Args: Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape encoder_padding_mask (ByteTensor, optional): binary
`(batch, src_len)` where padding elements are indicated by ``1``. ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns: Returns:
encoded output of shape `(seq_len, batch, embed_dim)` encoded output of shape `(seq_len, batch, embed_dim)`
""" """
if need_head_weights:
need_attn = True
residual = x residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None: if prev_self_attn_state is not None:
...@@ -259,7 +268,8 @@ class TransformerDecoderLayer(nn.Module): ...@@ -259,7 +268,8 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask, key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state, incremental_state=incremental_state,
static_kv=True, static_kv=True,
need_weights=(not self.training and self.need_attn), need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
......
...@@ -224,6 +224,8 @@ def add_preprocess_args(parser): ...@@ -224,6 +224,8 @@ def add_preprocess_args(parser):
help="comma separated, valid file prefixes") help="comma separated, valid file prefixes")
group.add_argument("--testpref", metavar="FP", default=None, group.add_argument("--testpref", metavar="FP", default=None,
help="comma separated, test file prefixes") help="comma separated, test file prefixes")
group.add_argument("--align-suffix", metavar="FP", default=None,
help="alignment file suffix")
group.add_argument("--destdir", metavar="DIR", default="data-bin", group.add_argument("--destdir", metavar="DIR", default="data-bin",
help="destination dir") help="destination dir")
group.add_argument("--thresholdtgt", metavar="N", default=0, type=int, group.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
......
...@@ -7,7 +7,8 @@ import math ...@@ -7,7 +7,8 @@ import math
import torch import torch
from fairseq import search from fairseq import search, utils
from fairseq.data import data_utils
from fairseq.models import FairseqIncrementalDecoder from fairseq.models import FairseqIncrementalDecoder
...@@ -81,7 +82,6 @@ class SequenceGenerator(object): ...@@ -81,7 +82,6 @@ class SequenceGenerator(object):
self.temperature = temperature self.temperature = temperature
self.match_source_len = match_source_len self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling'
assert temperature > 0, '--temperature must be greater than 0' assert temperature > 0, '--temperature must be greater than 0'
...@@ -98,14 +98,7 @@ class SequenceGenerator(object): ...@@ -98,14 +98,7 @@ class SequenceGenerator(object):
self.search = search.BeamSearch(tgt_dict) self.search = search.BeamSearch(tgt_dict)
@torch.no_grad() @torch.no_grad()
def generate( def generate(self, models, sample, **kwargs):
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
"""Generate a batch of translations. """Generate a batch of translations.
Args: Args:
...@@ -113,8 +106,21 @@ class SequenceGenerator(object): ...@@ -113,8 +106,21 @@ class SequenceGenerator(object):
sample (dict): batch sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens with these tokens
bos_token (int, optional): beginning of sentence token
(default: self.eos)
""" """
model = EnsembleModel(models) model = EnsembleModel(models)
return self._generate(model, sample, **kwargs)
@torch.no_grad()
def _generate(
self,
model,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
if not self.retain_dropout: if not self.retain_dropout:
model.eval() model.eval()
...@@ -155,7 +161,6 @@ class SequenceGenerator(object): ...@@ -155,7 +161,6 @@ class SequenceGenerator(object):
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = self.eos if bos_token is None else bos_token tokens[:, 0] = self.eos if bos_token is None else bos_token
attn, attn_buf = None, None attn, attn_buf = None, None
nonpad_idxs = None
# The blacklist indicates candidates that should be ignored. # The blacklist indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5 # For example, suppose we're sampling and have already finalized 2/5
...@@ -251,17 +256,15 @@ class SequenceGenerator(object): ...@@ -251,17 +256,15 @@ class SequenceGenerator(object):
if attn_clone is not None: if attn_clone is not None:
# remove padding tokens from attn scores # remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]] hypo_attn = attn_clone[i]
_, alignment = hypo_attn.max(dim=0)
else: else:
hypo_attn = None hypo_attn = None
alignment = None
return { return {
'tokens': tokens_clone[i], 'tokens': tokens_clone[i],
'score': score, 'score': score,
'attention': hypo_attn, # src_len x tgt_len 'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment, 'alignment': None,
'positional_scores': pos_scores[i], 'positional_scores': pos_scores[i],
} }
...@@ -345,7 +348,6 @@ class SequenceGenerator(object): ...@@ -345,7 +348,6 @@ class SequenceGenerator(object):
if attn is None: if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
attn_buf = attn.clone() attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores) attn[:, :, step + 1].copy_(avg_attn_scores)
scores = scores.type_as(lprobs) scores = scores.type_as(lprobs)
...@@ -512,7 +514,6 @@ class SequenceGenerator(object): ...@@ -512,7 +514,6 @@ class SequenceGenerator(object):
# sort by score descending # sort by score descending
for sent in range(len(finalized)): for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized return finalized
...@@ -577,9 +578,11 @@ class EnsembleModel(torch.nn.Module): ...@@ -577,9 +578,11 @@ class EnsembleModel(torch.nn.Module):
temperature=1., temperature=1.,
): ):
if self.incremental_states is not None: if self.incremental_states is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model])) decoder_out = list(model.forward_decoder(
tokens, encoder_out=encoder_out, incremental_state=self.incremental_states[model],
))
else: else:
decoder_out = list(model.decoder(tokens, encoder_out)) decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out))
decoder_out[0] = decoder_out[0][:, -1:, :] decoder_out[0] = decoder_out[0][:, -1:, :]
if temperature != 1.: if temperature != 1.:
decoder_out[0].div_(temperature) decoder_out[0].div_(temperature)
...@@ -605,3 +608,104 @@ class EnsembleModel(torch.nn.Module): ...@@ -605,3 +608,104 @@ class EnsembleModel(torch.nn.Module):
return return
for model in self.models: for model in self.models:
model.decoder.reorder_incremental_state(self.incremental_states[model], new_order) model.decoder.reorder_incremental_state(self.incremental_states[model], new_order)
class SequenceGeneratorWithAlignment(SequenceGenerator):
def __init__(self, tgt_dict, left_pad_target=False, **kwargs):
"""Generates translations of a given source sentence.
Produces alignments following "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
left_pad_target (bool, optional): Whether or not the
hypothesis should be left padded or not when they are
teacher forced for generating alignments.
"""
super().__init__(tgt_dict, **kwargs)
self.left_pad_target = left_pad_target
@torch.no_grad()
def generate(self, models, sample, **kwargs):
model = EnsembleModelWithAlignment(models)
finalized = super()._generate(model, sample, **kwargs)
src_tokens = sample['net_input']['src_tokens']
bsz = src_tokens.shape[0]
beam_size = self.beam_size
src_tokens, src_lengths, prev_output_tokens, tgt_tokens = \
self._prepare_batch_for_alignment(sample, finalized)
if any(getattr(m, 'full_context_alignment', False) for m in model.models):
attn = model.forward_align(src_tokens, src_lengths, prev_output_tokens)
else:
attn = [
finalized[i // beam_size][i % beam_size]['attention'].transpose(1, 0)
for i in range(bsz * beam_size)
]
# Process the attn matrix to extract hard alignments.
for i in range(bsz * beam_size):
alignment = utils.extract_hard_alignment(attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos)
finalized[i // beam_size][i % beam_size]['alignment'] = alignment
return finalized
def _prepare_batch_for_alignment(self, sample, hypothesis):
src_tokens = sample['net_input']['src_tokens']
bsz = src_tokens.shape[0]
src_tokens = src_tokens[:, None, :].expand(-1, self.beam_size, -1).contiguous().view(bsz * self.beam_size, -1)
src_lengths = sample['net_input']['src_lengths']
src_lengths = src_lengths[:, None].expand(-1, self.beam_size).contiguous().view(bsz * self.beam_size)
prev_output_tokens = data_utils.collate_tokens(
[beam['tokens'] for example in hypothesis for beam in example],
self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=True,
)
tgt_tokens = data_utils.collate_tokens(
[beam['tokens'] for example in hypothesis for beam in example],
self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=False,
)
return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
class EnsembleModelWithAlignment(EnsembleModel):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__(models)
def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
avg_attn = None
for model in self.models:
decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
attn = decoder_out[1]['attn']
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
if len(self.models) > 1:
avg_attn.div_(len(self.models))
return avg_attn
def _decode_one(
self, tokens, model, encoder_out, incremental_states, log_probs,
temperature=1.,
):
if self.incremental_states is not None:
decoder_out = list(model.forward_decoder(
tokens,
encoder_out=encoder_out,
incremental_state=self.incremental_states[model],
))
else:
decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out))
decoder_out[0] = decoder_out[0][:, -1:, :]
if temperature != 1.:
decoder_out[0].div_(temperature)
attn = decoder_out[1]
if type(attn) is dict:
attn = attn.get('attn', None)
if attn is not None:
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
probs = probs[:, -1, :]
return probs, attn
...@@ -14,6 +14,7 @@ class SequenceScorer(object): ...@@ -14,6 +14,7 @@ class SequenceScorer(object):
def __init__(self, tgt_dict, softmax_batch=None): def __init__(self, tgt_dict, softmax_batch=None):
self.pad = tgt_dict.pad() self.pad = tgt_dict.pad()
self.eos = tgt_dict.eos()
self.softmax_batch = softmax_batch or sys.maxsize self.softmax_batch = softmax_batch or sys.maxsize
assert self.softmax_batch > 0 assert self.softmax_batch > 0
...@@ -44,6 +45,7 @@ class SequenceScorer(object): ...@@ -44,6 +45,7 @@ class SequenceScorer(object):
) )
return probs return probs
orig_target = sample['target'] orig_target = sample['target']
# compute scores for each model in the ensemble # compute scores for each model in the ensemble
...@@ -53,6 +55,8 @@ class SequenceScorer(object): ...@@ -53,6 +55,8 @@ class SequenceScorer(object):
model.eval() model.eval()
decoder_out = model.forward(**net_input) decoder_out = model.forward(**net_input)
attn = decoder_out[1] attn = decoder_out[1]
if type(attn) is dict:
attn = attn.get('attn', None)
batched = batch_for_softmax(decoder_out, orig_target) batched = batch_for_softmax(decoder_out, orig_target)
probs, idx = None, 0 probs, idx = None, 0
...@@ -100,8 +104,9 @@ class SequenceScorer(object): ...@@ -100,8 +104,9 @@ class SequenceScorer(object):
avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
score_i = avg_probs_i.sum() / tgt_len score_i = avg_probs_i.sum() / tgt_len
if avg_attn is not None: if avg_attn is not None:
avg_attn_i = avg_attn[i, start_idxs[i]:] avg_attn_i = avg_attn[i]
_, alignment = avg_attn_i.max(dim=0) alignment = utils.extract_hard_alignment(avg_attn_i, sample['net_input']['src_tokens'][i],
sample['target'][i], self.pad, self.eos)
else: else:
avg_attn_i = alignment = None avg_attn_i = alignment = None
hypos.append([{ hypos.append([{
......
...@@ -198,8 +198,12 @@ class FairseqTask(object): ...@@ -198,8 +198,12 @@ class FairseqTask(object):
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(self.target_dictionary) return SequenceScorer(self.target_dictionary)
else: else:
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator, SequenceGeneratorWithAlignment
return SequenceGenerator( if getattr(args, 'print_alignment', False):
seq_gen_cls = SequenceGeneratorWithAlignment
else:
seq_gen_cls = SequenceGenerator
return seq_gen_cls(
self.target_dictionary, self.target_dictionary,
beam_size=getattr(args, 'beam', 5), beam_size=getattr(args, 'beam', 5),
max_len_a=getattr(args, 'max_len_a', 0), max_len_a=getattr(args, 'max_len_a', 0),
......
...@@ -24,7 +24,7 @@ def load_langpair_dataset( ...@@ -24,7 +24,7 @@ def load_langpair_dataset(
tgt, tgt_dict, tgt, tgt_dict,
combine, dataset_impl, upsample_primary, combine, dataset_impl, upsample_primary,
left_pad_source, left_pad_target, max_source_positions, left_pad_source, left_pad_target, max_source_positions,
max_target_positions, prepend_bos=False, max_target_positions, prepend_bos=False, load_alignments=False,
): ):
def split_exists(split, src, tgt, lang, data_path): def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
...@@ -74,6 +74,12 @@ def load_langpair_dataset( ...@@ -74,6 +74,12 @@ def load_langpair_dataset(
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)
return LanguagePairDataset( return LanguagePairDataset(
src_dataset, src_dataset.sizes, src_dict, src_dataset, src_dataset.sizes, src_dict,
tgt_dataset, tgt_dataset.sizes, tgt_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict,
...@@ -81,6 +87,7 @@ def load_langpair_dataset( ...@@ -81,6 +87,7 @@ def load_langpair_dataset(
left_pad_target=left_pad_target, left_pad_target=left_pad_target,
max_source_positions=max_source_positions, max_source_positions=max_source_positions,
max_target_positions=max_target_positions, max_target_positions=max_target_positions,
align_dataset=align_dataset,
) )
...@@ -120,6 +127,8 @@ class TranslationTask(FairseqTask): ...@@ -120,6 +127,8 @@ class TranslationTask(FairseqTask):
help='load the dataset lazily') help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true', parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--load-alignments', action='store_true',
help='load the binarized alignments')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left') help='pad the source on the left')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
...@@ -193,6 +202,7 @@ class TranslationTask(FairseqTask): ...@@ -193,6 +202,7 @@ class TranslationTask(FairseqTask):
left_pad_target=self.args.left_pad_target, left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions, max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions, max_target_positions=self.args.max_target_positions,
load_alignments=self.args.load_alignments,
) )
def build_dataset_for_inference(self, src_tokens, src_lengths): def build_dataset_for_inference(self, src_tokens, src_lengths):
......
...@@ -16,6 +16,7 @@ import warnings ...@@ -16,6 +16,7 @@ import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from itertools import accumulate
from fairseq.modules import gelu, gelu_accurate from fairseq.modules import gelu, gelu_accurate
...@@ -367,3 +368,47 @@ def set_torch_seed(seed): ...@@ -367,3 +368,47 @@ def set_torch_seed(seed):
assert isinstance(seed, int) assert isinstance(seed, int)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
def parse_alignment(line):
"""
Parses a single line from the alingment file.
Args:
line (str): String containing the alignment of the format:
<src_idx_1>-<tgt_idx_1> <src_idx_2>-<tgt_idx_2> ..
<src_idx_m>-<tgt_idx_m>. All indices are 0 indexed.
Returns:
torch.IntTensor: packed alignments of shape (2 * m).
"""
alignments = line.strip().split()
parsed_alignment = torch.IntTensor(2 * len(alignments))
for idx, alignment in enumerate(alignments):
src_idx, tgt_idx = alignment.split('-')
parsed_alignment[2 * idx] = int(src_idx)
parsed_alignment[2 * idx + 1] = int(tgt_idx)
return parsed_alignment
def get_token_to_word_mapping(tokens, exclude_list):
n = len(tokens)
word_start = [int(token not in exclude_list) for token in tokens]
word_idx = list(accumulate(word_start))
token_to_word = {i: word_idx[i] for i in range(n)}
return token_to_word
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero().squeeze(dim=-1)
src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero().squeeze(dim=-1)
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad])
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad])
alignment = []
if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent):
attn_valid = attn[tgt_valid]
attn_valid[:, src_invalid] = float('-inf')
_, src_indices = attn_valid.max(dim=1)
for tgt_idx, src_idx in zip(tgt_valid, src_indices):
alignment.append((src_token_to_word[src_idx.item()] - 1, tgt_token_to_word[tgt_idx.item()] - 1))
return alignment
...@@ -137,7 +137,7 @@ def main(args): ...@@ -137,7 +137,7 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(), hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, alignment=hypo['alignment'],
align_dict=align_dict, align_dict=align_dict,
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
...@@ -156,7 +156,7 @@ def main(args): ...@@ -156,7 +156,7 @@ def main(args):
if args.print_alignment: if args.print_alignment:
print('A-{}\t{}'.format( print('A-{}\t{}'.format(
sample_id, sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment)) ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment])
)) ))
if args.print_step: if args.print_step:
...@@ -180,6 +180,7 @@ def main(args): ...@@ -180,6 +180,7 @@ def main(args):
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
if has_target: if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
return scorer return scorer
......
...@@ -162,7 +162,7 @@ def main(args): ...@@ -162,7 +162,7 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(), hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, alignment=hypo['alignment'],
align_dict=align_dict, align_dict=align_dict,
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
...@@ -174,9 +174,10 @@ def main(args): ...@@ -174,9 +174,10 @@ def main(args):
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
)) ))
if args.print_alignment: if args.print_alignment:
alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
print('A-{}\t{}'.format( print('A-{}\t{}'.format(
id, id,
' '.join(map(lambda x: str(utils.item(x)), alignment)) alignment_str
)) ))
# update running id counter # update running id counter
......
...@@ -157,6 +157,60 @@ def main(args): ...@@ -157,6 +157,60 @@ def main(args):
) )
) )
def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers):
nseq = [0]
def merge_result(worker_result):
nseq[0] += worker_result['nseq']
input_file = input_prefix
offsets = Binarizer.find_offsets(input_file, num_workers)
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(
binarize_alignments,
(
args,
input_file,
utils.parse_alignment,
prefix,
offsets[worker_id],
offsets[worker_id + 1]
),
callback=merge_result
)
pool.close()
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"),
impl=args.dataset_impl)
merge_result(
Binarizer.binarize_alignments(
input_file, utils.parse_alignment, lambda t: ds.add_item(t),
offset=0, end=offsets[1]
)
)
if num_workers > 1:
pool.join()
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
temp_file_path = dataset_dest_prefix(args, prefix, None)
ds.merge_file_(temp_file_path)
os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path))
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
print(
"| [alignments] {}: parsed {} alignments".format(
input_file,
nseq[0]
)
)
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.dataset_impl == "raw": if args.dataset_impl == "raw":
# Copy original text file to destination folder # Copy original text file to destination folder
...@@ -180,9 +234,19 @@ def main(args): ...@@ -180,9 +234,19 @@ def main(args):
outprefix = "test{}".format(k) if k > 0 else "test" outprefix = "test{}".format(k) if k > 0 else "test"
make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers)
def make_all_alignments():
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
make_binary_alignment_dataset(args.trainpref + "." + args.align_suffix, "train.align", num_workers=args.workers)
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
make_binary_alignment_dataset(args.validpref + "." + args.align_suffix, "valid.align", num_workers=args.workers)
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
make_binary_alignment_dataset(args.testpref + "." + args.align_suffix, "test.align", num_workers=args.workers)
make_all(args.source_lang, src_dict) make_all(args.source_lang, src_dict)
if target: if target:
make_all(args.target_lang, tgt_dict) make_all(args.target_lang, tgt_dict)
if args.align_suffix:
make_all_alignments()
print("| Wrote preprocessed data to {}".format(args.destdir)) print("| Wrote preprocessed data to {}".format(args.destdir))
...@@ -242,11 +306,28 @@ def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos ...@@ -242,11 +306,28 @@ def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos
return res return res
def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end):
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"),
impl=args.dataset_impl, vocab_size=None)
def consumer(tensor):
ds.add_item(tensor)
res = Binarizer.binarize_alignments(filename, parse_alignment, consumer, offset=offset,
end=end)
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
return res
def dataset_dest_prefix(args, output_prefix, lang): def dataset_dest_prefix(args, output_prefix, lang):
base = "{}/{}".format(args.destdir, output_prefix) base = "{}/{}".format(args.destdir, output_prefix)
lang_part = ( if lang is not None:
".{}-{}.{}".format(args.source_lang, args.target_lang, lang) if lang is not None else "" lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang)
) elif args.only_source:
lang_part = ""
else:
lang_part = ".{}-{}".format(args.source_lang, args.target_lang)
return "{}{}".format(base, lang_part) return "{}{}".format(base, lang_part)
......
...@@ -266,6 +266,27 @@ class TestTranslation(unittest.TestCase): ...@@ -266,6 +266,27 @@ class TestTranslation(unittest.TestCase):
'--gen-expert', '0' '--gen-expert', '0'
]) ])
def test_alignment(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_alignment') as data_dir:
create_dummy_data(data_dir, alignment=True)
preprocess_translation_data(data_dir, ['--align-suffix', 'align'])
train_translation_model(
data_dir,
'transformer_align',
[
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--load-alignments',
'--alignment-layer', '1',
'--criterion', 'label_smoothed_cross_entropy_with_alignment'
],
run_validation=True,
)
generate_main(data_dir)
class TestStories(unittest.TestCase): class TestStories(unittest.TestCase):
...@@ -484,7 +505,7 @@ class TestCommonOptions(unittest.TestCase): ...@@ -484,7 +505,7 @@ class TestCommonOptions(unittest.TestCase):
generate_main(data_dir) generate_main(data_dir)
def create_dummy_data(data_dir, num_examples=1000, maxlen=20): def create_dummy_data(data_dir, num_examples=1000, maxlen=20, alignment=False):
def _create_dummy_data(filename): def _create_dummy_data(filename):
data = torch.rand(num_examples * maxlen) data = torch.rand(num_examples * maxlen)
...@@ -497,6 +518,20 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20): ...@@ -497,6 +518,20 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
print(ex_str, file=h) print(ex_str, file=h)
offset += ex_len offset += ex_len
def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
with open(os.path.join(data_dir, filename_src), 'r') as src_f, \
open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \
open(os.path.join(data_dir, filename), 'w') as h:
for src, tgt in zip(src_f, tgt_f):
src_len = len(src.split())
tgt_len = len(tgt.split())
avg_len = (src_len + tgt_len) // 2
num_alignments = random.randint(avg_len // 2, 2 * avg_len)
src_indices = torch.floor(torch.rand(num_alignments) * src_len).int()
tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int()
ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)])
print(ex_str, file=h)
_create_dummy_data('train.in') _create_dummy_data('train.in')
_create_dummy_data('train.out') _create_dummy_data('train.out')
_create_dummy_data('valid.in') _create_dummy_data('valid.in')
...@@ -504,6 +539,10 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20): ...@@ -504,6 +539,10 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
_create_dummy_data('test.in') _create_dummy_data('test.in')
_create_dummy_data('test.out') _create_dummy_data('test.out')
if alignment:
_create_dummy_alignment_data('train.in', 'train.out', 'train.align')
_create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align')
_create_dummy_alignment_data('test.in', 'test.out', 'test.align')
def preprocess_translation_data(data_dir, extra_flags=None): def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = options.get_preprocessing_parser() preprocess_parser = options.get_preprocessing_parser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment