Commit 891f2307 authored by Mehdi Drissi's avatar Mehdi Drissi
Browse files

Merge branch 'master' of https://github.com/pytorch/fairseq into minor_fixes

parents 762956a5 b458977a
...@@ -34,13 +34,17 @@ def main(args): ...@@ -34,13 +34,17 @@ def main(args):
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models: for model in models:
model.make_generation_fast_() model.make_generation_fast_()
if args.fp16:
model.half()
itr = data.EpochBatchIterator( itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences or 4, max_sentences=args.max_sentences or 4,
max_positions=model.max_positions(), max_positions=model.max_positions(),
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
ignore_invalid_inputs=True,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
......
*/*
!*/*.sh
!*/*.md
Sample data processing scripts for the FAIR Sequence-to-Sequence Toolkit # Example usage for Neural Machine Translation
These scripts provide an example of pre-processing data for the NMT task. These scripts provide an example of pre-processing data for the NMT task
and instructions for how to replicate the results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187).
# prepare-iwslt14.sh ## Preprocessing
### prepare-iwslt14.sh
Provides an example of pre-processing for IWSLT'14 German to English translation task: ["Report on the 11th IWSLT evaluation campaign" by Cettolo et al.](http://workshop2014.iwslt.org/downloads/proceeding.pdf) Provides an example of pre-processing for IWSLT'14 German to English translation task: ["Report on the 11th IWSLT evaluation campaign" by Cettolo et al.](http://workshop2014.iwslt.org/downloads/proceeding.pdf)
...@@ -34,7 +37,7 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \ ...@@ -34,7 +37,7 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \
``` ```
# prepare-wmt14en2de.sh ### prepare-wmt14en2de.sh
Provides an example of pre-processing for the WMT'14 English to German translation task. By default it will produce a dataset that was modeled after ["Attention Is All You Need" by Vaswani et al.](https://arxiv.org/abs/1706.03762) that includes news-commentary-v12 data. Provides an example of pre-processing for the WMT'14 English to German translation task. By default it will produce a dataset that was modeled after ["Attention Is All You Need" by Vaswani et al.](https://arxiv.org/abs/1706.03762) that includes news-commentary-v12 data.
...@@ -52,7 +55,7 @@ $ bash prepare-wmt14en2de.sh ...@@ -52,7 +55,7 @@ $ bash prepare-wmt14en2de.sh
$ cd ../.. $ cd ../..
# Binarize the dataset: # Binarize the dataset:
$ TEXT=data/wmt14_en_de $ TEXT=examples/translation/wmt14_en_de
$ python preprocess.py --source-lang en --target-lang de \ $ python preprocess.py --source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_de --thresholdtgt 0 --thresholdsrc 0 --destdir data-bin/wmt14_en_de --thresholdtgt 0 --thresholdsrc 0
...@@ -72,7 +75,7 @@ $ python generate.py data-bin/wmt14_en_de \ ...@@ -72,7 +75,7 @@ $ python generate.py data-bin/wmt14_en_de \
``` ```
# prepare-wmt14en2fr.sh ### prepare-wmt14en2fr.sh
Provides an example of pre-processing for the WMT'14 English to French translation task. Provides an example of pre-processing for the WMT'14 English to French translation task.
...@@ -84,7 +87,7 @@ $ bash prepare-wmt14en2fr.sh ...@@ -84,7 +87,7 @@ $ bash prepare-wmt14en2fr.sh
$ cd ../.. $ cd ../..
# Binarize the dataset: # Binarize the dataset:
$ TEXT=data/wmt14_en_fr $ TEXT=examples/translation/wmt14_en_fr
$ python preprocess.py --source-lang en --target-lang fr \ $ python preprocess.py --source-lang en --target-lang fr \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0 --destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0
...@@ -103,3 +106,39 @@ $ python generate.py data-bin/fconv_wmt_en_fr \ ...@@ -103,3 +106,39 @@ $ python generate.py data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
``` ```
## Replicating results from "Scaling Neural Machine Translation"
To replicate results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187):
1. Prepare the WMT'14 En-De data with a BPE vocab of 32k:
```
$ BPE_TOKENS=32764 bash prepare-wmt14en2de.sh
$ cd ../..
```
2. Preprocess the dataset with a joined dictionary:
```
$ TEXT=examples/translation/wmt14_en_de
$ python preprocess.py --source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_de_joined_dict \
--nwordssrc 32768 --nwordstgt 32768 \
--joined-dictionary
```
3. Train a model:
```
$ python train.py data-bin/wmt14_en_de_joined_dict \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0005 --min-lr 1e-09 \
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 3584 \
--fp16
```
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU.
If you want to train the above model with big batches (assuming your machine has 8 GPUs):
- add `--update-freq 16` to simulate training on 8*16=128 GPUs
- increase the learning rate; 0.001 works well for big batches
...@@ -13,7 +13,7 @@ CLEAN=$SCRIPTS/training/clean-corpus-n.perl ...@@ -13,7 +13,7 @@ CLEAN=$SCRIPTS/training/clean-corpus-n.perl
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
BPEROOT=subword-nmt BPEROOT=subword-nmt
BPE_TOKENS=40000 BPE_TOKENS="${BPE_TOKENS:-40000}"
URLS=( URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
......
...@@ -72,7 +72,7 @@ class MonolingualDataset(FairseqDataset): ...@@ -72,7 +72,7 @@ class MonolingualDataset(FairseqDataset):
order = [np.random.permutation(len(self))] order = [np.random.permutation(len(self))]
else: else:
order = [np.arange(len(self))] order = [np.arange(len(self))]
order.append(self.sizes) order.append(np.flip(self.sizes, 0))
return np.lexsort(order) return np.lexsort(order)
def valid_size(self, index, max_positions): def valid_size(self, index, max_positions):
......
...@@ -130,6 +130,12 @@ class FP16Trainer(Trainer): ...@@ -130,6 +130,12 @@ class FP16Trainer(Trainer):
overflow = DynamicLossScaler.has_overflow(grad_norm) overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow) self.scaler.update_scale(overflow)
if overflow: if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
raise Exception((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm return grad_norm
......
...@@ -26,6 +26,12 @@ class CompositeEncoder(FairseqEncoder): ...@@ -26,6 +26,12 @@ class CompositeEncoder(FairseqEncoder):
encoder_out[key] = self.encoders[key](src_tokens, src_lengths) encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
return encoder_out return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
for key in self.encoders:
encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order)
return encoder_out
def max_positions(self): def max_positions(self):
return min([self.encoders[key].max_positions() for key in self.encoders]) return min([self.encoders[key].max_positions() for key in self.encoders])
......
...@@ -18,6 +18,10 @@ class FairseqEncoder(nn.Module): ...@@ -18,6 +18,10 @@ class FairseqEncoder(nn.Module):
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
raise NotImplementedError raise NotImplementedError
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
raise NotImplementedError
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
raise NotImplementedError raise NotImplementedError
......
...@@ -32,9 +32,6 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -32,9 +32,6 @@ class FairseqIncrementalDecoder(FairseqDecoder):
) )
self.apply(apply_reorder_incremental_state) self.apply(apply_reorder_incremental_state)
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children.""" """Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size: if getattr(self, '_beam_size', -1) != beam_size:
......
...@@ -268,6 +268,17 @@ class FConvEncoder(FairseqEncoder): ...@@ -268,6 +268,17 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
} }
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = (
encoder_out_dict['encoder_out'][0].index_select(0, new_order),
encoder_out_dict['encoder_out'][1].index_select(0, new_order),
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -496,12 +507,6 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -496,12 +507,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out) encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out) utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf') return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf')
......
...@@ -226,6 +226,19 @@ class FConvEncoder(FairseqEncoder): ...@@ -226,6 +226,19 @@ class FConvEncoder(FairseqEncoder):
'encoder_out': (x, y), 'encoder_out': (x, y),
} }
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder_out'] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder_out']
)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -409,30 +422,12 @@ class FConvDecoder(FairseqDecoder): ...@@ -409,30 +422,12 @@ class FConvDecoder(FairseqDecoder):
else: else:
return x, avg_attn_scores return x, avg_attn_scores
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(incremental_state, new_order)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder']['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder']['encoder_out']
)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
def _split_encoder_out(self, encoder_out): def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs. """Split and transpose encoder outputs."""
"""
# transpose only once to speed up attention layers # transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(0, 1).contiguous() encoder_a = encoder_a.transpose(0, 1).contiguous()
......
...@@ -197,6 +197,16 @@ class LSTMEncoder(FairseqEncoder): ...@@ -197,6 +197,16 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
} }
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
...@@ -366,16 +376,6 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -366,16 +376,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
new_state = tuple(map(reorder_state, cached_state)) new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
......
...@@ -11,6 +11,8 @@ import torch ...@@ -11,6 +11,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
...@@ -36,6 +38,8 @@ class TransformerModel(FairseqModel): ...@@ -36,6 +38,8 @@ class TransformerModel(FairseqModel):
help='dropout probability for attention weights') help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D', parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN') help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension') help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
...@@ -48,6 +52,8 @@ class TransformerModel(FairseqModel): ...@@ -48,6 +52,8 @@ class TransformerModel(FairseqModel):
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', default=False, action='store_true', parser.add_argument('--encoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the encoder') help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension') help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
...@@ -69,12 +75,20 @@ class TransformerModel(FairseqModel): ...@@ -69,12 +75,20 @@ class TransformerModel(FairseqModel):
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim): def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx) emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
if args.share_all_embeddings: if args.share_all_embeddings:
if src_dict != tgt_dict: if src_dict != tgt_dict:
...@@ -82,12 +96,21 @@ class TransformerModel(FairseqModel): ...@@ -82,12 +96,21 @@ class TransformerModel(FairseqModel):
if args.encoder_embed_dim != args.decoder_embed_dim: if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError( raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim) if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True args.share_decoder_input_output_embed = True
else: else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim) encoder_embed_tokens = build_embedding(
decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim) src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens) encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens) decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
...@@ -141,6 +164,15 @@ class TransformerEncoder(FairseqEncoder): ...@@ -141,6 +164,15 @@ class TransformerEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
} }
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = \
encoder_out_dict['encoder_out'].index_select(1, new_order)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -222,12 +254,6 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -222,12 +254,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return x, attn return x, attn
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -385,16 +411,18 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, le ...@@ -385,16 +411,18 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, le
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
else: else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings) m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings)
return m return m
@register_model_architecture('transformer', 'transformer') @register_model_architecture('transformer', 'transformer')
def base_architecture(args): def base_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 6) args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6) args.decoder_layers = getattr(args, 'decoder_layers', 6)
......
...@@ -24,7 +24,7 @@ class MultiheadAttention(nn.Module): ...@@ -24,7 +24,7 @@ class MultiheadAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self._mask = None self._mask = None
......
...@@ -56,12 +56,12 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -56,12 +56,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
# recompute/expand embeddings if needed # recompute/expand embeddings if needed
bsz, seq_len = input.size() bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weights.size(0): if self.weights is None or max_pos > self.weights.size(0):
self.weights = SinusoidalPositionalEmbedding.get_embedding( self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, max_pos,
self.embedding_dim, self.embedding_dim,
self.padding_idx, self.padding_idx,
).type_as(self.weights) )
self.weights = self.weights.type_as(self._float_tensor) self.weights = self.weights.type_as(self._float_tensor)
if incremental_state is not None: if incremental_state is not None:
...@@ -69,7 +69,7 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -69,7 +69,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1) return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad) positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1) return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self): def max_positions(self):
"""Maximum number of supported positions.""" """Maximum number of supported positions."""
......
...@@ -117,6 +117,7 @@ def get_parser(desc, default_task='translation'): ...@@ -117,6 +117,7 @@ def get_parser(desc, default_task='translation'):
choices=['json', 'none', 'simple', 'tqdm']) choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
parser.add_argument('--fp16', action='store_true', help='use FP16')
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
parser.add_argument( parser.add_argument(
...@@ -187,8 +188,6 @@ def add_optimization_args(parser): ...@@ -187,8 +188,6 @@ def add_optimization_args(parser):
' (default is to normalize by number of tokens)') ' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N', group.add_argument('--update-freq', default='1', metavar='N',
help='update parameters every N_i batches, when in epoch i') help='update parameters every N_i batches, when in epoch i')
group.add_argument('--fp16', action='store_true',
help='use FP16 during training')
# Optimizer definitions can be found under fairseq/optim/ # Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT', group.add_argument('--optimizer', default='nag', metavar='OPT',
...@@ -210,6 +209,8 @@ def add_optimization_args(parser): ...@@ -210,6 +209,8 @@ def add_optimization_args(parser):
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)') help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR', group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum loss scale (for FP16 training)')
return group return group
......
...@@ -268,7 +268,7 @@ class SequenceGenerator(object): ...@@ -268,7 +268,7 @@ class SequenceGenerator(object):
for i, model in enumerate(self.models): for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state) model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state) encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode( probs, avg_attn_scores = self._decode(
tokens[:, :step + 1], encoder_outs, incremental_states) tokens[:, :step + 1], encoder_outs, incremental_states)
......
...@@ -256,7 +256,7 @@ def parse_embedding(embed_path): ...@@ -256,7 +256,7 @@ def parse_embedding(embed_path):
with open(embed_path) as f_embed: with open(embed_path) as f_embed:
next(f_embed) # skip header next(f_embed) # skip header
for line in f_embed: for line in f_embed:
pieces = line.strip().split() pieces = line.rstrip().split(" ")
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]]) embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
return embed_dict return embed_dict
......
...@@ -43,6 +43,8 @@ def main(args): ...@@ -43,6 +43,8 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam) model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
if args.fp16:
model.half()
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
......
...@@ -82,9 +82,9 @@ def main(args): ...@@ -82,9 +82,9 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_( model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, if args.fp16:
) model.half()
# Initialize generator # Initialize generator
translator = SequenceGenerator( translator = SequenceGenerator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment