distiller.py 25.6 KB
Newer Older
VictorSanh's avatar
VictorSanh committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
VictorSanh's avatar
VictorSanh committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
VictorSanh's avatar
VictorSanh committed
15
16
""" The distiller to distil the student.
    Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
VictorSanh's avatar
VictorSanh committed
17
"""
VictorSanh's avatar
VictorSanh committed
18
import math
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import os
20
import time
VictorSanh's avatar
VictorSanh committed
21

22
import psutil
VictorSanh's avatar
VictorSanh committed
23
import torch
24
from torch import nn
25
from torch.optim import AdamW
Aymeric Augustin's avatar
Aymeric Augustin committed
26
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
VictorSanh's avatar
VictorSanh committed
27
from torch.utils.data.distributed import DistributedSampler
28
from tqdm import tqdm
Aymeric Augustin's avatar
Aymeric Augustin committed
29
30
31
32
33
34

from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from transformers import get_linear_schedule_with_warmup
from utils import logger

VictorSanh's avatar
VictorSanh committed
35

36
37
try:
    from torch.utils.tensorboard import SummaryWriter
38
except ImportError:
39
40
    from tensorboardX import SummaryWriter

41

VictorSanh's avatar
VictorSanh committed
42
class Distiller:
43
44
45
46
    def __init__(
        self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
    ):
        logger.info("Initializing Distiller")
VictorSanh's avatar
VictorSanh committed
47
48
49
50
51
52
53
54
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

VictorSanh's avatar
VictorSanh committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)

69
        self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
VictorSanh's avatar
VictorSanh committed
70
71

        self.temperature = params.temperature
72
        assert self.temperature > 0.0
VictorSanh's avatar
VictorSanh committed
73
74
75

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
VictorSanh's avatar
VictorSanh committed
76
        self.alpha_clm = params.alpha_clm
VictorSanh's avatar
VictorSanh committed
77
        self.alpha_mse = params.alpha_mse
78
        self.alpha_cos = params.alpha_cos
VictorSanh's avatar
VictorSanh committed
79
80
81

        self.mlm = params.mlm
        if self.mlm:
82
            logger.info("Using MLM loss for LM step.")
VictorSanh's avatar
VictorSanh committed
83
84
85
86
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
87
88
            self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
VictorSanh's avatar
VictorSanh committed
89
90
91
92
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
93
            logger.info("Using CLM loss for LM step.")
VictorSanh's avatar
VictorSanh committed
94
95
96
97
98
99
100
101
102

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
VictorSanh's avatar
VictorSanh committed
103
        self.last_loss_clm = 0
104
105
106
107
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
108
        self.last_log = 0
VictorSanh's avatar
VictorSanh committed
109

110
        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
Lysandre's avatar
Lysandre committed
111
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
112
113
114
115
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
VictorSanh's avatar
VictorSanh committed
116

117
        logger.info("--- Initializing model optimizer")
VictorSanh's avatar
VictorSanh committed
118
        assert params.gradient_accumulation_steps >= 1
VictorSanh's avatar
VictorSanh committed
119
        self.num_steps_epoch = len(self.dataloader)
120
121
122
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
        )
VictorSanh's avatar
VictorSanh committed
123

124
        no_decay = ["bias", "LayerNorm.weight"]
VictorSanh's avatar
VictorSanh committed
125
        optimizer_grouped_parameters = [
126
127
128
129
130
131
132
133
134
135
136
137
            {
                "params": [
                    p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": 0.0,
            },
VictorSanh's avatar
VictorSanh committed
138
        ]
139
140
141
142
        logger.info(
            "------ Number of trainable parameters (student): %i"
            % sum([p.numel() for p in self.student.parameters() if p.requires_grad])
        )
VictorSanh's avatar
VictorSanh committed
143
        logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
144
145
146
        self.optimizer = AdamW(
            optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
        )
147
148

        warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
149
150
151
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
        )
VictorSanh's avatar
VictorSanh committed
152
153
154
155
156
157
158

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
159
160
161
            self.student, self.optimizer = amp.initialize(
                self.student, self.optimizer, opt_level=self.params.fp16_opt_level
            )
VictorSanh's avatar
VictorSanh committed
162
163
164
165
166
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
167

VictorSanh's avatar
VictorSanh committed
168
169
170
171
                logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
172

VictorSanh's avatar
VictorSanh committed
173
                logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
174
175
176
177
178
179
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True,
                )
VictorSanh's avatar
VictorSanh committed
180
181
182

        self.is_master = params.is_master
        if self.is_master:
183
184
185
186
            logger.info("--- Initializing Tensorboard")
            self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
            self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
            self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
VictorSanh's avatar
VictorSanh committed
187

188
    def prepare_batch_mlm(self, batch):
VictorSanh's avatar
VictorSanh committed
189
        """
Santiago Castro's avatar
Santiago Castro committed
190
        Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM.
VictorSanh's avatar
VictorSanh committed
191
192
193
194
195
196
197
198
199
200
201

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
Santiago Castro's avatar
Santiago Castro committed
202
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict.
VictorSanh's avatar
VictorSanh committed
203
204
205
206
207
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

208
        attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
VictorSanh's avatar
VictorSanh committed
209
210
211
212
213
214
215

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
216
217
218
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
219
220
221
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

222
        pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
VictorSanh's avatar
VictorSanh committed
223
224
225
226
227
228
229
230

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
231
                    pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
VictorSanh's avatar
VictorSanh committed
232
233
234
235
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
VictorSanh's avatar
VictorSanh committed
236
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
237
        _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
VictorSanh's avatar
VictorSanh committed
238
        probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
239
240
241
242
243
        _token_ids = (
            _token_ids_mask * (probs == 0).long()
            + _token_ids_real * (probs == 1).long()
            + _token_ids_rand * (probs == 2).long()
        )
VictorSanh's avatar
VictorSanh committed
244
245
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

Lysandre's avatar
Lysandre committed
246
        mlm_labels[~pred_mask] = -100  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
247

VictorSanh's avatar
VictorSanh committed
248
249
250
        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

VictorSanh's avatar
VictorSanh committed
251
252
        return token_ids, attn_mask, mlm_labels

253
    def prepare_batch_clm(self, batch):
VictorSanh's avatar
VictorSanh committed
254
        """
Santiago Castro's avatar
Santiago Castro committed
255
        Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.
VictorSanh's avatar
VictorSanh committed
256
257
258
259
260
261
262
263
264
265
266

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
267
            clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
VictorSanh's avatar
VictorSanh committed
268
269
270
271
272
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

273
        attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
VictorSanh's avatar
VictorSanh committed
274
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
Lysandre's avatar
Lysandre committed
275
        clm_labels[~attn_mask] = -100  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
276
277
278
279
280
281

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

282
    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
VictorSanh's avatar
VictorSanh committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
VictorSanh's avatar
VictorSanh committed
317
            if self.mlm:
318
                pad_id = self.params.special_tok_ids["pad_token"]
VictorSanh's avatar
VictorSanh committed
319
            else:
320
                pad_id = self.params.special_tok_ids["unk_token"]
VictorSanh's avatar
VictorSanh committed
321
322
323
324
325
326
327
328
329
330
331
332
            padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
333
334
        if self.is_master:
            logger.info("Starting training")
335
        self.last_log = time.time()
VictorSanh's avatar
VictorSanh committed
336
337
338
339
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
340
341
            if self.is_master:
                logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
342
343
            if self.multi_gpu:
                torch.distributed.barrier()
VictorSanh's avatar
VictorSanh committed
344

VictorSanh's avatar
VictorSanh committed
345
346
            iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
VictorSanh's avatar
VictorSanh committed
347
                if self.params.n_gpu > 0:
348
                    batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
VictorSanh's avatar
VictorSanh committed
349

VictorSanh's avatar
VictorSanh committed
350
351
352
353
354
                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
                self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
VictorSanh's avatar
VictorSanh committed
355
356

                iter_bar.update()
357
358
359
                iter_bar.set_postfix(
                    {"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
                )
VictorSanh's avatar
VictorSanh committed
360
361
            iter_bar.close()

362
363
            if self.is_master:
                logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
VictorSanh's avatar
VictorSanh committed
364
365
            self.end_epoch()

366
        if self.is_master:
367
368
            logger.info("Save very last checkpoint as `pytorch_model.bin`.")
            self.save_checkpoint(checkpoint_name="pytorch_model.bin")
369
370
371
            logger.info("Training is finished")

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
VictorSanh's avatar
VictorSanh committed
372
373
374
375
376
377
378
379
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
VictorSanh's avatar
VictorSanh committed
380
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
VictorSanh's avatar
VictorSanh committed
381
        """
VictorSanh's avatar
VictorSanh committed
382
        if self.mlm:
chutaklee's avatar
chutaklee committed
383
            student_outputs = self.student(
384
385
                input_ids=input_ids, attention_mask=attention_mask
            )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
386
            with torch.no_grad():
chutaklee's avatar
chutaklee committed
387
                teacher_outputs = self.teacher(
388
389
                    input_ids=input_ids, attention_mask=attention_mask
                )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
390
        else:
chutaklee's avatar
chutaklee committed
391
            student_outputs = self.student(input_ids=input_ids, attention_mask=None)  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
392
            with torch.no_grad():
chutaklee's avatar
chutaklee committed
393
394
395
                teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=None)  # (bs, seq_length, voc_size)
        s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
        t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"]
VictorSanh's avatar
VictorSanh committed
396
397
        assert s_logits.size() == t_logits.size()

398
399
        # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
VictorSanh's avatar
VictorSanh committed
400
        if self.params.restrict_ce_to_mask:
401
            mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
402
        else:
403
            mask = attention_mask.unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
404
405
406
407
        s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
VictorSanh's avatar
VictorSanh committed
408
409
        assert t_logits_slct.size() == s_logits_slct.size()

410
411
        loss_ce = (
            self.ce_loss_fct(
412
413
                nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
                nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
414
415
416
417
            )
            * (self.temperature) ** 2
        )
        loss = self.alpha_ce * loss_ce
VictorSanh's avatar
VictorSanh committed
418

419
        if self.alpha_mlm > 0.0:
VictorSanh's avatar
VictorSanh committed
420
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
VictorSanh's avatar
VictorSanh committed
421
            loss += self.alpha_mlm * loss_mlm
422
        if self.alpha_clm > 0.0:
VictorSanh's avatar
VictorSanh committed
423
424
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
425
            loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
VictorSanh's avatar
VictorSanh committed
426
427
            loss += self.alpha_clm * loss_clm

428
429
430
431
        if self.alpha_mse > 0.0:
            loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
                0
            )  # Reproducing batchmean reduction
VictorSanh's avatar
VictorSanh committed
432
            loss += self.alpha_mse * loss_mse
433
434
435
436
        if self.alpha_cos > 0.0:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)  # (bs, seq_length, dim)
437
438
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)
439
440
441
442
443
444
445

            s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)  # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)  # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

            target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
446
447
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos
VictorSanh's avatar
VictorSanh committed
448
449
450
451

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
452
        if self.alpha_mlm > 0.0:
VictorSanh's avatar
VictorSanh committed
453
            self.last_loss_mlm = loss_mlm.item()
454
        if self.alpha_clm > 0.0:
VictorSanh's avatar
VictorSanh committed
455
            self.last_loss_clm = loss_clm.item()
456
        if self.alpha_mse > 0.0:
VictorSanh's avatar
VictorSanh committed
457
            self.last_loss_mse = loss_mse.item()
458
        if self.alpha_cos > 0.0:
459
            self.last_loss_cos = loss_cos.item()
VictorSanh's avatar
VictorSanh committed
460
461
462
463
464

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

465
    def optimize(self, loss):
VictorSanh's avatar
VictorSanh committed
466
467
468
469
470
471
472
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
473
            logger.error("NaN detected")
VictorSanh's avatar
VictorSanh committed
474
475
476
477
478
479
480
481
482
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
483

VictorSanh's avatar
VictorSanh committed
484
485
486
487
488
489
490
491
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
492
                nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
VictorSanh's avatar
VictorSanh committed
493
            else:
494
                nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
VictorSanh's avatar
VictorSanh committed
495
496
            self.optimizer.step()
            self.optimizer.zero_grad()
VictorSanh's avatar
VictorSanh committed
497
            self.scheduler.step()
VictorSanh's avatar
VictorSanh committed
498
499
500
501
502
503
504
505
506
507

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
508
            self.last_log = time.time()
VictorSanh's avatar
VictorSanh committed
509
510
511
512
513
514
515
516
517
518
519
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
520
521
522
523
524
525
            self.tensorboard.add_scalar(
                tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
            )
            self.tensorboard.add_scalar(
                tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
            )
VictorSanh's avatar
VictorSanh committed
526
527
            if param.grad is None:
                continue
528
529
530
531
532
533
534
535
536
537
538
539
            self.tensorboard.add_scalar(
                tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
            )
            self.tensorboard.add_scalar(
                tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
            )

        self.tensorboard.add_scalar(
            tag="losses/cum_avg_loss_epoch",
            scalar_value=self.total_loss_epoch / self.n_iter,
            global_step=self.n_total_iter,
        )
VictorSanh's avatar
VictorSanh committed
540
        self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        self.tensorboard.add_scalar(
            tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
        )
        if self.alpha_mlm > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
            )
        if self.alpha_clm > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
            )
        if self.alpha_mse > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
            )
        if self.alpha_cos > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
            )
        self.tensorboard.add_scalar(
            tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
        )

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(
            tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
        )
VictorSanh's avatar
VictorSanh committed
572
573
574
575
576
577

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
578
        logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
VictorSanh's avatar
VictorSanh committed
579
580

        if self.is_master:
581
582
583
584
            self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
            self.tensorboard.add_scalar(
                tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
            )
VictorSanh's avatar
VictorSanh committed
585
586
587
588
589
590

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

591
    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
VictorSanh's avatar
VictorSanh committed
592
593
594
595
596
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
597
        mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
VictorSanh's avatar
VictorSanh committed
598
599
600
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))