Commit 481d9c4f authored by thomwolf's avatar thomwolf
Browse files

Merge branch 'master' into tf2

parents 4ddc31ff 7c0f2d0a
...@@ -97,20 +97,20 @@ Fine-tuning the library models for sequence classification on the GLUE benchmark ...@@ -97,20 +97,20 @@ Fine-tuning the library models for sequence classification on the GLUE benchmark
Evaluation](https://gluebenchmark.com/). This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa. Evaluation](https://gluebenchmark.com/). This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa.
GLUE is made up of a total of 9 different tasks. We get the following results on the dev set of the benchmark with an GLUE is made up of a total of 9 different tasks. We get the following results on the dev set of the benchmark with an
uncased BERT base model (the checkpoint `bert-base-uncased`). All experiments ran on 8 V100 GPUs with a total train uncased BERT base model (the checkpoint `bert-base-uncased`). All experiments ran on 8 V100 GPUs with a total train
batch size of 24. Some of these tasks have a small dataset and training can lead to high variance in the results batch size of 24. Some of these tasks have a small dataset and training can lead to high variance in the results
between different runs. We report the median on 5 runs (with different seeds) for each of the metrics. between different runs. We report the median on 5 runs (with different seeds) for each of the metrics.
| Task | Metric | Result | | Task | Metric | Result |
|-------|------------------------------|-------------| |-------|------------------------------|-------------|
| CoLA | Matthew's corr | 55.75 | | CoLA | Matthew's corr | 48.87 |
| SST-2 | Accuracy | 92.09 | | SST-2 | Accuracy | 91.74 |
| MRPC | F1/Accuracy | 90.48/86.27 | | MRPC | F1/Accuracy | 90.70/86.27 |
| STS-B | Person/Spearman corr. | 89.03/88.64 | | STS-B | Person/Spearman corr. | 91.39/91.04 |
| QQP | Accuracy/F1 | 90.92/87.72 | | QQP | Accuracy/F1 | 90.79/87.66 |
| MNLI | Matched acc./Mismatched acc. | 83.74/84.06 | | MNLI | Matched acc./Mismatched acc. | 83.70/84.83 |
| QNLI | Accuracy | 91.07 | | QNLI | Accuracy | 89.31 |
| RTE | Accuracy | 68.59 | | RTE | Accuracy | 71.43 |
| WNLI | Accuracy | 43.66 | | WNLI | Accuracy | 43.66 |
Some of these results are significantly different from the ones reported on the test set Some of these results are significantly different from the ones reported on the test set
......
...@@ -2,12 +2,21 @@ ...@@ -2,12 +2,21 @@
This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT. This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT.
**2019, September 19th - Update:** We fixed bugs in the code and released an upadted version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 97% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future!
## What is DistilBERT ## What is DistilBERT
DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production. DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5 For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
). ). *Please note that we will publish a formal write-up with updated and more complete results in the near future (September 19th).*
Here's the updated results on the dev sets of GLUE:
| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | STS-B | WNLI |
| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:|
| BERT-base | **77.6** | 48.9 | 84.3 | 88.6 | 89.3 | 89.5 | 71.3 | 91.7 | 91.2 | 43.7 |
| DistilBERT | **75.2** | 49.1 | 81.8 | 90.2 | 87.0 | 89.2 | 62.9 | 92.7 | 90.7 | 44.4 |
## Setup ## Setup
...@@ -20,7 +29,7 @@ This part of the library has only be tested with Python3.6+. There are few speci ...@@ -20,7 +29,7 @@ This part of the library has only be tested with Python3.6+. There are few speci
Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT): Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters. - `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.2 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score). - `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models. Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
......
...@@ -92,11 +92,11 @@ class Dataset: ...@@ -92,11 +92,11 @@ class Dataset:
Too short sequences are simply removed. This could be tunedd. Too short sequences are simply removed. This could be tunedd.
""" """
init_size = len(self) init_size = len(self)
indices = self.lengths > 5 indices = self.lengths > 11
self.token_ids = self.token_ids[indices] self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices] self.lengths = self.lengths[indices]
new_size = len(self) new_size = len(self)
logger.info(f'Remove {init_size - new_size} too short (<=5 tokens) sequences.') logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.')
def print_statistics(self): def print_statistics(self):
""" """
......
...@@ -18,15 +18,18 @@ ...@@ -18,15 +18,18 @@
import os import os
import math import math
import psutil import psutil
import time
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import trange, tqdm from tqdm import trange, tqdm
import numpy as np import numpy as np
import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AdamW, WarmupLinearSchedule from transformers import WarmupLinearSchedule
from utils import logger from utils import logger
from dataset import Dataset from dataset import Dataset
...@@ -58,10 +61,12 @@ class Distiller: ...@@ -58,10 +61,12 @@ class Distiller:
self.alpha_ce = params.alpha_ce self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm self.alpha_mlm = params.alpha_mlm
self.alpha_mse = params.alpha_mse self.alpha_mse = params.alpha_mse
self.alpha_cos = params.alpha_cos
assert self.alpha_ce >= 0. assert self.alpha_ce >= 0.
assert self.alpha_mlm >= 0. assert self.alpha_mlm >= 0.
assert self.alpha_mse >= 0. assert self.alpha_mse >= 0.
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0. assert self.alpha_cos >= 0.
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0.
self.mlm_mask_prop = params.mlm_mask_prop self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0 assert 0.0 <= self.mlm_mask_prop <= 1.0
...@@ -81,17 +86,21 @@ class Distiller: ...@@ -81,17 +86,21 @@ class Distiller:
self.last_loss = 0 self.last_loss = 0
self.last_loss_ce = 0 self.last_loss_ce = 0
self.last_loss_mlm = 0 self.last_loss_mlm = 0
self.last_loss_mse = 0 if self.alpha_mse > 0.: self.last_loss_mse = 0
if self.alpha_cos > 0.: self.last_loss_cos = 0
self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
self.mse_loss_fct = nn.MSELoss(reduction='sum') if self.alpha_mse > 0.:
self.mse_loss_fct = nn.MSELoss(reduction='sum')
if self.alpha_cos > 0.:
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')
logger.info('--- Initializing model optimizer') logger.info('--- Initializing model optimizer')
assert params.gradient_accumulation_steps >= 1 assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1 self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
...@@ -104,9 +113,11 @@ class Distiller: ...@@ -104,9 +113,11 @@ class Distiller:
lr=params.learning_rate, lr=params.learning_rate,
eps=params.adam_epsilon, eps=params.adam_epsilon,
betas=(0.9, 0.98)) betas=(0.9, 0.98))
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
self.scheduler = WarmupLinearSchedule(self.optimizer, self.scheduler = WarmupLinearSchedule(self.optimizer,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
if self.fp16: if self.fp16:
try: try:
...@@ -272,11 +283,14 @@ class Distiller: ...@@ -272,11 +283,14 @@ class Distiller:
The real training loop. The real training loop.
""" """
if self.is_master: logger.info('Starting training') if self.is_master: logger.info('Starting training')
self.last_log = time.time()
self.student.train() self.student.train()
self.teacher.eval() self.teacher.eval()
for _ in range(self.params.n_epoch): for _ in range(self.params.n_epoch):
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
if self.multi_gpu:
torch.distributed.barrier()
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for __ in range(self.num_steps_epoch): for __ in range(self.num_steps_epoch):
...@@ -314,9 +328,9 @@ class Distiller: ...@@ -314,9 +328,9 @@ class Distiller:
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
""" """
s_logits = self.student(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size) s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
with torch.no_grad(): with torch.no_grad():
t_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size) t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size() assert s_logits.size() == t_logits.size()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
...@@ -340,6 +354,22 @@ class Distiller: ...@@ -340,6 +354,22 @@ class Distiller:
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse loss += self.alpha_mse * loss_mse
if self.alpha_cos > 0.:
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
assert s_hidden_states.size() == t_hidden_states.size()
dim = s_hidden_states.size(-1)
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
loss += self.alpha_cos * loss_cos
self.total_loss_epoch += loss.item() self.total_loss_epoch += loss.item()
self.last_loss = loss.item() self.last_loss = loss.item()
...@@ -348,6 +378,8 @@ class Distiller: ...@@ -348,6 +378,8 @@ class Distiller:
self.last_loss_mlm = loss_mlm.item() self.last_loss_mlm = loss_mlm.item()
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.last_loss_mse = loss_mse.item() self.last_loss_mse = loss_mse.item()
if self.alpha_cos > 0.:
self.last_loss_cos = loss_cos.item()
self.optimize(loss) self.optimize(loss)
...@@ -396,6 +428,7 @@ class Distiller: ...@@ -396,6 +428,7 @@ class Distiller:
if self.n_total_iter % self.params.log_interval == 0: if self.n_total_iter % self.params.log_interval == 0:
self.log_tensorboard() self.log_tensorboard()
self.last_log = time.time()
if self.n_total_iter % self.params.checkpoint_interval == 0: if self.n_total_iter % self.params.checkpoint_interval == 0:
self.save_checkpoint() self.save_checkpoint()
...@@ -421,9 +454,12 @@ class Distiller: ...@@ -421,9 +454,12 @@ class Distiller:
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
if self.alpha_cos > 0.:
self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter)
def end_epoch(self): def end_epoch(self):
""" """
......
...@@ -2,3 +2,5 @@ gitpython==3.0.2 ...@@ -2,3 +2,5 @@ gitpython==3.0.2
tensorboard>=1.14.0 tensorboard>=1.14.0
tensorboardX==1.8 tensorboardX==1.8
psutil==5.6.3 psutil==5.6.3
scipy==1.3.1
pytorch_transformers==1.2.0
...@@ -20,7 +20,7 @@ import pickle ...@@ -20,7 +20,7 @@ import pickle
import random import random
import time import time
import numpy as np import numpy as np
from transformers import BertTokenizer from transformers import BertTokenizer, RobertaTokenizer
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -32,16 +32,21 @@ def main(): ...@@ -32,16 +32,21 @@ def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
parser.add_argument('--file_path', type=str, default='data/dump.txt', parser.add_argument('--file_path', type=str, default='data/dump.txt',
help='The path to the data.') help='The path to the data.')
parser.add_argument('--bert_tokenizer', type=str, default='bert-base-uncased', parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta'])
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased',
help="The tokenizer to use.") help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump', parser.add_argument('--dump_file', type=str, default='data/dump',
help='The dump file prefix.') help='The dump file prefix.')
args = parser.parse_args() args = parser.parse_args()
logger.info(f'Loading Tokenizer ({args.bert_tokenizer})') logger.info(f'Loading Tokenizer ({args.tokenizer_name})')
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) if args.tokenizer_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
elif args.tokenizer_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `[CLS]` for bert, `<s>` for roberta
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` for bert, `</s>` for roberta
logger.info(f'Loading text from {args.file_path}') logger.info(f'Loading text from {args.file_path}')
with open(args.file_path, 'r', encoding='utf8') as fp: with open(args.file_path, 'r', encoding='utf8') as fp:
...@@ -56,8 +61,8 @@ def main(): ...@@ -56,8 +61,8 @@ def main():
interval = 10000 interval = 10000
start = time.time() start = time.time()
for text in data: for text in data:
text = f'[CLS] {text.strip()} [SEP]' text = f'{bos} {text.strip()} {sep}'
token_ids = bert_tokenizer.encode(text) token_ids = tokenizer.encode(text)
rslt.append(token_ids) rslt.append(token_ids)
iter += 1 iter += 1
...@@ -69,7 +74,7 @@ def main(): ...@@ -69,7 +74,7 @@ def main():
logger.info(f'{len(data)} examples processed.') logger.info(f'{len(data)} examples processed.')
dp_file = f'{args.dump_file}.{args.bert_tokenizer}.pickle' dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle'
rslt_ = [np.uint16(d) for d in rslt] rslt_ = [np.uint16(d) for d in rslt]
random.shuffle(rslt_) random.shuffle(rslt_)
logger.info(f'Dump to {dp_file}') logger.info(f'Dump to {dp_file}')
......
...@@ -15,59 +15,73 @@ ...@@ -15,59 +15,73 @@
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training DistilBERT.
""" """
from transformers import BertForPreTraining from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch import torch
import argparse import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForPreTraining for Transfer Learned Distillation") parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str) parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"])
parser.add_argument("--dump_checkpoint", default='serialization_dir/transfer_learning_checkpoint_0247911.pth', type=str) parser.add_argument("--model_name", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action='store_true')
args = parser.parse_args() args = parser.parse_args()
model = BertForPreTraining.from_pretrained(args.bert_model) if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert'
elif args.model_type == 'roberta':
model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = 'roberta'
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']: for w in ['word_embeddings', 'position_embeddings']:
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \ compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
state_dict[f'bert.embeddings.{w}.weight'] state_dict[f'{prefix}.embeddings.{w}.weight']
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \ compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'bert.embeddings.LayerNorm.{w}'] state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1 std_idx += 1
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] if args.model_type == 'bert':
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
if args.vocab_transform: compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
for w in ['weight', 'bias']: if args.vocab_transform:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] for w in ['weight', 'bias']:
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
elif args.model_type == 'roberta':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
print(f'N layers selected for distillation: {std_idx}') print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
......
...@@ -23,7 +23,7 @@ import shutil ...@@ -23,7 +23,7 @@ import shutil
import numpy as np import numpy as np
import torch import torch
from transformers import BertTokenizer, BertForMaskedLM from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM
from transformers import DistilBertForMaskedLM, DistilBertConfig from transformers import DistilBertForMaskedLM, DistilBertConfig
from distiller import Distiller from distiller import Distiller
...@@ -70,8 +70,10 @@ def main(): ...@@ -70,8 +70,10 @@ def main():
help="Load student initialization checkpoint.") help="Load student initialization checkpoint.")
parser.add_argument("--from_pretrained_config", default=None, type=str, parser.add_argument("--from_pretrained_config", default=None, type=str,
help="Load student initialization architecture config.") help="Load student initialization architecture config.")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str, parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
help="The teacher BERT model.") help="Teacher type (BERT, RoBERTa).")
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str,
help="The teacher model.")
parser.add_argument("--temperature", default=2., type=float, parser.add_argument("--temperature", default=2., type=float,
help="Temperature for the softmax temperature.") help="Temperature for the softmax temperature.")
...@@ -81,6 +83,8 @@ def main(): ...@@ -81,6 +83,8 @@ def main():
help="Linear weight for the MLM loss. Must be >=0.") help="Linear weight for the MLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float, parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.") help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--alpha_cos", default=0.0, type=float,
help="Linear weight of the cosine embedding loss. Must be >=0.")
parser.add_argument("--mlm_mask_prop", default=0.15, type=float, parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
help="Proportion of tokens for which we need to make a prediction.") help="Proportion of tokens for which we need to make a prediction.")
parser.add_argument("--word_mask", default=0.8, type=float, parser.add_argument("--word_mask", default=0.8, type=float,
...@@ -165,11 +169,14 @@ def main(): ...@@ -165,11 +169,14 @@ def main():
### TOKENIZER ### ### TOKENIZER ###
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model) if args.teacher_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
elif args.teacher_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
special_tok_ids = {} special_tok_ids = {}
for tok_name, tok_symbol in bert_tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = bert_tokenizer.all_special_tokens.index(tok_symbol) idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = bert_tokenizer.all_special_ids[idx] special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}') logger.info(f'Special tokens {special_tok_ids}')
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
...@@ -197,16 +204,17 @@ def main(): ...@@ -197,16 +204,17 @@ def main():
## STUDENT ## ## STUDENT ##
if args.from_pretrained_weights is not None: if args.from_pretrained_weights is not None:
assert os.path.isfile(os.path.join(args.from_pretrained_weights)) assert os.path.isfile(args.from_pretrained_weights)
assert os.path.isfile(os.path.join(args.from_pretrained_config)) assert os.path.isfile(args.from_pretrained_config)
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}') logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config) stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
stu_architecture_config.output_hidden_states = True
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
config=stu_architecture_config) config=stu_architecture_config)
else: else:
args.vocab_size_or_config_json_file = args.vocab_size args.vocab_size_or_config_json_file = args.vocab_size
stu_architecture_config = DistilBertConfig(**vars(args)) stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
student = DistilBertForMaskedLM(stu_architecture_config) student = DistilBertForMaskedLM(stu_architecture_config)
...@@ -216,10 +224,13 @@ def main(): ...@@ -216,10 +224,13 @@ def main():
## TEACHER ## ## TEACHER ##
teacher = BertForMaskedLM.from_pretrained(args.bert_model) if args.teacher_type == 'bert':
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
elif args.teacher_type == 'roberta':
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}') teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.bert_model}.') logger.info(f'Teacher loaded from {args.teacher_name}.')
## DISTILLER ## ## DISTILLER ##
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -36,7 +36,6 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -36,7 +36,6 @@ class OpenAIGPTConfig(PretrainedConfig):
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_positions: Number of positional embeddings. n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions). n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states. n_embd: Dimensionality of the embeddings and hidden states.
......
...@@ -183,8 +183,8 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -183,8 +183,8 @@ class XLNetTokenizer(PreTrainedTokenizer):
def add_special_tokens_single_sequence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] An XLNet sequence has the following format: X [SEP][CLS]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
...@@ -192,8 +192,8 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -192,8 +192,8 @@ class XLNetTokenizer(PreTrainedTokenizer):
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
An XLNet sequence has the following format: X [SEP][CLS] An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment