distiller.py 25.5 KB
Newer Older
VictorSanh's avatar
VictorSanh committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
VictorSanh's avatar
VictorSanh committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
VictorSanh's avatar
VictorSanh committed
15
16
""" The distiller to distil the student.
    Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
VictorSanh's avatar
VictorSanh committed
17
"""
VictorSanh's avatar
VictorSanh committed
18
import math
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import os
20
import time
VictorSanh's avatar
VictorSanh committed
21

22
import psutil
VictorSanh's avatar
VictorSanh committed
23
24
25
import torch
import torch.nn as nn
import torch.nn.functional as F
26
from torch.optim import AdamW
Aymeric Augustin's avatar
Aymeric Augustin committed
27
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
VictorSanh's avatar
VictorSanh committed
28
from torch.utils.data.distributed import DistributedSampler
29
from tqdm import tqdm
Aymeric Augustin's avatar
Aymeric Augustin committed
30
31
32
33
34
35

from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from transformers import get_linear_schedule_with_warmup
from utils import logger

VictorSanh's avatar
VictorSanh committed
36

37
38
try:
    from torch.utils.tensorboard import SummaryWriter
39
except ImportError:
40
41
    from tensorboardX import SummaryWriter

42

VictorSanh's avatar
VictorSanh committed
43
class Distiller:
44
45
46
47
    def __init__(
        self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
    ):
        logger.info("Initializing Distiller")
VictorSanh's avatar
VictorSanh committed
48
49
50
51
52
53
54
55
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

VictorSanh's avatar
VictorSanh committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)

70
        self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
VictorSanh's avatar
VictorSanh committed
71
72

        self.temperature = params.temperature
73
        assert self.temperature > 0.0
VictorSanh's avatar
VictorSanh committed
74
75
76

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
VictorSanh's avatar
VictorSanh committed
77
        self.alpha_clm = params.alpha_clm
VictorSanh's avatar
VictorSanh committed
78
        self.alpha_mse = params.alpha_mse
79
        self.alpha_cos = params.alpha_cos
VictorSanh's avatar
VictorSanh committed
80
81
82

        self.mlm = params.mlm
        if self.mlm:
83
            logger.info(f"Using MLM loss for LM step.")
VictorSanh's avatar
VictorSanh committed
84
85
86
87
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
88
89
            self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
VictorSanh's avatar
VictorSanh committed
90
91
92
93
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
94
            logger.info(f"Using CLM loss for LM step.")
VictorSanh's avatar
VictorSanh committed
95
96
97
98
99
100
101
102
103

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
VictorSanh's avatar
VictorSanh committed
104
        self.last_loss_clm = 0
105
106
107
108
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
109
        self.last_log = 0
VictorSanh's avatar
VictorSanh committed
110

111
        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
Lysandre's avatar
Lysandre committed
112
        self.lm_loss_fct = nn.CrossEntropyLoss()
113
114
115
116
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
VictorSanh's avatar
VictorSanh committed
117

118
        logger.info("--- Initializing model optimizer")
VictorSanh's avatar
VictorSanh committed
119
        assert params.gradient_accumulation_steps >= 1
VictorSanh's avatar
VictorSanh committed
120
        self.num_steps_epoch = len(self.dataloader)
121
122
123
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
        )
VictorSanh's avatar
VictorSanh committed
124

125
        no_decay = ["bias", "LayerNorm.weight"]
VictorSanh's avatar
VictorSanh committed
126
        optimizer_grouped_parameters = [
127
128
129
130
131
132
133
134
135
136
137
138
            {
                "params": [
                    p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": 0.0,
            },
VictorSanh's avatar
VictorSanh committed
139
        ]
140
141
142
143
        logger.info(
            "------ Number of trainable parameters (student): %i"
            % sum([p.numel() for p in self.student.parameters() if p.requires_grad])
        )
VictorSanh's avatar
VictorSanh committed
144
        logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
145
146
147
        self.optimizer = AdamW(
            optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
        )
148
149

        warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
150
151
152
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
        )
VictorSanh's avatar
VictorSanh committed
153
154
155
156
157
158
159

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
160
161
162
            self.student, self.optimizer = amp.initialize(
                self.student, self.optimizer, opt_level=self.params.fp16_opt_level
            )
VictorSanh's avatar
VictorSanh committed
163
164
165
166
167
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
168

VictorSanh's avatar
VictorSanh committed
169
170
171
172
                logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
173

VictorSanh's avatar
VictorSanh committed
174
                logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
175
176
177
178
179
180
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True,
                )
VictorSanh's avatar
VictorSanh committed
181
182
183

        self.is_master = params.is_master
        if self.is_master:
184
185
186
187
            logger.info("--- Initializing Tensorboard")
            self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
            self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
            self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
VictorSanh's avatar
VictorSanh committed
188

189
    def prepare_batch_mlm(self, batch):
VictorSanh's avatar
VictorSanh committed
190
191
192
193
194
195
196
197
198
199
200
201
202
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
Lysandre's avatar
Lysandre committed
203
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
VictorSanh's avatar
VictorSanh committed
204
205
206
207
208
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

209
        attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
VictorSanh's avatar
VictorSanh committed
210
211
212
213
214
215
216

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
217
218
219
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
220
221
222
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

223
        pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
VictorSanh's avatar
VictorSanh committed
224
225
226
227
228
229
230
231

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
232
                    pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
VictorSanh's avatar
VictorSanh committed
233
234
235
236
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
VictorSanh's avatar
VictorSanh committed
237
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
238
        _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
VictorSanh's avatar
VictorSanh committed
239
        probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
240
241
242
243
244
        _token_ids = (
            _token_ids_mask * (probs == 0).long()
            + _token_ids_real * (probs == 1).long()
            + _token_ids_rand * (probs == 2).long()
        )
VictorSanh's avatar
VictorSanh committed
245
246
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

Lysandre's avatar
Lysandre committed
247
        mlm_labels[~pred_mask] = -1  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
248

VictorSanh's avatar
VictorSanh committed
249
250
251
        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

VictorSanh's avatar
VictorSanh committed
252
253
        return token_ids, attn_mask, mlm_labels

254
    def prepare_batch_clm(self, batch):
VictorSanh's avatar
VictorSanh committed
255
256
257
258
259
260
261
262
263
264
265
266
267
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
Lysandre's avatar
Lysandre committed
268
            clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
VictorSanh's avatar
VictorSanh committed
269
270
271
272
273
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

274
        attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
VictorSanh's avatar
VictorSanh committed
275
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
Lysandre's avatar
Lysandre committed
276
        clm_labels[~attn_mask] = -1  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
VictorSanh's avatar
VictorSanh committed
277
278
279
280
281
282

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

283
    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
VictorSanh's avatar
VictorSanh committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
VictorSanh's avatar
VictorSanh committed
318
            if self.mlm:
319
                pad_id = self.params.special_tok_ids["pad_token"]
VictorSanh's avatar
VictorSanh committed
320
            else:
321
                pad_id = self.params.special_tok_ids["unk_token"]
VictorSanh's avatar
VictorSanh committed
322
323
324
325
326
327
328
329
330
331
332
333
            padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
334
335
        if self.is_master:
            logger.info("Starting training")
336
        self.last_log = time.time()
VictorSanh's avatar
VictorSanh committed
337
338
339
340
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
341
342
            if self.is_master:
                logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
343
344
            if self.multi_gpu:
                torch.distributed.barrier()
VictorSanh's avatar
VictorSanh committed
345

VictorSanh's avatar
VictorSanh committed
346
347
            iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
VictorSanh's avatar
VictorSanh committed
348
                if self.params.n_gpu > 0:
349
                    batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
VictorSanh's avatar
VictorSanh committed
350

VictorSanh's avatar
VictorSanh committed
351
352
353
354
355
                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
                self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
VictorSanh's avatar
VictorSanh committed
356
357

                iter_bar.update()
358
359
360
                iter_bar.set_postfix(
                    {"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
                )
VictorSanh's avatar
VictorSanh committed
361
362
            iter_bar.close()

363
364
            if self.is_master:
                logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
VictorSanh's avatar
VictorSanh committed
365
366
            self.end_epoch()

367
        if self.is_master:
368
369
370
371
372
            logger.info(f"Save very last checkpoint as `pytorch_model.bin`.")
            self.save_checkpoint(checkpoint_name=f"pytorch_model.bin")
            logger.info("Training is finished")

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
VictorSanh's avatar
VictorSanh committed
373
374
375
376
377
378
379
380
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
VictorSanh's avatar
VictorSanh committed
381
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
VictorSanh's avatar
VictorSanh committed
382
        """
VictorSanh's avatar
VictorSanh committed
383
        if self.mlm:
384
385
386
            s_logits, s_hidden_states = self.student(
                input_ids=input_ids, attention_mask=attention_mask
            )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
387
            with torch.no_grad():
388
389
390
                t_logits, t_hidden_states = self.teacher(
                    input_ids=input_ids, attention_mask=attention_mask
                )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
391
        else:
392
393
394
            s_logits, _, s_hidden_states = self.student(
                input_ids=input_ids, attention_mask=None
            )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
395
            with torch.no_grad():
396
397
398
                t_logits, _, t_hidden_states = self.teacher(
                    input_ids=input_ids, attention_mask=None
                )  # (bs, seq_length, voc_size)
VictorSanh's avatar
VictorSanh committed
399
400
        assert s_logits.size() == t_logits.size()

401
402
        # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
VictorSanh's avatar
VictorSanh committed
403
        if self.params.restrict_ce_to_mask:
404
            mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits)  # (bs, seq_lenth, voc_size)
VictorSanh's avatar
VictorSanh committed
405
        else:
406
407
408
409
410
            mask = attention_mask.unsqueeze(-1).expand_as(s_logits)  # (bs, seq_lenth, voc_size)
        s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
VictorSanh's avatar
VictorSanh committed
411
412
        assert t_logits_slct.size() == s_logits_slct.size()

413
414
415
416
417
418
419
420
        loss_ce = (
            self.ce_loss_fct(
                F.log_softmax(s_logits_slct / self.temperature, dim=-1),
                F.softmax(t_logits_slct / self.temperature, dim=-1),
            )
            * (self.temperature) ** 2
        )
        loss = self.alpha_ce * loss_ce
VictorSanh's avatar
VictorSanh committed
421

422
        if self.alpha_mlm > 0.0:
VictorSanh's avatar
VictorSanh committed
423
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
VictorSanh's avatar
VictorSanh committed
424
            loss += self.alpha_mlm * loss_mlm
425
        if self.alpha_clm > 0.0:
VictorSanh's avatar
VictorSanh committed
426
427
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
428
            loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
VictorSanh's avatar
VictorSanh committed
429
430
            loss += self.alpha_clm * loss_clm

431
432
433
434
        if self.alpha_mse > 0.0:
            loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
                0
            )  # Reproducing batchmean reduction
VictorSanh's avatar
VictorSanh committed
435
            loss += self.alpha_mse * loss_mse
436
437
438
439
        if self.alpha_cos > 0.0:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)  # (bs, seq_length, dim)
440
441
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)
442
443
444
445
446
447
448

            s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)  # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)  # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

            target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
449
450
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos
VictorSanh's avatar
VictorSanh committed
451
452
453
454

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
455
        if self.alpha_mlm > 0.0:
VictorSanh's avatar
VictorSanh committed
456
            self.last_loss_mlm = loss_mlm.item()
457
        if self.alpha_clm > 0.0:
VictorSanh's avatar
VictorSanh committed
458
            self.last_loss_clm = loss_clm.item()
459
        if self.alpha_mse > 0.0:
VictorSanh's avatar
VictorSanh committed
460
            self.last_loss_mse = loss_mse.item()
461
        if self.alpha_cos > 0.0:
462
            self.last_loss_cos = loss_cos.item()
VictorSanh's avatar
VictorSanh committed
463
464
465
466
467

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

468
    def optimize(self, loss):
VictorSanh's avatar
VictorSanh committed
469
470
471
472
473
474
475
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
476
            logger.error("NaN detected")
VictorSanh's avatar
VictorSanh committed
477
478
479
480
481
482
483
484
485
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
486

VictorSanh's avatar
VictorSanh committed
487
488
489
490
491
492
493
494
495
496
497
498
499
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
VictorSanh's avatar
VictorSanh committed
500
            self.scheduler.step()
VictorSanh's avatar
VictorSanh committed
501
502
503
504
505
506
507
508
509
510

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
511
            self.last_log = time.time()
VictorSanh's avatar
VictorSanh committed
512
513
514
515
516
517
518
519
520
521
522
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
523
524
525
526
527
528
            self.tensorboard.add_scalar(
                tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
            )
            self.tensorboard.add_scalar(
                tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
            )
VictorSanh's avatar
VictorSanh committed
529
530
            if param.grad is None:
                continue
531
532
533
534
535
536
537
538
539
540
541
542
            self.tensorboard.add_scalar(
                tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
            )
            self.tensorboard.add_scalar(
                tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
            )

        self.tensorboard.add_scalar(
            tag="losses/cum_avg_loss_epoch",
            scalar_value=self.total_loss_epoch / self.n_iter,
            global_step=self.n_total_iter,
        )
VictorSanh's avatar
VictorSanh committed
543
        self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        self.tensorboard.add_scalar(
            tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
        )
        if self.alpha_mlm > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
            )
        if self.alpha_clm > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
            )
        if self.alpha_mse > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
            )
        if self.alpha_cos > 0.0:
            self.tensorboard.add_scalar(
                tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
            )
        self.tensorboard.add_scalar(
            tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
        )

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(
            tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
        )
VictorSanh's avatar
VictorSanh committed
575
576
577
578
579
580

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
581
        logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
VictorSanh's avatar
VictorSanh committed
582
583

        if self.is_master:
584
585
586
587
            self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
            self.tensorboard.add_scalar(
                tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
            )
VictorSanh's avatar
VictorSanh committed
588
589
590
591
592
593

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

594
    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
VictorSanh's avatar
VictorSanh committed
595
596
597
598
599
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
600
        mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
VictorSanh's avatar
VictorSanh committed
601
602
603
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))