Commit c2794070 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Update IWSLT configuration for transformer

parent dbe96371
...@@ -36,6 +36,31 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \ ...@@ -36,6 +36,31 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \
``` ```
To train transformer model on IWSLT'14 German to English:
```
# Preparation steps are the same as for fconv model.
# Train the model (better for a single GPU setup):
$ mkdir -p checkpoints/transformer
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
-a transformer_iwslt_de_en --optimizer adam --lr 0.0005 -s de -t en \
--label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 \
--min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --max-update 50000 \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--adam-betas '(0.9, 0.98)' --save-dir checkpoints/transformer
# Average 10 latest checkpoints:
$ python scripts/average_checkpoints.py --inputs checkpoints/transformer \
--num-epoch-checkpoints 10 --output checkpoints/transformer/model.pt
# Generate:
$ python generate.py data-bin/iwslt14.tokenized.de-en \
--path checkpoints/transformer/model.pt \
--batch-size 128 --beam 5 --remove-bpe
```
### prepare-wmt14en2de.sh ### prepare-wmt14en2de.sh
......
...@@ -588,14 +588,14 @@ def base_architecture(args): ...@@ -588,14 +588,14 @@ def base_architecture(args):
@register_model_architecture('transformer', 'transformer_iwslt_de_en') @register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args): def transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 3) args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512) args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 3) args.decoder_layers = getattr(args, 'decoder_layers', 6)
base_architecture(args) base_architecture(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment