"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "92b3cb786d3b32b2ef830154b4131cbacb91699e"
Commit dddd6b99 authored by VictorSanh's avatar VictorSanh
Browse files

Update DistilBERT training code

parent f9453d15
...@@ -9,6 +9,12 @@ DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and l ...@@ -9,6 +9,12 @@ DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and l
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5 For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
). ).
## Setup
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breakings changes compared to v1.1.0). It is important to note that there is a small internal bug in the current version of PyTorch available on pip that causes a memory leak in our training/distillation. It has been recently fixed and will likely be integrated into the next release. For the moment, we recommend to [compile PyTorch from source](https://github.com/pytorch/pytorch#from-source). Please refer to [issue 1179](https://github.com/huggingface/pytorch-transformers/issues/1179) for more details.
## How to use DistilBERT ## How to use DistilBERT
PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT): PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
""" """
import os import os
import math import math
import psutil
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import trange, tqdm from tqdm import trange, tqdm
import numpy as np import numpy as np
...@@ -192,7 +193,7 @@ class Distiller: ...@@ -192,7 +193,7 @@ class Distiller:
x_prob = self.token_probs[token_ids.flatten()] x_prob = self.token_probs[token_ids.flatten()]
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device) pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
pred_mask[tgt_ids] = 1 pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len) pred_mask = pred_mask.view(bs, max_seq_len)
...@@ -216,7 +217,7 @@ class Distiller: ...@@ -216,7 +217,7 @@ class Distiller:
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
token_ids = token_ids.masked_scatter(pred_mask, _token_ids) token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[1-pred_mask] = -1 mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
return token_ids, attn_mask, mlm_labels return token_ids, attn_mask, mlm_labels
...@@ -379,9 +380,9 @@ class Distiller: ...@@ -379,9 +380,9 @@ class Distiller:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm) torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
self.scheduler.step()
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step()
def iter(self): def iter(self):
""" """
...@@ -418,6 +419,8 @@ class Distiller: ...@@ -418,6 +419,8 @@ class Distiller:
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
def end_epoch(self): def end_epoch(self):
""" """
......
gitpython==3.0.2 gitpython==3.0.2
tensorboard>=1.14.0
tensorboardX==1.8
psutil==5.6.3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment