distiller.py 25.3 KB
Newer Older
VictorSanh's avatar
VictorSanh committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
VictorSanh's avatar
VictorSanh committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
VictorSanh's avatar
VictorSanh committed
15
16
""" The distiller to distil the student.
    Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
VictorSanh's avatar
VictorSanh committed
17
"""
VictorSanh's avatar
VictorSanh committed
18
19
import os
import math
VictorSanh's avatar
VictorSanh committed
20
import psutil
21
import time
VictorSanh's avatar
VictorSanh committed
22
23
24
25
26
27
from tqdm import trange, tqdm
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
28
from torch.optim import AdamW
VictorSanh's avatar
VictorSanh committed
29
30
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import RandomSampler, BatchSampler, DataLoader
VictorSanh's avatar
VictorSanh committed
31

32
33
34
35
36
try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter

37
from transformers import get_linear_schedule_with_warmup
VictorSanh's avatar
VictorSanh committed
38
39

from utils import logger
VictorSanh's avatar
VictorSanh committed
40
41
from lm_seqs_dataset import LmSeqsDataset
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
VictorSanh's avatar
VictorSanh committed
42
43
44
45

class Distiller:
    def __init__(self,
                 params: dict,
VictorSanh's avatar
VictorSanh committed
46
                 dataset: LmSeqsDataset,
VictorSanh's avatar
VictorSanh committed
47
48
49
50
51
52
53
54
55
56
57
58
                 token_probs: torch.tensor,
                 student: nn.Module,
                 teacher: nn.Module):
        logger.info('Initializing Distiller')
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

VictorSanh's avatar
VictorSanh committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)

        self.dataloader = DataLoader(dataset=dataset,
                                     batch_sampler=sampler,
                                     collate_fn=dataset.batch_sequences)
VictorSanh's avatar
VictorSanh committed
76
77
78
79
80
81

        self.temperature = params.temperature
        assert self.temperature > 0.

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
VictorSanh's avatar
VictorSanh committed
82
        self.alpha_clm = params.alpha_clm
VictorSanh's avatar
VictorSanh committed
83
        self.alpha_mse = params.alpha_mse
84
        self.alpha_cos = params.alpha_cos
VictorSanh's avatar
VictorSanh committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

        self.mlm = params.mlm
        if self.mlm:
            logger.info(f'Using MLM loss for LM step.')
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
            self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
            logger.info(f'Using CLM loss for LM step.')
VictorSanh's avatar
VictorSanh committed
100
101
102
103
104
105
106
107
108

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
VictorSanh's avatar
VictorSanh committed
109
        self.last_loss_clm = 0
110
111
112
        if self.alpha_mse > 0.: self.last_loss_mse = 0
        if self.alpha_cos > 0.: self.last_loss_cos = 0
        self.last_log = 0
VictorSanh's avatar
VictorSanh committed
113
114

        self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
LysandreJik's avatar
LysandreJik committed
115
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
116
117
118
119
        if self.alpha_mse > 0.:
            self.mse_loss_fct = nn.MSELoss(reduction='sum')
        if self.alpha_cos > 0.:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')
VictorSanh's avatar
VictorSanh committed
120
121
122

        logger.info('--- Initializing model optimizer')
        assert params.gradient_accumulation_steps >= 1
VictorSanh's avatar
VictorSanh committed
123
        self.num_steps_epoch = len(self.dataloader)
VictorSanh's avatar
VictorSanh committed
124
125
126
127
128
129
130
131
132
133
134
135
136
        num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay},
            {'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
        ]
        logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad]))
        logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))
137
138

        warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
139
140
141
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=num_train_optimization_steps)
VictorSanh's avatar
VictorSanh committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(self.student,
                                                          self.optimizer,
                                                          opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
                logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
                logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
                self.student = DistributedDataParallel(self.student,
                                                       device_ids=[params.local_rank],
VictorSanh's avatar
VictorSanh committed
164
165
                                                       output_device=params.local_rank,
                                                       find_unused_parameters=True)
VictorSanh's avatar
VictorSanh committed
166
167
168
169
170

        self.is_master = params.is_master
        if self.is_master:
            logger.info('--- Initializing Tensorboard')
            self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
VictorSanh's avatar
VictorSanh committed
171
172
            self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0)
            self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0)
VictorSanh's avatar
VictorSanh committed
173

VictorSanh's avatar
VictorSanh committed
174
175
    def prepare_batch_mlm(self,
                          batch):
VictorSanh's avatar
VictorSanh committed
176
177
178
179
180
181
182
183
184
185
186
187
188
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
thomwolf's avatar
thomwolf committed
189
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -100 where there is nothing to predict.
VictorSanh's avatar
VictorSanh committed
190
191
192
193
194
195
196
197
198
199
200
201
202
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
VictorSanh's avatar
VictorSanh committed
203
        pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
VictorSanh's avatar
VictorSanh committed
221
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
VictorSanh's avatar
VictorSanh committed
222
223
224
225
226
        _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
        probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
        _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

LysandreJik's avatar
LysandreJik committed
227
        mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
228

VictorSanh's avatar
VictorSanh committed
229
230
231
        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

VictorSanh's avatar
VictorSanh committed
232
233
        return token_ids, attn_mask, mlm_labels

VictorSanh's avatar
VictorSanh committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    def prepare_batch_clm(self,
                          batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
thomwolf's avatar
thomwolf committed
249
            clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -100 where there is nothing to predict.
VictorSanh's avatar
VictorSanh committed
250
251
252
253
254
255
256
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
LysandreJik's avatar
LysandreJik committed
257
        clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
258
259
260
261
262
263

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

VictorSanh's avatar
VictorSanh committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    def round_batch(self,
                    x: torch.tensor,
                    lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
VictorSanh's avatar
VictorSanh committed
301
302
303
304
            if self.mlm:
                pad_id = self.params.special_tok_ids['pad_token']
            else:
                pad_id = self.params.special_tok_ids['unk_token']
VictorSanh's avatar
VictorSanh committed
305
306
307
308
309
310
311
312
313
314
315
316
317
            padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
        if self.is_master: logger.info('Starting training')
318
        self.last_log = time.time()
VictorSanh's avatar
VictorSanh committed
319
320
321
322
323
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
324
325
            if self.multi_gpu:
                torch.distributed.barrier()
VictorSanh's avatar
VictorSanh committed
326

VictorSanh's avatar
VictorSanh committed
327
328
            iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
VictorSanh's avatar
VictorSanh committed
329
330
331
                if self.params.n_gpu > 0:
                    batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)

VictorSanh's avatar
VictorSanh committed
332
333
334
335
336
                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
                self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
VictorSanh's avatar
VictorSanh committed
337
338
339
340
341
342
343
344
345

                iter_bar.update()
                iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
                                      'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'})
            iter_bar.close()

            if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
            self.end_epoch()

346
347
348
349
        if self.is_master:
            logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
            self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
            logger.info('Training is finished')
VictorSanh's avatar
VictorSanh committed
350
351
352
353

    def step(self,
             input_ids: torch.tensor,
             attention_mask: torch.tensor,
VictorSanh's avatar
VictorSanh committed
354
             lm_labels: torch.tensor):
VictorSanh's avatar
VictorSanh committed
355
356
357
358
359
360
361
362
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
VictorSanh's avatar
VictorSanh committed
363
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
VictorSanh's avatar
VictorSanh committed
364
        """
VictorSanh's avatar
VictorSanh committed
365
366
367
368
369
370
371
372
        if self.mlm:
            s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask)     # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
        else:
            s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None)            # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None)           # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
373
374
375
376
377
        assert s_logits.size() == t_logits.size()

        #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
        if self.params.restrict_ce_to_mask:
VictorSanh's avatar
VictorSanh committed
378
            mask = (lm_labels>-1).unsqueeze(-1).expand_as(s_logits)    # (bs, seq_lenth, voc_size)
VictorSanh's avatar
VictorSanh committed
379
380
381
382
383
384
385
386
387
388
389
        else:
            mask = attention_mask.unsqueeze(-1).expand_as(s_logits)    # (bs, seq_lenth, voc_size)
        s_logits_slct = torch.masked_select(s_logits, mask)            # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))      # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, mask)            # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))      # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
                                   F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
        loss = self.alpha_ce*loss_ce
VictorSanh's avatar
VictorSanh committed
390

VictorSanh's avatar
VictorSanh committed
391
        if self.alpha_mlm > 0.:
VictorSanh's avatar
VictorSanh committed
392
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
VictorSanh's avatar
VictorSanh committed
393
            loss += self.alpha_mlm * loss_mlm
VictorSanh's avatar
VictorSanh committed
394
395
396
397
398
399
400
        if self.alpha_clm > 0.:
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                                        shift_labels.view(-1))
            loss += self.alpha_clm * loss_clm

VictorSanh's avatar
VictorSanh committed
401
402
403
        if self.alpha_mse > 0.:
            loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        if self.alpha_cos > 0.:
            s_hidden_states = s_hidden_states[-1]                              # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]                              # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)     # (bs, seq_length, dim)
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)
            
            s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)        # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)                # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)        # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)                # (bs * seq_length, dim)
        
            target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos
VictorSanh's avatar
VictorSanh committed
419
420
421
422
423
424

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.:
            self.last_loss_mlm = loss_mlm.item()
VictorSanh's avatar
VictorSanh committed
425
426
        if self.alpha_clm > 0.:
            self.last_loss_clm = loss_clm.item()
VictorSanh's avatar
VictorSanh committed
427
428
        if self.alpha_mse > 0.:
            self.last_loss_mse = loss_mse.item()
429
430
        if self.alpha_cos > 0.:
            self.last_loss_cos = loss_cos.item()
VictorSanh's avatar
VictorSanh committed
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self,
                 loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error('NaN detected')
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
VictorSanh's avatar
VictorSanh committed
468
            self.scheduler.step()
VictorSanh's avatar
VictorSanh committed
469
470
471
472
473
474
475
476
477
478

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
479
            self.last_log = time.time()
VictorSanh's avatar
VictorSanh committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter)

        self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
        if self.alpha_mlm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
VictorSanh's avatar
VictorSanh committed
503
504
        if self.alpha_clm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter)
VictorSanh's avatar
VictorSanh committed
505
506
        if self.alpha_mse > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
507
508
        if self.alpha_cos > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter)
VictorSanh's avatar
VictorSanh committed
509
        self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
VictorSanh's avatar
VictorSanh committed
510
511
        
        self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
512
        self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter)
VictorSanh's avatar
VictorSanh committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.')

        if self.is_master:
            self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth')
            self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self,
                        checkpoint_name: str = 'checkpoint.pth'):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))