"vscode:/vscode.git/clone" did not exist on "61c494499382a851db9980de780bc76e67127d8d"
distiller.py 25.5 KB
Newer Older
VictorSanh's avatar
VictorSanh committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
VictorSanh's avatar
VictorSanh committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
VictorSanh's avatar
VictorSanh committed
15
16
""" The distiller to distil the student.
    Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
VictorSanh's avatar
VictorSanh committed
17
"""
VictorSanh's avatar
VictorSanh committed
18
import math
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import os
20
import time
VictorSanh's avatar
VictorSanh committed
21

22
import psutil
VictorSanh's avatar
VictorSanh committed
23
import torch
24
from torch import nn
25
from torch.optim import AdamW
Aymeric Augustin's avatar
Aymeric Augustin committed
26
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
VictorSanh's avatar
VictorSanh committed
27
from torch.utils.data.distributed import DistributedSampler
28
from tqdm import tqdm
Aymeric Augustin's avatar
Aymeric Augustin committed
29
30
31
32
33
34

from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from transformers import get_linear_schedule_with_warmup
from utils import logger

VictorSanh's avatar
VictorSanh committed
35

36
37
try:
    from torch.utils.tensorboard import SummaryWriter
38
except ImportError:
39
40
    from tensorboardX import SummaryWriter

41

VictorSanh's avatar
VictorSanh committed
42
class Distiller:
43
44
45
46
    def __init__(
        self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
    ):
        logger.info("Initializing Distiller")
VictorSanh's avatar
VictorSanh committed
47
48
49
50
51
52
53
54
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

VictorSanh's avatar
VictorSanh committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)

69
        self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
VictorSanh's avatar
VictorSanh committed
70
71

        self.temperature = params.temperature
72
        assert self.temperature > 0.0
VictorSanh's avatar
VictorSanh committed
73
74
75

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
VictorSanh's avatar
VictorSanh committed
76
        self.alpha_clm = params.alpha_clm
VictorSanh's avatar
VictorSanh committed
77
        self.alpha_mse = params.alpha_mse
78
        self.alpha_cos = params.alpha_cos
VictorSanh's avatar
VictorSanh committed
79
80
81

        self.mlm = params.mlm
        if self.mlm:
82
            logger.info("Using MLM loss for LM step.")
VictorSanh's avatar
VictorSanh committed
83
84
85
86
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
87
88
            self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
VictorSanh's avatar
VictorSanh committed
89
90
91
92
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
93
            logger.info("Using CLM loss for LM step.")
VictorSanh's avatar
VictorSanh committed
94
95
96
97
98
99
100
101
102

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
VictorSanh's avatar
VictorSanh committed
103
        self.last_loss_clm = 0
104
105
106
107
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
108
        self.last_log = 0
VictorSanh's avatar
VictorSanh committed
109

110
        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
Lysandre's avatar
Lysandre committed
111
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
112
113
114
115
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
VictorSanh's avatar
VictorSanh committed
116

117
        logger.info("--- Initializing model optimizer")
VictorSanh's avatar
VictorSanh committed
118
        assert params.gradient_accumulation_steps >= 1
VictorSanh's avatar
VictorSanh committed
119
        self.num_steps_epoch = len(self.dataloader)
120
121
122
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
        )
VictorSanh's avatar
VictorSanh committed
123

124
        no_decay = ["bias", "LayerNorm.weight"]
VictorSanh's avatar
VictorSanh committed
125
        optimizer_grouped_parameters = [
126
127
128
129
130
131
132
133
134
135
136
137
            {
                "params": [
                    p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": 0.0,
            },
VictorSanh's avatar
VictorSanh committed
138
        ]
139
140
141
142
        logger.info(
            "------ Number of trainable parameters (student): %i"
            % sum([p.numel() for p in self.student.parameters() if p.requires_grad])
        )
VictorSanh's avatar
VictorSanh committed
143
        logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
144
145
146
        self.optimizer = AdamW(
            optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
        )
147
148

        warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
149
150
151
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
        )
VictorSanh's avatar
VictorSanh committed
152
153
154
155
156
157
158

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
159
160
161
            self.student, self.optimizer = amp.initialize(
                self.student, self.optimizer, opt_level=self.params.fp16_opt_level
            )
VictorSanh's avatar
VictorSanh committed
162
163
164
165
166
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
167

VictorSanh's avatar
VictorSanh committed
168
169
170
171
                logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
172

VictorSanh's avatar
VictorSanh committed
173
                logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
174
175
176
177
178
179
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True,
                )
VictorSanh's avatar
VictorSanh committed
180
181
182

        self.is_master = params.is_master
        if self.is_master:
183
184
185
186
            logger.info("--- Initializing Tensorboard")
            self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
            self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
            self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
VictorSanh's avatar
VictorSanh committed
187

188
    def prepare_batch_mlm(self, batch):
VictorSanh's avatar
VictorSanh committed
189
        """
Santiago Castro's avatar
Santiago Castro committed
190
        Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM.
VictorSanh's avatar
VictorSanh committed
191
192
193
194
195
196
197
198
199
200
201

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
Santiago Castro's avatar
Santiago Castro committed
202
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict.
VictorSanh's avatar
VictorSanh committed
203
204
205
206
207
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

208
        attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
VictorSanh's avatar
VictorSanh committed
209
210
211
212
213
214
215

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
216
217
218
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
219
220
221
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

222
        pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
VictorSanh's avatar
VictorSanh committed
223
224
225
226
227
228
229
230

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
231
                    pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
VictorSanh's avatar
VictorSanh committed
232
233
234
235
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
VictorSanh's avatar
VictorSanh committed
236
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
237
        _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
VictorSanh's avatar
VictorSanh committed
238
        probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
239
240
241
242
243
        _token_ids = (
            _token_ids_mask * (probs == 0).long()
            + _token_ids_real * (probs == 1).long()
            + _token_ids_rand * (probs == 2).long()
        )
VictorSanh's avatar
VictorSanh committed
244
245
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

Lysandre's avatar
Lysandre committed
246
        mlm_labels[~pred_mask] = -100  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
247

VictorSanh's avatar
VictorSanh committed
248
249
250
        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

VictorSanh's avatar
VictorSanh committed
251
252
        return token_ids, attn_mask, mlm_labels

253
    def prepare_batch_clm(self, batch):
VictorSanh's avatar
VictorSanh committed
254
        """
Santiago Castro's avatar
Santiago Castro committed
255
        Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.
VictorSanh's avatar
VictorSanh committed
256
257
258
259
260
261
262
263
264
265
266

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
267
            clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
VictorSanh's avatar
VictorSanh committed
268
269
270
271
272
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

273
        attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
VictorSanh's avatar
VictorSanh committed
274
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
Lysandre's avatar
Lysandre committed
275
        clm_labels[~attn_mask] = -100  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
276
277
278
279
280
281

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

282
    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
VictorSanh's avatar
VictorSanh committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
VictorSanh's avatar
VictorSanh committed
317
            if self.mlm:
318
                pad_id = self.params.special_tok_ids["pad_token"]
VictorSanh's avatar
VictorSanh committed
319
            else:
320
                pad_id = self.params.special_tok_ids["unk_token"]
VictorSanh's avatar
VictorSanh committed
321
322
323
324
325
326
327
328
329
330
331
332
            padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
333
334
        if self.is_master:
            logger.info("Starting training")
335
        self.last_log = time.time()
VictorSanh's avatar
VictorSanh committed
336
337
338
339
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
340
341
            if self.is_master:
                logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
342
343
            if self.multi_gpu:
                torch.distributed.barrier()
VictorSanh's avatar
VictorSanh committed
344

VictorSanh's avatar
VictorSanh committed
345
346
            iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
VictorSanh's avatar
VictorSanh committed
347
                if self.params.n_gpu > 0:
348
                    batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
VictorSanh's avatar
VictorSanh committed
349

VictorSanh's avatar
VictorSanh committed
350
351
352
353
354
                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
                self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
VictorSanh's avatar
VictorSanh committed
355
356

                iter_bar.update()
357
358
359
                iter_bar.set_postfix(
                    {"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
                )
VictorSanh's avatar
VictorSanh committed
360
361
            iter_bar.close()

362
363
            if self.is_master:
                logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
VictorSanh's avatar
VictorSanh committed
364
365
            self.end_epoch()

366
        if self.is_master:
367
368
            logger.info("Save very last checkpoint as `pytorch_model.bin`.")
            self.save_checkpoint(checkpoint_name="pytorch_model.bin")
369
370
371
            logger.info("Training is finished")

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
VictorSanh's avatar
VictorSanh committed
372
373
374
375
376
377
378
379
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
VictorSanh's avatar
VictorSanh committed
380
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
VictorSanh's avatar
VictorSanh committed
381
        """
VictorSanh's avatar
VictorSanh committed
382
        if self.mlm:
383
384
385
            s_logits, s_hidden_states = self.student(
                input_ids=input_ids, attention_mask=attention_mask
            )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
386
            with torch.no_grad():
387
388
389
                t_logits, t_hidden_states = self.teacher(
                    input_ids=input_ids, attention_mask=attention_mask
                )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
390
        else:
391
392
393
            s_logits, _, s_hidden_states = self.student(
                input_ids=input_ids, attention_mask=None
            )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
394
            with torch.no_grad():
395
396
397
                t_logits, _, t_hidden_states = self.teacher(
                    input_ids=input_ids, attention_mask=None
                )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
398
399
        assert s_logits.size() == t_logits.size()

400
401
        # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
VictorSanh's avatar
VictorSanh committed
402
        if self.params.restrict_ce_to_mask:
403
            mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
404
        else:
405
            mask = attention_mask.unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
406
407
408
409
        s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
VictorSanh's avatar
VictorSanh committed
410
411
        assert t_logits_slct.size() == s_logits_slct.size()

412
413
        loss_ce = (
            self.ce_loss_fct(
414
415
                nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
                nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
416
417
418
419
            )
            * (self.temperature) ** 2
        )
        loss = self.alpha_ce * loss_ce
VictorSanh's avatar
VictorSanh committed
420

421
        if self.alpha_mlm > 0.0:
VictorSanh's avatar
VictorSanh committed
422
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
VictorSanh's avatar
VictorSanh committed
423
            loss += self.alpha_mlm * loss_mlm
424
        if self.alpha_clm > 0.0:
VictorSanh's avatar
VictorSanh committed
425
426
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
427
            loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
VictorSanh's avatar
VictorSanh committed
428
429
            loss += self.alpha_clm * loss_clm

430
431
432
433
        if self.alpha_mse > 0.0:
            loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
                0
            )  # Reproducing batchmean reduction
VictorSanh's avatar
VictorSanh committed
434
            loss += self.alpha_mse * loss_mse
435
436
437
438
        if self.alpha_cos > 0.0:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)  # (bs, seq_length, dim)
439
440
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)
441
442
443
444
445
446
447

            s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)  # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)  # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

            target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
448
449
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos
VictorSanh's avatar
VictorSanh committed
450
451
452
453

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
454
        if self.alpha_mlm > 0.0:
VictorSanh's avatar
VictorSanh committed
455
            self.last_loss_mlm = loss_mlm.item()
456
        if self.alpha_clm > 0.0:
VictorSanh's avatar
VictorSanh committed
457
            self.last_loss_clm = loss_clm.item()
458
        if self.alpha_mse > 0.0:
VictorSanh's avatar
VictorSanh committed
459
            self.last_loss_mse = loss_mse.item()
460
        if self.alpha_cos > 0.0:
461
            self.last_loss_cos = loss_cos.item()
VictorSanh's avatar
VictorSanh committed
462
463
464
465
466

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

467
    def optimize(self, loss):
VictorSanh's avatar
VictorSanh committed
468
469
470
471
472
473
474
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
475
            logger.error("NaN detected")
VictorSanh's avatar
VictorSanh committed
476
477
478
479
480
481
482
483
484
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
485

VictorSanh's avatar
VictorSanh committed
486
487
488
489
490
491
492
493
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
494
                nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
VictorSanh's avatar
VictorSanh committed
495
            else:
496
                nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
VictorSanh's avatar
VictorSanh committed
497
498
            self.optimizer.step()
            self.optimizer.zero_grad()
VictorSanh's avatar
VictorSanh committed
499
            self.scheduler.step()
VictorSanh's avatar
VictorSanh committed
500
501
502
503
504
505
506
507
508
509

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
510
            self.last_log = time.time()
VictorSanh's avatar
VictorSanh committed
511
512
513
514
515
516
517
518
519
520
521
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
522
523
524
525
526
527
            self.tensorboard.add_scalar(
                tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
            )
            self.tensorboard.add_scalar(
                tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
            )
VictorSanh's avatar
VictorSanh committed
528
529
            if param.grad is None:
                continue
530
531
532
533
534
535
536
537
538
539
540
541
            self.tensorboard.add_scalar(
                tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
            )
            self.tensorboard.add_scalar(
                tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
            )

        self.tensorboard.add_scalar(
            tag="losses/cum_avg_loss_epoch",
            scalar_value=self.total_loss_epoch / self.n_iter,
            global_step=self.n_total_iter,
        )
VictorSanh's avatar
VictorSanh committed
542
        self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        self.tensorboard.add_scalar(
            tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
        )
        if self.alpha_mlm > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
            )
        if self.alpha_clm > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
            )
        if self.alpha_mse > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
            )
        if self.alpha_cos > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
            )
        self.tensorboard.add_scalar(
            tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
        )

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(
            tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
        )
VictorSanh's avatar
VictorSanh committed
574
575
576
577
578
579

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
580
        logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
VictorSanh's avatar
VictorSanh committed
581
582

        if self.is_master:
583
584
585
586
            self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
            self.tensorboard.add_scalar(
                tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
            )
VictorSanh's avatar
VictorSanh committed
587
588
589
590
591
592

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

593
    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
VictorSanh's avatar
VictorSanh committed
594
595
596
597
598
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
599
        mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
VictorSanh's avatar
VictorSanh committed
600
601
602
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))