pretrain_utils.py 4.32 KB
Newer Older
mandoxzhang's avatar
mandoxzhang committed
1
2
3
4
import logging
import os
import sys

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import transformers
from torch.optim import AdamW
from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    BertForPreTraining,
    GPT2Config,
    GPT2LMHeadModel,
    RobertaConfig,
    RobertaForMaskedLM,
    get_linear_schedule_with_warmup,
)

from colossalai.core import global_context as gpc
from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.nn.optimizer import FusedAdam, HybridAdam

sys.path.append(os.getcwd())
mandoxzhang's avatar
mandoxzhang committed
24
25
from collections import OrderedDict

26
27
28
29
import torch.nn as nn
from model.bert import BertForMaskedLM
from model.deberta_v2 import DebertaV2ForMaskedLM

mandoxzhang's avatar
mandoxzhang committed
30
31
32
33
__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']


def get_new_state_dict(state_dict, start_index=13):
34
    new_state_dict = OrderedDict()
mandoxzhang's avatar
mandoxzhang committed
35
36
    for k, v in state_dict.items():
        name = k[start_index:]
37
        new_state_dict[name] = v
mandoxzhang's avatar
mandoxzhang committed
38
39
40
41
    return new_state_dict


class LMModel(nn.Module):
42

mandoxzhang's avatar
mandoxzhang committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    def __init__(self, model, config, args):
        super().__init__()

        self.checkpoint = args.checkpoint_activations
        self.config = config
        self.model = model
        if self.checkpoint:
            self.model.gradient_checkpointing_enable()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        # Only return lm_logits
        return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)


def get_model(args, logger):

    if args.mlm == 'bert':
        config = transformers.BertConfig.from_json_file(args.bert_config)
        model = BertForMaskedLM(config)
    elif args.mlm == 'deberta_v2':
        config = transformers.DebertaV2Config.from_json_file(args.bert_config)
        model = DebertaV2ForMaskedLM(config)
    else:
        raise Exception("Invalid mlm!")

    if len(args.load_pretrain_model) > 0:
        assert os.path.exists(args.load_pretrain_model)
        # load_checkpoint(args.load_pretrain_model, model, strict=False)
71
72
        m_state_dict = torch.load(args.load_pretrain_model,
                                  map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
mandoxzhang's avatar
mandoxzhang committed
73
        # new_state_dict = get_new_state_dict(m_state_dict)
74
75
        model.load_state_dict(m_state_dict,
                              strict=True)    # must insure that every process have identical parameters !!!!!!!
mandoxzhang's avatar
mandoxzhang committed
76
        logger.info("load model success")
77

mandoxzhang's avatar
mandoxzhang committed
78
79
80
81
    numel = sum([p.numel() for p in model.parameters()])
    if args.checkpoint_activations:
        model.gradient_checkpointing_enable()
    # model = LMModel(model, config, args)
82

mandoxzhang's avatar
mandoxzhang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    return config, model, numel


def get_optimizer(model, lr):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    # configure the weight decay for bert models
    optimizer_grouped_parameters = [{
        'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.1
    }, {
        'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0
    }]
98
    optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95])
mandoxzhang's avatar
mandoxzhang committed
99
100
101
102
103
    return optimizer


def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
    # warmup_steps = int(total_steps * warmup_ratio)
104
105
106
107
    lr_scheduler = get_linear_schedule_with_warmup(optimizer,
                                                   num_warmup_steps=warmup_steps,
                                                   num_training_steps=total_steps,
                                                   last_epoch=last_epoch)
mandoxzhang's avatar
mandoxzhang committed
108
109
110
111
112
113
114
115
116
117
118
119
120
    # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
    return lr_scheduler


def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
    model_path = path + '_pytorch_model.bin'
    optimizer_lr_path = path + '.op_lrs'
    checkpoint = {}
    checkpoint['optimizer'] = optimizer.state_dict()
    checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
    checkpoint['epoch'] = epoch
    checkpoint['shard'] = shard
    checkpoint['global_step'] = global_step
121
    model_state = model.state_dict()    #each process must run model.state_dict()
mandoxzhang's avatar
mandoxzhang committed
122
123
124
    if gpc.get_global_rank() == 0:
        torch.save(checkpoint, optimizer_lr_path)
        torch.save(model_state, model_path)