train.py 9.14 KB
Newer Older
1
2
3
4
import argparse

import torch
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
5
from data.dummy_dataloader import DummyDataloader
6
7
8
9
from loss_func.bert_loss import BertLoss
from lr_scheduler import AnnealingLR
from model.bert import BertForPretrain, build_pipeline_bert

10
11
import colossalai
from colossalai.kernel import LayerNorm
12
13
14
from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
15
from colossalai.legacy.engine.schedule import PipelineSchedule
16
from colossalai.legacy.utils import is_using_pp
17
18
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import FusedAdam
19
from colossalai.utils import MultiTimer
20
21
22
23
24
25
26
27
28
29
30
31


def process_batch_data(batch_data):
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = batch_data
    if gpc.is_first_rank(ParallelMode.PIPELINE):
        data = dict(input_ids=tokens, attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
    else:
        data = dict(attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
    label = dict(loss_mask=loss_mask, sentence_order=sentence_order)
    return data, label


32
33
34
35
36
37
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
    return parser.parse_args()


38
def pipeline_data_process_func(stage_output, micro_batch_data):
39
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
40
41
42
43
44
45
46
47
48
    if gpc.is_first_rank(ParallelMode.PIPELINE):
        data = (tokens, padding_mask, types, lm_labels)
        label = (loss_mask, sentence_order)
    else:
        data = (stage_output, padding_mask, types, lm_labels)
        label = (loss_mask, sentence_order)
    return data, label


49
50
def main():
    # initialize
51
    args = parse_args()
52
53
54
55
    colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')

    logger = get_dist_logger()

56
57
58
59
60
61
62
63
64
    # build synthetic dataloader
    BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
    VOCAB_SIZE = 30528
    trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
                                  vocab_size=VOCAB_SIZE,
                                  seq_length=gpc.config.SEQ_LENGTH)
    validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
                                  vocab_size=VOCAB_SIZE,
                                  seq_length=gpc.config.SEQ_LENGTH)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    logger.info("Dataloaders are built", ranks=[0])

    # build model
    if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE:
        is_naive_fp16 = True
    else:
        is_naive_fp16 = False

    use_pipeline = is_using_pp()
    kwargs = dict(vocab_size=VOCAB_SIZE,
                  hidden_size=gpc.config.HIDDEN_SIZE,
                  max_sequence_length=gpc.config.SEQ_LENGTH,
                  num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
                  convert_fp16_to_fp32_in_softmax=True,
                  is_naive_fp16=is_naive_fp16,
                  add_binary_head=gpc.config.ADD_BINARY_HEAD)

    if use_pipeline:
        model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs)
    else:
        model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs)

    model = model.half()
    model.reset_parameters()
    logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0])

    total_numel = 0
    for p in model.parameters():
        total_numel += p.numel()
    logger.info(f"This model has {total_numel} parameters")

    # build criterion
    criterion = BertLoss()
    logger.info("Criterion is built", ranks=[0])

    # layernorm and bias has no weight decay
    weight_decay_params = {'params': []}
    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
    for module_ in model.modules():
        if isinstance(module_, LayerNorm):
            no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None])
        else:
            weight_decay_params['params'].extend(
                [p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias'])
            no_weight_decay_params['params'].extend(
                [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias'])

    logger.info(
        f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}"
    )
    # optimizer
    optimizer = FusedAdam((weight_decay_params, no_weight_decay_params),
                          lr=gpc.config.LR,
                          weight_decay=gpc.config.WEIGHT_DECAY)
    logger.info("Optimizer is built", ranks=[0])

    # lr scheduler
    # follow Megatron-LM setting
    warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION)
    lr_scheduler = AnnealingLR(optimizer=optimizer,
                               max_lr=gpc.config.LR,
                               min_lr=gpc.config.MIN_LR,
                               warmup_steps=warmup_steps,
                               decay_steps=gpc.config.DECAY_ITERS,
                               decay_style='linear')
    logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")

    # # init
134
    engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True)
135
136
137
138
139
140
141
142
143
144
145
146
147

    # build timer
    timer = MultiTimer()
    skip_iters = 0

    # build loss tracker
    accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda()
    accumulated_eval_loss = torch.zeros(1, dtype=torch.float32).cuda()

    # build data iters for pipeline parallel
    if use_pipeline:
        train_data_iter = SequenceParallelDataIterator(trainloader)
        valid_data_iter = SequenceParallelDataIterator(validloader)
148
        engine.schedule.data_process_func = pipeline_data_process_func
149

150
151
    logger.info("start training")

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    for step in range(1, gpc.config.TRAIN_ITERS + 1):
        timer.start('train-iterations')
        engine.train()
        if use_pipeline:
            engine.zero_grad()
            _, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False)
            engine.step()
        else:
            tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
                trainloader)
            engine.zero_grad()
            lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
            train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
            engine.backward(train_loss)
            engine.step()
        timer.stop('train-iterations', keep_in_history=True)

        if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
            accumulated_train_loss += train_loss

        lr_scheduler.step()

        if step % gpc.config.EVAL_INTERVAL == 0:
            engine.eval()

            for j in range(gpc.config.EVAL_ITERS):
                with torch.no_grad():
                    if use_pipeline:
                        _, _, eval_loss = engine.execute_schedule(valid_data_iter,
                                                                  forward_only=True,
                                                                  return_output_label=False)
                    else:
                        tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
                            validloader)
                        lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
                        eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)

                    if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
                        accumulated_eval_loss += eval_loss

            if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
                accumulated_eval_loss /= gpc.config.EVAL_ITERS
                accumulated_train_loss /= gpc.config.EVAL_INTERVAL

            timer_string = []
            for n, t in timer:
                timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}")
            timer_string = ' | '.join(timer_string)
            lr = list(engine.optimizer.param_groups)[0]['lr']
            loss_scale = engine.optimizer.optim.loss_scale.item()

            if gpc.is_initialized(ParallelMode.PIPELINE):
                ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]]
            else:
                ranks = [0]
            logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' +
                        f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' +
                        f"| Learning rate: {lr} | " + timer_string,
                        ranks=ranks)

            for n, t in timer:
                t.reset()
            accumulated_eval_loss.zero_()
            accumulated_train_loss.zero_()


if __name__ == '__main__':
    main()