Commit fa5222c2 authored by thomwolf's avatar thomwolf
Browse files

update readme

parent ab90d4cd
# PyTorch Pretrained Bert # PyTorch Pretrained Bert - PyTorch Pretrained OpenAI GPT
[![CircleCI](https://circleci.com/gh/huggingface/pytorch-pretrained-BERT.svg?style=svg)](https://circleci.com/gh/huggingface/pytorch-pretrained-BERT) [![CircleCI](https://circleci.com/gh/huggingface/pytorch-pretrained-BERT.svg?style=svg)](https://circleci.com/gh/huggingface/pytorch-pretrained-BERT)
This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) and of [OpenAI's TensorFlow repository for the OpenAI GPT model](https://github.com/openai/finetune-transformer-lm)
This implementation is provided with [Google's pre-trained models](https://github.com/google-research/bert), examples, notebooks and a command-line interface to load any pre-trained TensorFlow checkpoint for BERT is also provided. BERT that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
This PyTorch implementation of BERT is provided with [Google's pre-trained models](https://github.com/google-research/bert), examples, notebooks and a command-line interface to load any pre-trained TensorFlow checkpoint for BERT is also provided.
OpenAI GPT that was released together with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
This PyTorch implementation of OpenAI GPT is provided with [OpenAI's pre-trained model](https://github.com/openai/finetune-transformer-lm) and a command-line interface that was used to convert the pre-trained NumPy checkpoint in the provided PyTorch model.
## Content ## Content
...@@ -58,17 +62,31 @@ This package comprises the following classes that can be imported in Python and ...@@ -58,17 +62,31 @@ This package comprises the following classes that can be imported in Python and
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**), - [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**). - [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file): - Three PyTorch models (`torch.nn.Module`) for OpenAI with pre-trained weights (in the [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py) file):
- [`OpenAIGPTModel`](./pytorch_pretrained_bert/modeling_openai.py#L537) - raw OpenAI GPT Transformer model (**fully pre-trained**),
- [`OpenAIGPTLMHeadModel`](./pytorch_pretrained_bert/modeling_openai.py#L691) - OpenAI GPT Transformer with the tied language modeling head on top (**fully pre-trained**),
- [`OpenAIGPTDoubleHeadsModel`](./pytorch_pretrained_bert/modeling_openai.py#L752) - OpenAI GPT Transformer with the tied language modeling head and a multiple choice classification head on top (OpenAI GPT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
- Three tokenizers for BERT (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.), - `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
- `WordpieceTokenizer` - WordPiece tokenization, - `WordpieceTokenizer` - WordPiece tokenization,
- `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization. - `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
- One optimizer (in the [`optimization.py`](./pytorch_pretrained_bert/optimization.py) file): - One tokenizers for OpenAI GPT (in the [`tokenization_openai.py`](./pytorch_pretrained_bert/tokenization_openai.py) file):
- `OpenAIGPTTokenizer` - perform Byte-Pair-Encoding (BPE) tokenization,
- One optimizer for BERT (in the [`optimization.py`](./pytorch_pretrained_bert/optimization.py) file):
- `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate. - `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
- A configuration class (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file): - One optimizer for OpenAI GPT (in the [`optimization_openai.py`](./pytorch_pretrained_bert/optimization_openai.py) file):
- `OpenAIGPTAdam` - OpenAI GPT version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
- A configuration class for BERT (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
- `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files. - `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files.
- A configuration class for OpenAI GPT (in the [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py) file):
- `OpenAIGPTConfig` - Configuration class to store the configuration of a `OpenAIGPTModel` with utilities to read and write from JSON configuration files.
The repository further comprises: The repository further comprises:
- Five examples on how to use Bert (in the [`examples` folder](./examples)): - Five examples on how to use Bert (in the [`examples` folder](./examples)):
...@@ -87,12 +105,14 @@ The repository further comprises: ...@@ -87,12 +105,14 @@ The repository further comprises:
These notebooks are detailed in the [Notebooks](#notebooks) section of this readme. These notebooks are detailed in the [Notebooks](#notebooks) section of this readme.
- A command-line interface to convert any TensorFlow checkpoint in a PyTorch dump: - A command-line interface to convert any TensorFlow checkpoint (BERT) and NumPy checkpoint (OpenAI) in a PyTorch dump:
This CLI is detailed in the [Command-line interface](#Command-line-interface) section of this readme. This CLI is detailed in the [Command-line interface](#Command-line-interface) section of this readme.
## Usage ## Usage
### BERT
Here is a quick-start example using `BertTokenizer`, `BertModel` and `BertForMaskedLM` class with Google AI's pre-trained `Bert base uncased` model. See the [doc section](#doc) below for all the details on these classes. Here is a quick-start example using `BertTokenizer`, `BertModel` and `BertForMaskedLM` class with Google AI's pre-trained `Bert base uncased` model. See the [doc section](#doc) below for all the details on these classes.
First let's prepare a tokenized input with `BertTokenizer` First let's prepare a tokenized input with `BertTokenizer`
...@@ -152,20 +172,70 @@ predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] ...@@ -152,20 +172,70 @@ predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson' assert predicted_token == 'henson'
``` ```
### OpenAI GPT
Here is a quick-start example using `OpenAIGPTTokenizer`, `OpenAIGPTModel` and `OpenAIGPTLMHeadModel` class with OpenAI's pre-trained model. See the [doc section](#doc) below for all the details on these classes.
First let's prepare a tokenized input with `OpenAIGPTTokenizer`
```python
import torch
from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel
# Load pre-trained model tokenizer (vocabulary)
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
# Tokenized input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
tokenized_text = tokenizer.tokenize(text)
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
```
Let's see how to use `OpenAIGPTModel` to get hidden states
```python
# Load pre-trained model (weights)
model = OpenAIGPTModel.from_pretrained('openai-gpt')
model.eval()
# Predict hidden states features for each layer
hidden_states = model(tokens_tensor, segments_tensors)
```
And how to use `OpenAIGPTLMHeadModel`
```python
# Load pre-trained model (weights)
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
model.eval()
# Predict all tokens
predictions = model(tokens_tensor)
# get the predicted last token
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
```
## Doc ## Doc
Here is a detailed documentation of the classes in the package and how to use them: Here is a detailed documentation of the classes in the package and how to use them:
| Sub-section | Description | | Sub-section | Description |
|-|-| |-|-|
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance | | [Loading Google AI's/OpenAI's pre-trained weigths](#Loading-Google-AI-or-OpenAI-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance |
| [PyTorch models](#PyTorch-models) | API of the eight PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering` | | [PyTorch models](#PyTorch-models) | API of the eight PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering` |
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class| | [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class | | [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
### Loading Google AI's pre-trained weigths and PyTorch dump ### Loading Google AI or OpenAI pre-trained weigths or PyTorch dump
To load one of Google AI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as To load one of Google AI's, OpenAI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as
```python ```python
model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None) model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
...@@ -173,10 +243,10 @@ model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=Non ...@@ -173,10 +243,10 @@ model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=Non
where where
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the eight PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering`, and - `BERT_CLASS` is either a tokenizer to load the vocabulary (`BertTokenizer` or `OpenAIGPTTokenizer` classes) or one of the eight BERT or three OpenAI GPT PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification`, `BertForMultipleChoice`, `BertForQuestionAnswering`, `OpenAIGPTModel`, `OpenAIGPTLMHeadModel` or `OpenAIGPTDoubleHeadsModel`, and
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either: - `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
- the shortcut name of a Google AI's pre-trained model selected in the list: - the shortcut name of a Google AI's or OpenAI's pre-trained model selected in the list:
- `bert-base-uncased`: 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-base-uncased`: 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-large-uncased`: 24-layer, 1024-hidden, 16-heads, 340M parameters - `bert-large-uncased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
...@@ -185,11 +255,12 @@ where ...@@ -185,11 +255,12 @@ where
- `bert-base-multilingual-uncased`: (Orig, not recommended) 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-base-multilingual-uncased`: (Orig, not recommended) 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
- `openai-gpt`: OpenAI English model, 12-layer, 768-hidden, 12-heads, 110M parameters
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
- `bert_config.json` a configuration file for the model, and - `bert_config.json` or `openai_gpt_config.json` a configuration file for the model, and
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`) - `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BertForPreTraining` or `OpenAIGPTModel` (saved with the usual `torch.save()`)
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`). If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information). - `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information).
...@@ -198,10 +269,15 @@ where ...@@ -198,10 +269,15 @@ where
**When using an `uncased model`, make sure to pass `--do_lower_case` to the example training scripts (or pass `do_lower_case=True` to FullTokenizer if you're using your own script and loading the tokenizer your-self.).** **When using an `uncased model`, make sure to pass `--do_lower_case` to the example training scripts (or pass `do_lower_case=True` to FullTokenizer if you're using your own script and loading the tokenizer your-self.).**
Example: Examples:
```python ```python
# BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# OpenAI GPT
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
model = OpenAIGPTModel.from_pretrained('openai-gpt')
``` ```
### PyTorch models ### PyTorch models
...@@ -311,7 +387,78 @@ The token-level classifier takes as input the full sequence of the last hidden s ...@@ -311,7 +387,78 @@ The token-level classifier takes as input the full sequence of the last hidden s
An example on how to use this class is given in the [`run_squad.py`](./examples/run_squad.py) script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task. An example on how to use this class is given in the [`run_squad.py`](./examples/run_squad.py) script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
### Tokenizer: `BertTokenizer` #### 9. `OpenAIGPTModel`
`OpenAIGPTModel` is the basic OpenAI GPT Transformer model with a layer of summed token and position embeddings followed by a series of 12 identical self-attention blocks.
The main implementation difference between BERT and the OpenAI is the use, in OpenAI GPT, of a single embedding matrix to store the word, special (`[SEP]`, `[CLS]`...) token and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
You should use the associate indices to index the embeddings.
The special tokens embeddings (`[SEP]`, `[CLS]`...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
We detail them here. This model takes as *inputs*:
[`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py)
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
- `position_ids`: an optional torch.LongTensor with the same shape as input_ids with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids. You can use it to add a third embedding (the previous two being the word and position embeddings) to each token in the sentence.
This model *outputs*:
- `hidden_states`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
#### 10. `OpenAIGPTLMHeadModel`
`OpenAIGPTLMHeadModel` includes the `OpenAIGPTModel` Transformer followed by a language modeling head with weights tied to the input embeddings (no additional parameters).
*Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus optional labels:
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
*Outputs*:
- if `lm_labels` is not `None`:
Outputs the language modeling loss.
- else:
Outputs `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_num_embeddings] (or more generally [d_1, ..., d_n, total_num_embeddings] were d_1 ... d_n are the dimension of input_ids)
#### 11. `OpenAIGPTDoubleHeadsModel`
`OpenAIGPTDoubleHeadsModel` includes the `OpenAIGPTModel` Transformer followed by two heads:
- a language modeling head with weights tied to the input embeddings (no additional parameters) and:
- a multiple choice classifier (linear layer).
*Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus a classification mask and two optional labels:
- `multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
- `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices].
*Outputs*:
- if `lm_labels` and `multiple_choice_labels` are not `None`:
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
- else Outputs a tuple with:
- `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_num_embeddings]
- `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
### Tokenizers:
#### `BertTokenizer`
`BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization. `BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
...@@ -328,7 +475,26 @@ and three methods: ...@@ -328,7 +475,26 @@ and three methods:
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing. Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
### Optimizer: `BertAdam` #### `OpenAIGPTTokenizer`
`OpenAIGPTTokenizer` perform Byte-Pair-Encoding (BPE) tokenization.
This class has one arguments:
- `vocab_file`: path to a vocabulary file.
- `merges_file`: path to a file containing the BPE merges.
and three methods:
- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization.
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch_pretrained_bert/tokenization_openai.py) for the details of the `OpenAIGPTTokenizer`.
### Optimizers:
#### `BertAdam`
`BertAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following: `BertAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following:
...@@ -348,6 +514,13 @@ The optimizer accepts the following arguments: ...@@ -348,6 +514,13 @@ The optimizer accepts the following arguments:
- `weight_decay:` Weight decay. Default : `0.01` - `weight_decay:` Weight decay. Default : `0.01`
- `max_grad_norm` : Maximum norm for the gradients (`-1` means no clipping). Default : `1.0` - `max_grad_norm` : Maximum norm for the gradients (`-1` means no clipping). Default : `1.0`
#### `OpenAIGPTAdam`
`OpenAIGPTAdam` is similar to `BertAdam`.
The differences with `BertAdam` is that `OpenAIGPTAdam` compensate for bias as in the regular Adam optimizer.
`OpenAIGPTAdam` accepts the same arguments as `BertAdam`.
## Examples ## Examples
| Sub-section | Description | | Sub-section | Description |
...@@ -587,7 +760,9 @@ Please follow the instructions given in the notebooks to run and modify them. ...@@ -587,7 +760,9 @@ Please follow the instructions given in the notebooks to run and modify them.
## Command-line interface ## Command-line interface
A command-line interface is provided to convert a TensorFlow checkpoint in a PyTorch dump of the `BertForPreTraining` class (see above). A command-line interface is provided to convert a TensorFlow checkpoint in a PyTorch dump of the `BertForPreTraining` class (for BERT) or NumPy checkpoint in a PyTorch dump of the `OpenAIGPTModel` class (for OpenAI GPT).
### BERT
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`./pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script. You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`./pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
...@@ -610,6 +785,19 @@ pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \ ...@@ -610,6 +785,19 @@ pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models). You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
### OpenAI GPT
Here is an example of the conversion process for a pre-trained OpenAI GPT model, assuming that your NumPy checkpoit save as the same format than OpenAI pretrained model (see [here](https://github.com/openai/finetune-transformer-lm))
```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
pytorch_pretrained_bert convert_openai_checkpoint \
$OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
$PYTORCH_DUMP_OUTPUT \
[OPENAI_GPT_CONFIG]
```
## TPU ## TPU
TPU support and pretraining scripts TPU support and pretraining scripts
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
" Run OpenAI GPT on RocStories"
import argparse
import os
import random
import logging
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
# from analysis import rocstories as rocstories_analysis
# from datasets import rocstories
# from model_pytorch import DoubleHeadModel, load_openai_pretrained_model
# from opt import OpenAIAdam
# from text_utils import TextEncoder
# from utils import (encode_dataset, iter_data,
# ResultLogger, make_path)
# from loss import MultipleChoiceLossCompute
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer
from pytorch_pretrained_bert.modeling_openai import OpenAIGPTDoubleHeadsModel
from pytorch_pretrained_bert.optimization_openai import OpenAIAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def transform_roc(X1, X2, X3):
n_batch = len(X1)
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
mmb = np.zeros((n_batch, 2, n_ctx), dtype=np.float32)
start = encoder['_start_']
delimiter = encoder['_delimiter_']
for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)):
x12 = [start] + x1[:max_len] + [delimiter] + x2[:max_len] + [clf_token]
x13 = [start] + x1[:max_len] + [delimiter] + x3[:max_len] + [clf_token]
l12 = len(x12)
l13 = len(x13)
xmb[i, 0, :l12, 0] = x12
xmb[i, 1, :l13, 0] = x13
mmb[i, 0, :l12] = 1
mmb[i, 1, :l13] = 1
# Position information that is added to the input embeddings in the TransformerModel
xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
return xmb, mmb
def iter_apply(Xs, Ms, Ys):
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
logits = []
cost = 0
with torch.no_grad():
dh_model.eval()
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
n = len(xmb)
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
_, clf_logits = dh_model(XMB)
clf_logits *= n
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
clf_losses *= n
logits.append(clf_logits.to("cpu").numpy())
cost += clf_losses.sum().item()
logits = np.concatenate(logits, 0)
return logits, cost
def iter_predict(Xs, Ms):
logits = []
with torch.no_grad():
dh_model.eval()
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
n = len(xmb)
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
_, clf_logits = dh_model(XMB)
logits.append(clf_logits.to("cpu").numpy())
logits = np.concatenate(logits, 0)
return logits
def log(save_dir, desc):
global best_score
print("Logging")
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
tr_cost = tr_cost / len(trY[:n_valid])
va_cost = va_cost / n_valid
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1)) * 100.
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1)) * 100.
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
print('%d %d %.3f %.3f %.2f %.2f' % (n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
if submit:
score = va_acc
if score > best_score:
best_score = score
path = os.path.join(save_dir, desc, 'best_params')
torch.save(dh_model.state_dict(), make_path(path))
def predict(dataset, submission_dir):
filename = filenames[dataset]
pred_fn = pred_fns[dataset]
label_decoder = label_decoders[dataset]
predictions = pred_fn(iter_predict(teX, teM))
if label_decoder is not None:
predictions = [label_decoder[prediction] for prediction in predictions]
path = os.path.join(submission_dir, filename)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as f:
f.write('{}\t{}\n'.format('index', 'prediction'))
for i, prediction in enumerate(predictions):
f.write('{}\t{}\n'.format(i, prediction))
def run_epoch():
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
n_batch=n_batch_train, truncate=True, verbose=True):
global n_updates
dh_model.train()
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
lm_logits, clf_logits = dh_model(XMB)
compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits)
n_updates += 1
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
log(save_dir, desc)
argmax = lambda x: np.argmax(x, 1)
pred_fns = {
'rocstories': argmax,
}
filenames = {
'rocstories': 'ROCStories.tsv',
}
label_decoders = {
'rocstories': None,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--desc', type=str, help="Description")
parser.add_argument('--dataset', type=str)
parser.add_argument('--log_dir', type=str, default='log/')
parser.add_argument('--save_dir', type=str, default='save/')
parser.add_argument('--data_dir', type=str, default='data/')
parser.add_argument('--submission_dir', type=str, default='submission/')
parser.add_argument('--submit', action='store_true')
parser.add_argument('--analysis', action='store_true')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--n_iter', type=int, default=3)
parser.add_argument('--n_batch', type=int, default=8)
parser.add_argument('--max_grad_norm', type=int, default=1)
parser.add_argument('--lr', type=float, default=6.25e-5)
parser.add_argument('--lr_warmup', type=float, default=0.002)
parser.add_argument('--n_ctx', type=int, default=512)
parser.add_argument('--n_embd', type=int, default=768)
parser.add_argument('--n_head', type=int, default=12)
parser.add_argument('--n_layer', type=int, default=12)
parser.add_argument('--embd_pdrop', type=float, default=0.1)
parser.add_argument('--attn_pdrop', type=float, default=0.1)
parser.add_argument('--resid_pdrop', type=float, default=0.1)
parser.add_argument('--clf_pdrop', type=float, default=0.1)
parser.add_argument('--l2', type=float, default=0.01)
parser.add_argument('--vector_l2', action='store_true')
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--afn', type=str, default='gelu')
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')
parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')
parser.add_argument('--n_transfer', type=int, default=12)
parser.add_argument('--lm_coef', type=float, default=0.5)
parser.add_argument('--b1', type=float, default=0.9)
parser.add_argument('--b2', type=float, default=0.999)
parser.add_argument('--e', type=float, default=1e-8)
parser.add_argument('--n_valid', type=int, default=374)
args = parser.parse_args()
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# Constants
submit = args.submit
dataset = args.dataset
n_ctx = args.n_ctx
save_dir = args.save_dir
desc = args.desc
data_dir = args.data_dir
log_dir = args.log_dir
submission_dir = args.submission_dir
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print("device", device, "n_gpu", n_gpu)
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
encoder = text_encoder.encoder
n_vocab = len(text_encoder.encoder)
print("Encoding dataset...")
((trX1, trX2, trX3, trY),
(vaX1, vaX2, vaX3, vaY),
(teX1, teX2, teX3)) = encode_dataset(*rocstories(data_dir, n_valid=args.n_valid),
encoder=text_encoder)
encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder)
encoder['_classify_'] = len(encoder)
clf_token = encoder['_classify_']
n_special = 3
max_len = n_ctx // 2 - 2
n_ctx = min(max(
[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
) + 3, n_ctx)
vocab = n_vocab + n_special + n_ctx
trX, trM = transform_roc(trX1, trX2, trX3)
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
if submit:
teX, teM = transform_roc(teX1, teX2, teX3)
n_train = len(trY)
n_valid = len(vaY)
n_batch_train = args.n_batch * max(n_gpu, 1)
n_updates_total = (n_train // n_batch_train) * args.n_iter
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)
criterion = nn.CrossEntropyLoss(reduce=False)
model_opt = OpenAIAdam(dh_model.parameters(),
lr=args.lr,
schedule=args.lr_schedule,
warmup=args.lr_warmup,
t_total=n_updates_total,
b1=args.b1,
b2=args.b2,
e=args.e,
l2=args.l2,
vector_l2=args.vector_l2,
max_grad_norm=args.max_grad_norm)
compute_loss_fct = MultipleChoiceLossCompute(criterion,
criterion,
args.lm_coef,
model_opt)
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)
dh_model.to(device)
dh_model = nn.DataParallel(dh_model)
n_updates = 0
n_epochs = 0
if dataset != 'stsb':
trYt = trY
if submit:
path = os.path.join(save_dir, desc, 'best_params')
torch.save(dh_model.state_dict(), make_path(path))
best_score = 0
for i in range(args.n_iter):
print("running epoch", i)
run_epoch()
n_epochs += 1
log(save_dir, desc)
if submit:
path = os.path.join(save_dir, desc, 'best_params')
dh_model.load_state_dict(torch.load(path))
predict(dataset, args.submission_dir)
if args.analysis:
rocstories_analysis(data_dir, os.path.join(args.submission_dir, 'ROCStories.tsv'),
os.path.join(log_dir, 'rocstories.jsonl'))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment