Commit 912a377e authored by thomwolf's avatar thomwolf
Browse files

dilbert -> distilbert

parent c9bce181
...@@ -13,7 +13,7 @@ The library currently contains PyTorch implementations, pre-trained model weight ...@@ -13,7 +13,7 @@ The library currently contains PyTorch implementations, pre-trained model weight
5. **[XLNet](https://github.com/zihangdai/xlnet/)** (from Google/CMU) released with the paper [​XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. 5. **[XLNet](https://github.com/zihangdai/xlnet/)** (from Google/CMU) released with the paper [​XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
6. **[XLM](https://github.com/facebookresearch/XLM/)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 6. **[XLM](https://github.com/facebookresearch/XLM/)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
8. **[DilBERT](https://github.com/huggingface/pytorch-transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the blogpost [Smaller, faster, cheaper, lighter: Introducing DilBERT, a distilled version of BERT](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-dilbert-a-distilled-version-of-bert-8cf3380435b5 8. **[DistilBERT](https://github.com/huggingface/pytorch-transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the blogpost [Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
) by Victor Sanh, Lysandre Debut and Thomas Wolf. ) by Victor Sanh, Lysandre Debut and Thomas Wolf.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/pytorch-transformers/examples.html). These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/pytorch-transformers/examples.html).
......
# DilBERT # DistilBERT
This folder contains the original code used to train DilBERT as well as examples showcasing how to use DilBERT. This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT.
## What is DilBERT ## What is DistilBERT
DilBERT stands for Distillated-BERT. DilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on the GLUE language understanding benchmark. DilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DilBERT is thus an interesting option to put large-scaled trained Transformer model into production. DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
For more information on DilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-dilbert-a-distilled-version-of-bert-8cf3380435b5 For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
). ).
## How to use DilBERT ## How to use DistilBERT
PyTorch-Transformers includes two pre-trained DilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DilBERT): PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
- `dilbert-base-uncased`: DilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters. - `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
- `dilbert-base-uncased-distilled-squad`: A finetuned version of `dilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.2 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score). - `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.2 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
Using DilBERT is very similar to using BERT. DilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DilBertTokenizer` name to have a consistent naming between the library models. Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
```python ```python
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased') tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DilBertModel.from_pretrained('dilbert-base-uncased') model = DistilBertModel.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
outputs = model(input_ids) outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
``` ```
## How to train DilBERT ## How to train DistilBERT
In the following, we will explain how you can train your own compressed model. In the following, we will explain how you can train your own compressed model.
...@@ -68,7 +68,7 @@ python train.py \ ...@@ -68,7 +68,7 @@ python train.py \
By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them. By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them.
We highly encourage you to distributed training for training DilBert as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs: We highly encourage you to distributed training for training DistilBert as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs:
```bash ```bash
export NODE_RANK=0 export NODE_RANK=0
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Dataloaders to train DilBERT. Dataloaders to train DistilBERT.
""" """
from typing import List from typing import List
import math import math
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
The distiller to distil DilBERT. The distiller to distil DistilBERT.
""" """
import os import os
import math import math
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DilBERT. Preprocessing script before training DistilBERT.
""" """
import argparse import argparse
import pickle import pickle
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DilBERT. Preprocessing script before training DistilBERT.
""" """
from pytorch_transformers import BertForPreTraining from pytorch_transformers import BertForPreTraining
import torch import torch
...@@ -33,32 +33,32 @@ if __name__ == '__main__': ...@@ -33,32 +33,32 @@ if __name__ == '__main__':
compressed_sd = {} compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']: for w in ['word_embeddings', 'position_embeddings']:
compressed_sd[f'dilbert.embeddings.{w}.weight'] = \ compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
state_dict[f'bert.embeddings.{w}.weight'] state_dict[f'bert.embeddings.{w}.weight']
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'dilbert.embeddings.LayerNorm.{w}'] = \ compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'bert.embeddings.LayerNorm.{w}'] state_dict[f'bert.embeddings.LayerNorm.{w}']
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}'] state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}'] state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}'] state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}'] state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}']
compressed_sd[f'dilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1 std_idx += 1
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DilBERT. Preprocessing script before training DistilBERT.
""" """
from collections import Counter from collections import Counter
import argparse import argparse
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Training DilBERT. Training DistilBERT.
""" """
import os import os
import argparse import argparse
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
import torch import torch
from pytorch_transformers import BertTokenizer, BertForMaskedLM from pytorch_transformers import BertTokenizer, BertForMaskedLM
from pytorch_transformers import DilBertForMaskedLM, DilBertConfig from pytorch_transformers import DistilBertForMaskedLM, DistilBertConfig
from distiller import Distiller from distiller import Distiller
from utils import git_log, logger, init_gpu_params, set_seed from utils import git_log, logger, init_gpu_params, set_seed
...@@ -201,13 +201,13 @@ def main(): ...@@ -201,13 +201,13 @@ def main():
assert os.path.isfile(os.path.join(args.from_pretrained_config)) assert os.path.isfile(os.path.join(args.from_pretrained_config))
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}') logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DilBertConfig.from_json_file(args.from_pretrained_config) stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
student = DilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
config=stu_architecture_config) config=stu_architecture_config)
else: else:
args.vocab_size_or_config_json_file = args.vocab_size args.vocab_size_or_config_json_file = args.vocab_size
stu_architecture_config = DilBertConfig(**vars(args)) stu_architecture_config = DistilBertConfig(**vars(args))
student = DilBertForMaskedLM(stu_architecture_config) student = DistilBertForMaskedLM(stu_architecture_config)
if args.n_gpu > 0: if args.n_gpu > 0:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Utils to train DilBERT. Utils to train DistilBERT.
""" """
import git import git
import json import json
......
...@@ -7,7 +7,7 @@ from .tokenization_gpt2 import GPT2Tokenizer ...@@ -7,7 +7,7 @@ from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer from .tokenization_roberta import RobertaTokenizer
from .tokenization_dilbert import DilBertTokenizer from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_utils import (PreTrainedTokenizer)
...@@ -41,9 +41,9 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel, ...@@ -41,9 +41,9 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP) XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_dilbert import (DilBertConfig, DilBertForMaskedLM, DilBertModel, from .modeling_distilbert import (DistilBertConfig, DistilBertForMaskedLM, DistilBertModel,
DilBertForSequenceClassification, DilBertForQuestionAnswering, DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
......
...@@ -30,7 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel ...@@ -30,7 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
from .modeling_xlnet import XLNetConfig, XLNetModel from .modeling_xlnet import XLNetConfig, XLNetModel
from .modeling_xlm import XLMConfig, XLMModel from .modeling_xlm import XLMConfig, XLMModel
from .modeling_roberta import RobertaConfig, RobertaModel from .modeling_roberta import RobertaConfig, RobertaModel
from .modeling_dilbert import DilBertConfig, DilBertModel from .modeling_distilbert import DistilBertConfig, DistilBertModel
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
...@@ -111,8 +111,8 @@ class AutoConfig(object): ...@@ -111,8 +111,8 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
if 'dilbert' in pretrained_model_name_or_path: if 'distilbert' in pretrained_model_name_or_path:
return DilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
...@@ -228,8 +228,8 @@ class AutoModel(object): ...@@ -228,8 +228,8 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 'dilbert' in pretrained_model_name_or_path: if 'distilbert' in pretrained_model_name_or_path:
return DilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
PyTorch DilBERT model. PyTorch DistilBERT model.
""" """
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
...@@ -36,19 +36,19 @@ import logging ...@@ -36,19 +36,19 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'dilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/dilbert-base-uncased-pytorch_model.bin", 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-pytorch_model.bin",
'dilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/dilbert-base-uncased-distilled-squad-pytorch_model.bin" 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-pytorch_model.bin"
} }
DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'dilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/dilbert-base-uncased-config.json", 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
'dilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/dilbert-base-uncased-distilled-squad-config.json" 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json"
} }
class DilBertConfig(PretrainedConfig): class DistilBertConfig(PretrainedConfig):
pretrained_config_archive_map = DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=30522, vocab_size_or_config_json_file=30522,
...@@ -66,7 +66,7 @@ class DilBertConfig(PretrainedConfig): ...@@ -66,7 +66,7 @@ class DilBertConfig(PretrainedConfig):
qa_dropout=0.1, qa_dropout=0.1,
seq_classif_dropout=0.2, seq_classif_dropout=0.2,
**kwargs): **kwargs):
super(DilBertConfig, self).__init__(**kwargs) super(DistilBertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -398,17 +398,17 @@ class Transformer(nn.Module): ...@@ -398,17 +398,17 @@ class Transformer(nn.Module):
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ### ### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
class DilBertPreTrainedModel(PreTrainedModel): class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = DilBertConfig config_class = DistilBertConfig
pretrained_model_archive_map = DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None load_tf_weights = None
base_model_prefix = "dilbert" base_model_prefix = "distilbert"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super(DilBertPreTrainedModel, self).__init__(*inputs, **kwargs) super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
...@@ -425,36 +425,36 @@ class DilBertPreTrainedModel(PreTrainedModel): ...@@ -425,36 +425,36 @@ class DilBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
DILBERT_START_DOCSTRING = r""" DISTILBERT_START_DOCSTRING = r"""
DilBERT is a small, fast, cheap and light Transformer model DistilBERT is a small, fast, cheap and light Transformer model
trained by distilling Bert base. It has 40% less parameters than trained by distilling Bert base. It has 40% less parameters than
`bert-base-uncased`, runs 60% faster while preserving over 95% of `bert-base-uncased`, runs 60% faster while preserving over 95% of
Bert's performances as measured on the GLUE language understanding benchmark. Bert's performances as measured on the GLUE language understanding benchmark.
Here are the differences between the interface of Bert and DilBert: Here are the differences between the interface of Bert and DistilBert:
- DilBert doesn't have `token_type_ids`, you don't need to indicate which token belong to which segment. Just separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`) - DistilBert doesn't have `token_type_ids`, you don't need to indicate which token belong to which segment. Just separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`)
- DilBert doesn't have options to select the input positions (`position_ids` input). This could be added if necessary though, just let's us know if you need this option. - DistilBert doesn't have options to select the input positions (`position_ids` input). This could be added if necessary though, just let's us know if you need this option.
For more information on DilBERT, please refer to our For more information on DistilBERT, please refer to our
`detailed blog post`_ `detailed blog post`_
.. _`detailed blog post`: .. _`detailed blog post`:
https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-dilbert-a-distilled-version-of-bert-8cf3380435b5 https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
Parameters: Parameters:
config (:class:`~pytorch_transformers.DilBertConfig`): Model configuration class with all the parameters of the model. config (:class:`~pytorch_transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration. Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights. Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
""" """
DILBERT_INPUTS_DOCSTRING = r""" DISTILBERT_INPUTS_DOCSTRING = r"""
Inputs: Inputs:
**input_ids**L ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **input_ids**L ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices oof input sequence tokens in the vocabulary. Indices oof input sequence tokens in the vocabulary.
The input sequences should start with `[CLS]` and `[SEP]` tokens. The input sequences should start with `[CLS]` and `[SEP]` tokens.
For now, ONLY BertTokenizer(`bert-base-uncased`) is supported and you should use this tokenizer when using DilBERT. For now, ONLY BertTokenizer(`bert-base-uncased`) is supported and you should use this tokenizer when using DistilBERT.
**attention_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **attention_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -465,9 +465,9 @@ DILBERT_INPUTS_DOCSTRING = r""" ...@@ -465,9 +465,9 @@ DILBERT_INPUTS_DOCSTRING = r"""
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
""" """
@add_start_docstrings("The bare DilBERT encoder/transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING) DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class DilBertModel(DilBertPreTrainedModel): class DistilBertModel(DistilBertPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
...@@ -482,15 +482,15 @@ class DilBertModel(DilBertPreTrainedModel): ...@@ -482,15 +482,15 @@ class DilBertModel(DilBertPreTrainedModel):
Examples:: Examples::
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased') tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DilBertModel.from_pretrained('dilbert-base-uncased') model = DistilBertModel.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
def __init__(self, config): def __init__(self, config):
super(DilBertModel, self).__init__(config) super(DistilBertModel, self).__init__(config)
self.embeddings = Embeddings(config) # Embeddings self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder self.transformer = Transformer(config) # Encoder
...@@ -543,9 +543,9 @@ class DilBertModel(DilBertPreTrainedModel): ...@@ -543,9 +543,9 @@ class DilBertModel(DilBertPreTrainedModel):
return output # last-layer hidden-state, (all hidden_states), (all attentions) return output # last-layer hidden-state, (all hidden_states), (all attentions)
@add_start_docstrings("""DilBert Model with a `masked language modeling` head on top. """, @add_start_docstrings("""DistilBert Model with a `masked language modeling` head on top. """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING) DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class DilBertForMaskedLM(DilBertPreTrainedModel): class DistilBertForMaskedLM(DistilBertPreTrainedModel):
r""" r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
...@@ -568,19 +568,19 @@ class DilBertForMaskedLM(DilBertPreTrainedModel): ...@@ -568,19 +568,19 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
Examples:: Examples::
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased') tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DilBertForMaskedLM.from_pretrained('dilbert-base-uncased') model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids) outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2] loss, prediction_scores = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(DilBertForMaskedLM, self).__init__(config) super(DistilBertForMaskedLM, self).__init__(config)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.dilbert = DilBertModel(config) self.distilbert = DistilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim) self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size) self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
...@@ -595,14 +595,14 @@ class DilBertForMaskedLM(DilBertPreTrainedModel): ...@@ -595,14 +595,14 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
Export to TorchScript can't handle parameter sharing so we are cloning them instead. Export to TorchScript can't handle parameter sharing so we are cloning them instead.
""" """
self._tie_or_clone_weights(self.vocab_projector, self._tie_or_clone_weights(self.vocab_projector,
self.dilbert.embeddings.word_embeddings) self.distilbert.embeddings.word_embeddings)
def forward(self, def forward(self,
input_ids: torch.tensor, input_ids: torch.tensor,
attention_mask: torch.tensor = None, attention_mask: torch.tensor = None,
masked_lm_labels: torch.tensor = None, masked_lm_labels: torch.tensor = None,
head_mask: torch.tensor = None): head_mask: torch.tensor = None):
dlbrt_output = self.dilbert(input_ids=input_ids, dlbrt_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim) hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
...@@ -620,10 +620,10 @@ class DilBertForMaskedLM(DilBertPreTrainedModel): ...@@ -620,10 +620,10 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
return outputs # (mlm_loss), prediction_logits, (all hidden_states), (all attentions) return outputs # (mlm_loss), prediction_logits, (all hidden_states), (all attentions)
@add_start_docstrings("""DilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings("""DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """, the pooled output) e.g. for GLUE tasks. """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING) DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class DilBertForSequenceClassification(DilBertPreTrainedModel): class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
r""" r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss. Labels for computing the sequence classification/regression loss.
...@@ -646,8 +646,8 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel): ...@@ -646,8 +646,8 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
Examples:: Examples::
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased') tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DilBertForSequenceClassification.from_pretrained('dilbert-base-uncased') model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels) outputs = model(input_ids, labels=labels)
...@@ -655,10 +655,10 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel): ...@@ -655,10 +655,10 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(DilBertForSequenceClassification, self).__init__(config) super(DistilBertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.dilbert = DilBertModel(config) self.distilbert = DistilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim) self.pre_classifier = nn.Linear(config.dim, config.dim)
self.classifier = nn.Linear(config.dim, config.num_labels) self.classifier = nn.Linear(config.dim, config.num_labels)
self.dropout = nn.Dropout(config.seq_classif_dropout) self.dropout = nn.Dropout(config.seq_classif_dropout)
...@@ -670,17 +670,17 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel): ...@@ -670,17 +670,17 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
attention_mask: torch.tensor = None, attention_mask: torch.tensor = None,
labels: torch.tensor = None, labels: torch.tensor = None,
head_mask: torch.tensor = None): head_mask: torch.tensor = None):
dilbert_output = self.dilbert(input_ids=input_ids, distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask)
hidden_state = dilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim) pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim) pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim) logits = self.classifier(pooled_output) # (bs, dim)
outputs = (logits,) + dilbert_output[1:] outputs = (logits,) + distilbert_output[1:]
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
loss_fct = nn.MSELoss() loss_fct = nn.MSELoss()
...@@ -693,10 +693,10 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel): ...@@ -693,10 +693,10 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""DilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings("""DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """, the hidden-states output to compute `span start logits` and `span end logits`). """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING) DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class DilBertForQuestionAnswering(DilBertPreTrainedModel): class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
r""" r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
...@@ -724,8 +724,8 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel): ...@@ -724,8 +724,8 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
Examples:: Examples::
tokenizer = DilBertTokenizer.from_pretrained('dilbert-base-uncased') tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DilBertForQuestionAnswering.from_pretrained('dilbert-base-uncased') model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1]) start_positions = torch.tensor([1])
end_positions = torch.tensor([3]) end_positions = torch.tensor([3])
...@@ -734,9 +734,9 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel): ...@@ -734,9 +734,9 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(DilBertForQuestionAnswering, self).__init__(config) super(DistilBertForQuestionAnswering, self).__init__(config)
self.dilbert = DilBertModel(config) self.distilbert = DistilBertModel(config)
self.qa_outputs = nn.Linear(config.dim, config.num_labels) self.qa_outputs = nn.Linear(config.dim, config.num_labels)
assert config.num_labels == 2 assert config.num_labels == 2
self.dropout = nn.Dropout(config.qa_dropout) self.dropout = nn.Dropout(config.qa_dropout)
...@@ -749,10 +749,10 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel): ...@@ -749,10 +749,10 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
start_positions: torch.tensor = None, start_positions: torch.tensor = None,
end_positions: torch.tensor = None, end_positions: torch.tensor = None,
head_mask: torch.tensor = None): head_mask: torch.tensor = None):
dilbert_output = self.dilbert(input_ids=input_ids, distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask)
hidden_states = dilbert_output[0] # (bs, max_query_len, dim) hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim) hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
...@@ -760,7 +760,7 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel): ...@@ -760,7 +760,7 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
start_logits = start_logits.squeeze(-1) # (bs, max_query_len) start_logits = start_logits.squeeze(-1) # (bs, max_query_len)
end_logits = end_logits.squeeze(-1) # (bs, max_query_len) end_logits = end_logits.squeeze(-1) # (bs, max_query_len)
outputs = (start_logits, end_logits,) + dilbert_output[1:] outputs = (start_logits, end_logits,) + distilbert_output[1:]
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension # If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1: if len(start_positions.size()) > 1:
......
...@@ -20,23 +20,23 @@ import unittest ...@@ -20,23 +20,23 @@ import unittest
import shutil import shutil
import pytest import pytest
from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM, from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DilBertForQuestionAnswering, DilBertForSequenceClassification) DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_distilbert import DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
class DilBertModelTest(CommonTestCases.CommonModelTester): class DistilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering, all_model_classes = (DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering,
DilBertForSequenceClassification) DistilBertForSequenceClassification)
test_pruning = True test_pruning = True
test_torchscript = True test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = True test_head_masking = True
class DilBertModelTester(object): class DistilBertModelTester(object):
def __init__(self, def __init__(self,
parent, parent,
...@@ -100,7 +100,7 @@ class DilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -100,7 +100,7 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = DilBertConfig( config = DistilBertConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
dim=self.hidden_size, dim=self.hidden_size,
n_layers=self.num_hidden_layers, n_layers=self.num_hidden_layers,
...@@ -119,8 +119,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -119,8 +119,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_distilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertModel(config=config) model = DistilBertModel(config=config)
model.eval() model.eval()
(sequence_output,) = model(input_ids, input_mask) (sequence_output,) = model(input_ids, input_mask)
(sequence_output,) = model(input_ids) (sequence_output,) = model(input_ids)
...@@ -132,8 +132,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -132,8 +132,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
list(result["sequence_output"].size()), list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_distilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertForMaskedLM(config=config) model = DistilBertForMaskedLM(config=config)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels) loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
result = { result = {
...@@ -145,8 +145,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -145,8 +145,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_dilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertForQuestionAnswering(config=config) model = DistilBertForQuestionAnswering(config=config)
model.eval() model.eval()
loss, start_logits, end_logits = model(input_ids, input_mask, sequence_labels, sequence_labels) loss, start_logits, end_logits = model(input_ids, input_mask, sequence_labels, sequence_labels)
result = { result = {
...@@ -162,9 +162,9 @@ class DilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -162,9 +162,9 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.seq_length]) [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_dilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_distilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = DilBertForSequenceClassification(config) model = DistilBertForSequenceClassification(config)
model.eval() model.eval()
loss, logits = model(input_ids, input_mask, sequence_labels) loss, logits = model(input_ids, input_mask, sequence_labels)
result = { result = {
...@@ -183,33 +183,33 @@ class DilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -183,33 +183,33 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
self.model_tester = DilBertModelTest.DilBertModelTester(self) self.model_tester = DistilBertModelTest.DistilBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=DilBertConfig, dim=37) self.config_tester = ConfigTester(self, config_class=DistilBertConfig, dim=37)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_dilbert_model(self): def test_distilbert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dilbert_model(*config_and_inputs) self.model_tester.create_and_check_distilbert_model(*config_and_inputs)
def test_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dilbert_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self): def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dilbert_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_distilbert_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self): def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dilbert_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
# @pytest.mark.slow # @pytest.mark.slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/pytorch_transformers_test/" # cache_dir = "/tmp/pytorch_transformers_test/"
# for model_name in list(DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DilBertModel.from_pretrained(model_name, cache_dir=cache_dir) # model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir) # shutil.rmtree(cache_dir)
# self.assertIsNotNone(model) # self.assertIsNotNone(model)
......
...@@ -18,20 +18,20 @@ import os ...@@ -18,20 +18,20 @@ import os
import unittest import unittest
from io import open from io import open
from pytorch_transformers.tokenization_dilbert import (DilBertTokenizer) from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest from .tokenization_bert_test import BertTokenizationTest
class DilBertTokenizationTest(BertTokenizationTest): class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DilBertTokenizer tokenizer_class = DistilBertTokenizer
def get_tokenizer(self): def get_tokenizer(self):
return DilBertTokenizer.from_pretrained(self.tmpdirname) return DistilBertTokenizer.from_pretrained(self.tmpdirname)
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = DilBertTokenizer.from_pretrained("dilbert-base-uncased") tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build")
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for DilBERT.""" """Tokenization classes for DistilBERT."""
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
...@@ -31,21 +31,21 @@ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} ...@@ -31,21 +31,21 @@ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': 'vocab_file':
{ {
'dilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'dilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'dilbert-base-uncased': 512, 'distilbert-base-uncased': 512,
'dilbert-base-uncased-distilled-squad': 512, 'distilbert-base-uncased-distilled-squad': 512,
} }
class DilBertTokenizer(BertTokenizer): class DistilBertTokenizer(BertTokenizer):
r""" r"""
Constructs a DilBertTokenizer. Constructs a DistilBertTokenizer.
:class:`~pytorch_transformers.DilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment