Commit bb9c5ead authored by VictorSanh's avatar VictorSanh Committed by Victor SANH
Browse files

update distiller

parent a12ab0a8
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" The distiller to distil DistilBERT """ The distiller to distil the student.
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" """
import os import os
import math import math
...@@ -28,16 +28,19 @@ import torch ...@@ -28,16 +28,19 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import RandomSampler, BatchSampler, DataLoader
from transformers import WarmupLinearSchedule from transformers import WarmupLinearSchedule
from utils import logger from utils import logger
from dataset import Dataset from lm_seqs_dataset import LmSeqsDataset
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
class Distiller: class Distiller:
def __init__(self, def __init__(self,
params: dict, params: dict,
dataloader: Dataset, dataset: LmSeqsDataset,
token_probs: torch.tensor, token_probs: torch.tensor,
student: nn.Module, student: nn.Module,
teacher: nn.Module): teacher: nn.Module):
...@@ -50,33 +53,47 @@ class Distiller: ...@@ -50,33 +53,47 @@ class Distiller:
self.student = student self.student = student
self.teacher = teacher self.teacher = teacher
self.dataloader = dataloader self.student_config = student.config
if self.params.n_gpu > 1: self.vocab_size = student.config.vocab_size
self.dataloader.split()
self.get_iterator(seed=params.seed) if params.n_gpu <= 1:
sampler = RandomSampler(dataset)
else:
sampler = DistributedSampler(dataset)
if params.group_by_size:
groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
else:
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
self.dataloader = DataLoader(dataset=dataset,
batch_sampler=sampler,
collate_fn=dataset.batch_sequences)
self.temperature = params.temperature self.temperature = params.temperature
assert self.temperature > 0. assert self.temperature > 0.
self.alpha_ce = params.alpha_ce self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm self.alpha_mlm = params.alpha_mlm
self.alpha_clm = params.alpha_clm
self.alpha_mse = params.alpha_mse self.alpha_mse = params.alpha_mse
self.alpha_cos = params.alpha_cos self.alpha_cos = params.alpha_cos
assert self.alpha_ce >= 0.
assert self.alpha_mlm >= 0. self.mlm = params.mlm
assert self.alpha_mse >= 0. if self.mlm:
assert self.alpha_cos >= 0. logger.info(f'Using MLM loss for LM step.')
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0. self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0
self.mlm_mask_prop = params.mlm_mask_prop assert params.word_mask + params.word_keep + params.word_rand == 1.0
assert 0.0 <= self.mlm_mask_prop <= 1.0 self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
assert params.word_mask + params.word_keep + params.word_rand == 1.0 self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand]) self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs if self.fp16:
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs self.pred_probs = self.pred_probs.half()
if self.fp16: self.token_probs = self.token_probs.half()
self.pred_probs = self.pred_probs.half() else:
self.token_probs = self.token_probs.half() logger.info(f'Using CLM loss for LM step.')
self.epoch = 0 self.epoch = 0
self.n_iter = 0 self.n_iter = 0
...@@ -86,12 +103,13 @@ class Distiller: ...@@ -86,12 +103,13 @@ class Distiller:
self.last_loss = 0 self.last_loss = 0
self.last_loss_ce = 0 self.last_loss_ce = 0
self.last_loss_mlm = 0 self.last_loss_mlm = 0
self.last_loss_clm = 0
if self.alpha_mse > 0.: self.last_loss_mse = 0 if self.alpha_mse > 0.: self.last_loss_mse = 0
if self.alpha_cos > 0.: self.last_loss_cos = 0 if self.alpha_cos > 0.: self.last_loss_cos = 0
self.last_log = 0 self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.mse_loss_fct = nn.MSELoss(reduction='sum') self.mse_loss_fct = nn.MSELoss(reduction='sum')
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
...@@ -99,7 +117,7 @@ class Distiller: ...@@ -99,7 +117,7 @@ class Distiller:
logger.info('--- Initializing model optimizer') logger.info('--- Initializing model optimizer')
assert params.gradient_accumulation_steps >= 1 assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1 self.num_steps_epoch = len(self.dataloader)
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.weight']
...@@ -140,43 +158,18 @@ class Distiller: ...@@ -140,43 +158,18 @@ class Distiller:
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.") logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student, self.student = DistributedDataParallel(self.student,
device_ids=[params.local_rank], device_ids=[params.local_rank],
output_device=params.local_rank) output_device=params.local_rank,
find_unused_parameters=True)
self.is_master = params.is_master self.is_master = params.is_master
if self.is_master: if self.is_master:
logger.info('--- Initializing Tensorboard') logger.info('--- Initializing Tensorboard')
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train')) self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0) self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0)
self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0)
def get_iterator(self,
seed: int = None):
"""
Initialize the data iterator.
Each process has its own data iterator (iterating on his own random portion of the dataset).
Input: def prepare_batch_mlm(self,
------ batch):
seed: `int` - The random seed.
"""
logger.info('--- Initializing Data Iterator')
self.data_iterator = self.dataloader.get_iterator(seed=seed)
def get_batch(self):
"""
Call the data iterator to output a new batch.
If the data iterator went through the whole dataset, create a new iterator.
"""
assert hasattr(self, 'data_iterator')
try:
x = next(self.data_iterator)
except StopIteration:
logger.warning('--- Went through the whole dataset. Creating new data iterator.')
self.data_iterator = self.dataloader.get_iterator()
x = next(self.data_iterator)
return x
def prepare_batch(self,
batch):
""" """
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM. Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
...@@ -222,7 +215,7 @@ class Distiller: ...@@ -222,7 +215,7 @@ class Distiller:
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item() assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
_token_ids_real = token_ids[pred_mask] _token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(self.params.vocab_size) _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token']) _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True) probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
...@@ -230,8 +223,41 @@ class Distiller: ...@@ -230,8 +223,41 @@ class Distiller:
mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
return token_ids, attn_mask, mlm_labels return token_ids, attn_mask, mlm_labels
def prepare_batch_clm(self,
batch):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
"""
token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0)
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
clm_labels[~attn_mask] = -1 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
return token_ids, attn_mask, clm_labels
def round_batch(self, def round_batch(self,
x: torch.tensor, x: torch.tensor,
lengths: torch.tensor): lengths: torch.tensor):
...@@ -269,7 +295,10 @@ class Distiller: ...@@ -269,7 +295,10 @@ class Distiller:
if ml1 % 8 != 0: if ml1 % 8 != 0:
pad = 8 - (ml1 % 8) pad = 8 - (ml1 % 8)
ml2 = ml1 + pad ml2 = ml1 + pad
pad_id = self.params.special_tok_ids['pad_token'] if self.mlm:
pad_id = self.params.special_tok_ids['pad_token']
else:
pad_id = self.params.special_tok_ids['unk_token']
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id) padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
x = torch.cat([x, padding_tensor], 1) x = torch.cat([x, padding_tensor], 1)
assert x.size() == (bs2, ml2) assert x.size() == (bs2, ml2)
...@@ -292,14 +321,16 @@ class Distiller: ...@@ -292,14 +321,16 @@ class Distiller:
if self.multi_gpu: if self.multi_gpu:
torch.distributed.barrier() torch.distributed.barrier()
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for __ in range(self.num_steps_epoch): for batch in iter_bar:
batch = self.get_batch()
if self.params.n_gpu > 0: if self.params.n_gpu > 0:
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch) batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
token_ids, attn_mask, mlm_labels = self.prepare_batch(batch=batch)
self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels) if self.mlm:
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
else:
token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
iter_bar.update() iter_bar.update()
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}', iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
...@@ -317,7 +348,7 @@ class Distiller: ...@@ -317,7 +348,7 @@ class Distiller:
def step(self, def step(self,
input_ids: torch.tensor, input_ids: torch.tensor,
attention_mask: torch.tensor, attention_mask: torch.tensor,
mlm_labels: torch.tensor): lm_labels: torch.tensor):
""" """
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation), One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation). and possibly a parameter update (depending on the gradient accumulation).
...@@ -326,17 +357,22 @@ class Distiller: ...@@ -326,17 +357,22 @@ class Distiller:
------ ------
input_ids: `torch.tensor(bs, seq_length)` - The token ids. input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
""" """
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) if self.mlm:
with torch.no_grad(): s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) with torch.no_grad():
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
else:
s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size() assert s_logits.size() == t_logits.size()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2 #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if self.params.restrict_ce_to_mask: if self.params.restrict_ce_to_mask:
mask = (mlm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) mask = (lm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
else: else:
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
...@@ -348,13 +384,20 @@ class Distiller: ...@@ -348,13 +384,20 @@ class Distiller:
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1), loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2 F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
loss = self.alpha_ce*loss_ce loss = self.alpha_ce*loss_ce
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.:
loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1)) loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
loss += self.alpha_mlm * loss_mlm loss += self.alpha_mlm * loss_mlm
if self.alpha_clm > 0.:
shift_logits = s_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
loss += self.alpha_clm * loss_clm
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse loss += self.alpha_mse * loss_mse
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
...@@ -376,6 +419,8 @@ class Distiller: ...@@ -376,6 +419,8 @@ class Distiller:
self.last_loss_ce = loss_ce.item() self.last_loss_ce = loss_ce.item()
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.:
self.last_loss_mlm = loss_mlm.item() self.last_loss_mlm = loss_mlm.item()
if self.alpha_clm > 0.:
self.last_loss_clm = loss_clm.item()
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.last_loss_mse = loss_mse.item() self.last_loss_mse = loss_mse.item()
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
...@@ -452,6 +497,8 @@ class Distiller: ...@@ -452,6 +497,8 @@ class Distiller:
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
if self.alpha_clm > 0.:
self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter)
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment