Commit 354944e6 authored by VictorSanh's avatar VictorSanh
Browse files

[distillation] big update w/ new weights

parent 0d1dad6d
...@@ -2,12 +2,21 @@ ...@@ -2,12 +2,21 @@
This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT. This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT.
**2019, September 19th - Update:** We fixed bugs in the code and released an upadted version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 97% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future!
## What is DistilBERT ## What is DistilBERT
DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production. DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5 For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
). ). *Please note that we will publish a formal write-up with updated and more complete results in the near future (September 19th).*
Here's the updated results on the dev sets of GLUE:
| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | STS-B | WNLI |
| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:|
| BERT-base | **77.6** | 48.9 | 84.3 | 88.6 | 89.3 | 89.5 | 71.3 | 91.7 | 91.2 | 43.7 |
| DistilBERT | **75.2** | 49.1 | 81.8 | 90.2 | 87.0 | 89.2 | 62.9 | 92.7 | 90.7 | 44.4 |
## Setup ## Setup
...@@ -20,7 +29,7 @@ This part of the library has only be tested with Python3.6+. There are few speci ...@@ -20,7 +29,7 @@ This part of the library has only be tested with Python3.6+. There are few speci
PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT): PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters. - `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.2 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score). - `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models. Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
......
...@@ -92,11 +92,11 @@ class Dataset: ...@@ -92,11 +92,11 @@ class Dataset:
Too short sequences are simply removed. This could be tunedd. Too short sequences are simply removed. This could be tunedd.
""" """
init_size = len(self) init_size = len(self)
indices = self.lengths > 5 indices = self.lengths > 11
self.token_ids = self.token_ids[indices] self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices] self.lengths = self.lengths[indices]
new_size = len(self) new_size = len(self)
logger.info(f'Remove {init_size - new_size} too short (<=5 tokens) sequences.') logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.')
def print_statistics(self): def print_statistics(self):
""" """
......
...@@ -18,15 +18,18 @@ ...@@ -18,15 +18,18 @@
import os import os
import math import math
import psutil import psutil
import time
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import trange, tqdm from tqdm import trange, tqdm
import numpy as np import numpy as np
import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import AdamW
from pytorch_transformers import AdamW, WarmupLinearSchedule from pytorch_transformers import WarmupLinearSchedule
from utils import logger from utils import logger
from dataset import Dataset from dataset import Dataset
...@@ -58,10 +61,12 @@ class Distiller: ...@@ -58,10 +61,12 @@ class Distiller:
self.alpha_ce = params.alpha_ce self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm self.alpha_mlm = params.alpha_mlm
self.alpha_mse = params.alpha_mse self.alpha_mse = params.alpha_mse
self.alpha_cos = params.alpha_cos
assert self.alpha_ce >= 0. assert self.alpha_ce >= 0.
assert self.alpha_mlm >= 0. assert self.alpha_mlm >= 0.
assert self.alpha_mse >= 0. assert self.alpha_mse >= 0.
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0. assert self.alpha_cos >= 0.
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0.
self.mlm_mask_prop = params.mlm_mask_prop self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0 assert 0.0 <= self.mlm_mask_prop <= 1.0
...@@ -81,17 +86,21 @@ class Distiller: ...@@ -81,17 +86,21 @@ class Distiller:
self.last_loss = 0 self.last_loss = 0
self.last_loss_ce = 0 self.last_loss_ce = 0
self.last_loss_mlm = 0 self.last_loss_mlm = 0
self.last_loss_mse = 0 if self.alpha_mse > 0.: self.last_loss_mse = 0
if self.alpha_cos > 0.: self.last_loss_cos = 0
self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
if self.alpha_mse > 0.:
self.mse_loss_fct = nn.MSELoss(reduction='sum') self.mse_loss_fct = nn.MSELoss(reduction='sum')
if self.alpha_cos > 0.:
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')
logger.info('--- Initializing model optimizer') logger.info('--- Initializing model optimizer')
assert params.gradient_accumulation_steps >= 1 assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1 self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
...@@ -104,6 +113,9 @@ class Distiller: ...@@ -104,6 +113,9 @@ class Distiller:
lr=params.learning_rate, lr=params.learning_rate,
eps=params.adam_epsilon, eps=params.adam_epsilon,
betas=(0.9, 0.98)) betas=(0.9, 0.98))
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
logger.info(f'--- Scheduler: {params.scheduler_type}')
self.scheduler = WarmupLinearSchedule(self.optimizer, self.scheduler = WarmupLinearSchedule(self.optimizer,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
...@@ -272,11 +284,14 @@ class Distiller: ...@@ -272,11 +284,14 @@ class Distiller:
The real training loop. The real training loop.
""" """
if self.is_master: logger.info('Starting training') if self.is_master: logger.info('Starting training')
self.last_log = time.time()
self.student.train() self.student.train()
self.teacher.eval() self.teacher.eval()
for _ in range(self.params.n_epoch): for _ in range(self.params.n_epoch):
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
if self.multi_gpu:
torch.distributed.barrier()
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for __ in range(self.num_steps_epoch): for __ in range(self.num_steps_epoch):
...@@ -314,9 +329,9 @@ class Distiller: ...@@ -314,9 +329,9 @@ class Distiller:
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
""" """
s_logits = self.student(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size) s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
with torch.no_grad(): with torch.no_grad():
t_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size) t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size() assert s_logits.size() == t_logits.size()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
...@@ -341,6 +356,22 @@ class Distiller: ...@@ -341,6 +356,22 @@ class Distiller:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse loss += self.alpha_mse * loss_mse
if self.alpha_cos > 0.:
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
assert s_hidden_states.size() == t_hidden_states.size()
dim = s_hidden_states.size(-1)
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
loss += self.alpha_cos * loss_cos
self.total_loss_epoch += loss.item() self.total_loss_epoch += loss.item()
self.last_loss = loss.item() self.last_loss = loss.item()
self.last_loss_ce = loss_ce.item() self.last_loss_ce = loss_ce.item()
...@@ -348,6 +379,8 @@ class Distiller: ...@@ -348,6 +379,8 @@ class Distiller:
self.last_loss_mlm = loss_mlm.item() self.last_loss_mlm = loss_mlm.item()
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.last_loss_mse = loss_mse.item() self.last_loss_mse = loss_mse.item()
if self.alpha_cos > 0.:
self.last_loss_cos = loss_cos.item()
self.optimize(loss) self.optimize(loss)
...@@ -396,6 +429,7 @@ class Distiller: ...@@ -396,6 +429,7 @@ class Distiller:
if self.n_total_iter % self.params.log_interval == 0: if self.n_total_iter % self.params.log_interval == 0:
self.log_tensorboard() self.log_tensorboard()
self.last_log = time.time()
if self.n_total_iter % self.params.checkpoint_interval == 0: if self.n_total_iter % self.params.checkpoint_interval == 0:
self.save_checkpoint() self.save_checkpoint()
...@@ -421,9 +455,12 @@ class Distiller: ...@@ -421,9 +455,12 @@ class Distiller:
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
if self.alpha_cos > 0.:
self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter)
def end_epoch(self): def end_epoch(self):
""" """
......
...@@ -2,3 +2,5 @@ gitpython==3.0.2 ...@@ -2,3 +2,5 @@ gitpython==3.0.2
tensorboard>=1.14.0 tensorboard>=1.14.0
tensorboardX==1.8 tensorboardX==1.8
psutil==5.6.3 psutil==5.6.3
scipy==1.3.1
pytorch_transformers==1.2.0
...@@ -20,7 +20,7 @@ import pickle ...@@ -20,7 +20,7 @@ import pickle
import random import random
import time import time
import numpy as np import numpy as np
from pytorch_transformers import BertTokenizer from pytorch_transformers import BertTokenizer, RobertaTokenizer
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -32,16 +32,21 @@ def main(): ...@@ -32,16 +32,21 @@ def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
parser.add_argument('--file_path', type=str, default='data/dump.txt', parser.add_argument('--file_path', type=str, default='data/dump.txt',
help='The path to the data.') help='The path to the data.')
parser.add_argument('--bert_tokenizer', type=str, default='bert-base-uncased', parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta'])
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased',
help="The tokenizer to use.") help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump', parser.add_argument('--dump_file', type=str, default='data/dump',
help='The dump file prefix.') help='The dump file prefix.')
args = parser.parse_args() args = parser.parse_args()
logger.info(f'Loading Tokenizer ({args.bert_tokenizer})') logger.info(f'Loading Tokenizer ({args.tokenizer_name})')
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) if args.tokenizer_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
elif args.tokenizer_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `[CLS]` for bert, `<s>` for roberta
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` for bert, `</s>` for roberta
logger.info(f'Loading text from {args.file_path}') logger.info(f'Loading text from {args.file_path}')
with open(args.file_path, 'r', encoding='utf8') as fp: with open(args.file_path, 'r', encoding='utf8') as fp:
...@@ -56,8 +61,8 @@ def main(): ...@@ -56,8 +61,8 @@ def main():
interval = 10000 interval = 10000
start = time.time() start = time.time()
for text in data: for text in data:
text = f'[CLS] {text.strip()} [SEP]' text = f'{bos} {text.strip()} {sep}'
token_ids = bert_tokenizer.encode(text) token_ids = tokenizer.encode(text)
rslt.append(token_ids) rslt.append(token_ids)
iter += 1 iter += 1
...@@ -69,7 +74,7 @@ def main(): ...@@ -69,7 +74,7 @@ def main():
logger.info(f'{len(data)} examples processed.') logger.info(f'{len(data)} examples processed.')
dp_file = f'{args.dump_file}.{args.bert_tokenizer}.pickle' dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle'
rslt_ = [np.uint16(d) for d in rslt] rslt_ = [np.uint16(d) for d in rslt]
random.shuffle(rslt_) random.shuffle(rslt_)
logger.info(f'Dump to {dp_file}') logger.info(f'Dump to {dp_file}')
......
...@@ -15,59 +15,73 @@ ...@@ -15,59 +15,73 @@
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training DistilBERT.
""" """
from pytorch_transformers import BertForPreTraining from pytorch_transformers import BertForMaskedLM, RobertaForMaskedLM
import torch import torch
import argparse import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForPreTraining for Transfer Learned Distillation") parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str) parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"])
parser.add_argument("--dump_checkpoint", default='serialization_dir/transfer_learning_checkpoint_0247911.pth', type=str) parser.add_argument("--model_name", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action='store_true')
args = parser.parse_args() args = parser.parse_args()
model = BertForPreTraining.from_pretrained(args.bert_model) if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert'
elif args.model_type == 'roberta':
model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = 'roberta'
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']: for w in ['word_embeddings', 'position_embeddings']:
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \ compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
state_dict[f'bert.embeddings.{w}.weight'] state_dict[f'{prefix}.embeddings.{w}.weight']
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \ compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'bert.embeddings.LayerNorm.{w}'] state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1 std_idx += 1
if args.model_type == 'bert':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
elif args.model_type == 'roberta':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
print(f'N layers selected for distillation: {std_idx}') print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
......
...@@ -23,7 +23,7 @@ import shutil ...@@ -23,7 +23,7 @@ import shutil
import numpy as np import numpy as np
import torch import torch
from pytorch_transformers import BertTokenizer, BertForMaskedLM from pytorch_transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM
from pytorch_transformers import DistilBertForMaskedLM, DistilBertConfig from pytorch_transformers import DistilBertForMaskedLM, DistilBertConfig
from distiller import Distiller from distiller import Distiller
...@@ -70,8 +70,10 @@ def main(): ...@@ -70,8 +70,10 @@ def main():
help="Load student initialization checkpoint.") help="Load student initialization checkpoint.")
parser.add_argument("--from_pretrained_config", default=None, type=str, parser.add_argument("--from_pretrained_config", default=None, type=str,
help="Load student initialization architecture config.") help="Load student initialization architecture config.")
parser.add_argument("--bert_model", default='bert-base-uncased', type=str, parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
help="The teacher BERT model.") help="Teacher type (BERT, RoBERTa).")
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str,
help="The teacher model.")
parser.add_argument("--temperature", default=2., type=float, parser.add_argument("--temperature", default=2., type=float,
help="Temperature for the softmax temperature.") help="Temperature for the softmax temperature.")
...@@ -81,6 +83,8 @@ def main(): ...@@ -81,6 +83,8 @@ def main():
help="Linear weight for the MLM loss. Must be >=0.") help="Linear weight for the MLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float, parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.") help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--alpha_cos", default=0.0, type=float,
help="Linear weight of the cosine embedding loss. Must be >=0.")
parser.add_argument("--mlm_mask_prop", default=0.15, type=float, parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
help="Proportion of tokens for which we need to make a prediction.") help="Proportion of tokens for which we need to make a prediction.")
parser.add_argument("--word_mask", default=0.8, type=float, parser.add_argument("--word_mask", default=0.8, type=float,
...@@ -165,11 +169,14 @@ def main(): ...@@ -165,11 +169,14 @@ def main():
### TOKENIZER ### ### TOKENIZER ###
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model) if args.teacher_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
elif args.teacher_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
special_tok_ids = {} special_tok_ids = {}
for tok_name, tok_symbol in bert_tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = bert_tokenizer.all_special_tokens.index(tok_symbol) idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = bert_tokenizer.all_special_ids[idx] special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}') logger.info(f'Special tokens {special_tok_ids}')
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
...@@ -202,11 +209,12 @@ def main(): ...@@ -202,11 +209,12 @@ def main():
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}') logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config) stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
stu_architecture_config.output_hidden_states = True
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
config=stu_architecture_config) config=stu_architecture_config)
else: else:
args.vocab_size_or_config_json_file = args.vocab_size args.vocab_size_or_config_json_file = args.vocab_size
stu_architecture_config = DistilBertConfig(**vars(args)) stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
student = DistilBertForMaskedLM(stu_architecture_config) student = DistilBertForMaskedLM(stu_architecture_config)
...@@ -216,10 +224,13 @@ def main(): ...@@ -216,10 +224,13 @@ def main():
## TEACHER ## ## TEACHER ##
teacher = BertForMaskedLM.from_pretrained(args.bert_model) if args.teacher_type == 'bert':
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
elif args.teacher_type == 'roberta':
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}') teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.bert_model}.') logger.info(f'Teacher loaded from {args.teacher_name}.')
## DISTILLER ## ## DISTILLER ##
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment