Commit b4d68ff9 authored by kimiyoung's avatar kimiyoung
Browse files

Merge branch 'master' of github.com:kimiyoung/transformer-xl

parents 564c62c6 19a5852f
...@@ -7,18 +7,18 @@ This repository contains the code in both **PyTorch** and **TensorFlow** for our ...@@ -7,18 +7,18 @@ This repository contains the code in both **PyTorch** and **TensorFlow** for our
>Preprint 2018 >Preprint 2018
#### TensorFlow ## TensorFlow
- The source code is in the `tf/` folder, supporting (1) single-node multi-gpu training, and (2) multi-host TPU training. - The source code is in the `tf/` folder, supporting (1) single-node multi-gpu training, and (2) multi-host TPU training.
- Besides the source code, we also provide pretrained "TensorFlow" models with state-of-the-art (SoTA) performances reported in the paper. - Besides the source code, we also provide pretrained "TensorFlow" models with state-of-the-art (SoTA) performances reported in the paper.
- Please refer to `tf/README.md` for details. - Please refer to `tf/README.md` for details.
#### PyTorch ## PyTorch
- The source code is in the `pytorch/` folder, supporting single-node multi-gpu training via the module `nn.DataParallel`. - The source code is in the `pytorch/` folder, supporting single-node multi-gpu training via the module `nn.DataParallel`.
- Please refer to `pytorch/README.md` for details. - Please refer to `pytorch/README.md` for details.
#### Results ## Results
Transformer-XL achieves new state-of-the-art results on multipole language modeling benchmarks. Transformer-XL is also the first to break through the 1.0 barrier on char-level language modeling. Below is a summary. Transformer-XL achieves new state-of-the-art results on multipole language modeling benchmarks. Transformer-XL is also the first to break through the 1.0 barrier on char-level language modeling. Below is a summary.
...@@ -31,4 +31,4 @@ Transformer-XL | **0.99** | **1.08** | **21.8** | **18.3** | **54.5** ...@@ -31,4 +31,4 @@ Transformer-XL | **0.99** | **1.08** | **21.8** | **18.3** | **54.5**
## Acknowledgement ## Acknowledgement
A large portion of the `getdata.sh` script comes from the [awd-lstm](https://github.com/salesforce/awd-lstm-lm/) repo. Happy Language Modeling :) A large portion of the `getdata.sh` script comes from the [awd-lstm](https://github.com/salesforce/awd-lstm-lm/) repo. Happy Language Modeling :)
\ No newline at end of file
#### Introduction ## Introduction
This directory contains our pytorch implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our pytorch codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts: This directory contains our pytorch implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our pytorch codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts:
- `*large.sh` are for the SoTA setting with large models which might not be directly runnable on a local GPU machine. - `*large.sh` are for the SoTA setting with large models which might not be directly runnable on a local GPU machine.
...@@ -7,16 +7,16 @@ This directory contains our pytorch implementation of Transformer-XL. Note that ...@@ -7,16 +7,16 @@ This directory contains our pytorch implementation of Transformer-XL. Note that
The pytorch implementation produces similar results to the TF codebase under the same settings in our preliminary experiments. The pytorch implementation produces similar results to the TF codebase under the same settings in our preliminary experiments.
#### Prerequisite ## Prerequisite
- Pytorch 0.4: `conda install pytorch torchvision -c pytorch` - Pytorch 0.4: `conda install pytorch torchvision -c pytorch`
#### Data Prepration ## Data Prepration
`bash getdata.sh` `bash getdata.sh`
## Training and Evaluation
#### Replicate the "bpc = 1.06" result on `enwik8` with a 12-layer Transformer-XL #### Replicate the "bpc = 1.06" result on `enwik8` with a 12-layer Transformer-XL
...@@ -24,11 +24,11 @@ The pytorch implementation produces similar results to the TF codebase under the ...@@ -24,11 +24,11 @@ The pytorch implementation produces similar results to the TF codebase under the
- Training - Training
`bash run_enwik8.sh train --work_dir PATH_TO_WORK_DIR` `bash run_enwik8_base.sh train --work_dir PATH_TO_WORK_DIR`
- Evaluation - Evaluation
`bash run_enwik8.sh eval --work_dir PATH_TO_WORK_DIR` `bash run_enwik8_base.sh eval --work_dir PATH_TO_WORK_DIR`
...@@ -38,11 +38,11 @@ The pytorch implementation produces similar results to the TF codebase under the ...@@ -38,11 +38,11 @@ The pytorch implementation produces similar results to the TF codebase under the
- Evaluation - Evaluation
`bash run_wt103.sh train --work_dir PATH_TO_WORK_DIR` `bash run_wt103_base.sh train --work_dir PATH_TO_WORK_DIR`
- Testing - Testing
`bash run_wt103.sh eval --work_dir PATH_TO_WORK_DIR` `bash run_wt103_base.sh eval --work_dir PATH_TO_WORK_DIR`
...@@ -53,10 +53,10 @@ The pytorch implementation produces similar results to the TF codebase under the ...@@ -53,10 +53,10 @@ The pytorch implementation produces similar results to the TF codebase under the
- `--fp16` and `--dynamic-loss-scale`: Run in pseudo-fp16 mode (fp16 storage fp32 math) with dynamic loss scaling. - `--fp16` and `--dynamic-loss-scale`: Run in pseudo-fp16 mode (fp16 storage fp32 math) with dynamic loss scaling.
- Note: to explore the `--fp16` option, please make sure the `apex` package is installed (https://github.com/NVIDIA/apex/). - Note: to explore the `--fp16` option, please make sure the `apex` package is installed (https://github.com/NVIDIA/apex/).
- To see performance without the recurrence mechanism, simply use `mem_len=0` in all your scripts. - To see performance without the recurrence mechanism, simply use `mem_len=0` in all your scripts.
- To see performance with a standard Transformer without relative positional encodings and recurrence mechanisms, use `attn_type=2` and `mem_len=0`. - To see performance of a standard Transformer without relative positional encodings or recurrence mechanisms, use `attn_type=2` and `mem_len=0`.
#### Other datasets: #### Other datasets:
- `Text8` character-level language modeling: check out `run_text8.sh` - `Text8` character-level language modeling: check out `run_text8_base.sh`
- `lm1b` word-level language modeling: check out `run_lm1b.sh` - `lm1b` word-level language modeling: check out `run_lm1b_base.sh`
...@@ -80,8 +80,8 @@ We used 32, 32, 64, and 512 TPU cores for training our best models on enwik8, te ...@@ -80,8 +80,8 @@ We used 32, 32, 64, and 512 TPU cores for training our best models on enwik8, te
For `dataset` in `[enwik8, lm1b, wt103, text8]`: For `dataset` in `[enwik8, lm1b, wt103, text8]`:
- check out `scripts/dataset_gpu.sh` for GPU training and evaluation - check out `scripts/dataset_base_gpu.sh` for GPU training and evaluation
- check out `scripts/dataset_tpu.sh` for TPU training and evaluation - check out `scripts/dataset_large_tpu.sh` for TPU training and evaluation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment