run_pretraining.py 11.3 KB
Newer Older
mandoxzhang's avatar
mandoxzhang committed
1
import math
2
3
4
5
import os
import time
from functools import partial

mandoxzhang's avatar
mandoxzhang committed
6
7
8
9
10
import torch
from arguments import parse_args
from evaluation import evaluate
from loss import LossForPretraining
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
11
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
mandoxzhang's avatar
mandoxzhang committed
12
13
from tqdm import tqdm
from transformers import AutoTokenizer
14
15
16
from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
from utils.logger import Logger
mandoxzhang's avatar
mandoxzhang committed
17

18
19
20
21
22
23
24
25
26
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer

mandoxzhang's avatar
mandoxzhang committed
27
28
29
30
31

def main():

    args = parse_args()
    launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
32

mandoxzhang's avatar
mandoxzhang committed
33
34
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)

35
    # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
36

mandoxzhang's avatar
mandoxzhang committed
37
    logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
38

mandoxzhang's avatar
mandoxzhang committed
39
40
    if args.vscode_debug:
        colossalai.launch(config={},
41
42
43
44
45
                          rank=args.rank,
                          world_size=args.world_size,
                          host=args.host,
                          port=args.port,
                          backend=args.backend)
mandoxzhang's avatar
mandoxzhang committed
46
47
48
        args.local_rank = -1
        args.log_interval = 1
    else:
49
        colossalai.launch_from_torch(config={})    #args.colossal_config
mandoxzhang's avatar
mandoxzhang committed
50
        args.local_rank = int(os.environ["LOCAL_RANK"])
51
52
53
54
        logger.info(
            f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
            f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}'
        )
mandoxzhang's avatar
mandoxzhang committed
55
56
57
58
59

    log_args(logger, args)
    args.tokenizer = tokenizer
    args.logger = logger
    set_global_variables(launch_time, args.tensorboard_path)
60

mandoxzhang's avatar
mandoxzhang committed
61
    world_size = torch.distributed.get_world_size()
62
    init_dev = get_current_device()
mandoxzhang's avatar
mandoxzhang committed
63
64

    # build model, optimizer and criterion
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    if args.distplan.startswith("CAI"):
        # all param must use the same process group.
        world_size = torch.distributed.get_world_size()
        shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
        default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None

        if args.shardinit and args.distplan != "CAI_Gemini":
            raise RuntimeError("You can only use shardinit with CAI_Gemini")

        # build GPT model
        with ColoInitContext(device=get_current_device(),
                             dtype=torch.half,
                             default_dist_spec=default_dist_spec,
                             default_pg=shard_pg):
mandoxzhang's avatar
mandoxzhang committed
79
            config, model, numel = get_model(args, logger)
80

81
        # assign running configurations
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        gemini_config = None
        if args.distplan.startswith("CAI_ZeRO"):
            optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
        elif args.distplan == "CAI_Gemini":
            gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
                                 device=get_current_device(),
                                 placement_policy=args.placement,
                                 pin_memory=True,
                                 hidden_dim=model.config.hidden_size,
                                 search_range_mb=128)
            optim_config = dict(gpu_margin_mem_ratio=0.)
        else:
            raise RuntimeError

        # build a highly optimized gpu/cpu optimizer
        optimizer = get_optimizer(model, lr=args.lr)

        if args.distplan == "CAI_ZeRO1":
            zero_stage = 1
        elif args.distplan == "CAI_ZeRO2":
            zero_stage = 2
        elif args.distplan == "CAI_Gemini":
            zero_stage = 3
        else:
            raise RuntimeError

        # wrap your model and optimizer
        model = zero_model_wrapper(model, zero_stage, gemini_config)
        optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)

        logger.info(get_mem_info(prefix='After init optim, '))
113

mandoxzhang's avatar
mandoxzhang committed
114
115
116
    else:
        config, model, numel = get_model(args, logger)
        logger.info("no_zero")
117

mandoxzhang's avatar
mandoxzhang committed
118
119
120
121
    if torch.distributed.get_rank() == 0:
        os.mkdir(os.path.join(args.ckpt_path, launch_time))

    logger.info(f'Model numel: {numel}')
122

mandoxzhang's avatar
mandoxzhang committed
123
    get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
124
125

    # 144003367 is is the length of the entire dataset
126
    steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size    #len(dataloader)
mandoxzhang's avatar
mandoxzhang committed
127
128
    total_steps = steps_per_epoch * args.epoch

129
    lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
mandoxzhang's avatar
mandoxzhang committed
130
131
132
133
134
135
136
137
138

    start_epoch = 0
    start_shard = 0
    global_step = 0
    if args.resume_train:
        assert os.path.exists(args.load_optimizer_lr)
        o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu')
        o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1
        optimizer.load_state_dict(o_l_state_dict['optimizer'])
139
140
141
142
        # o_l_state_dict['lr_scheduler']['last_epoch']
        lr_scheduler = get_lr_scheduler(optimizer,
                                        total_steps=total_steps,
                                        last_epoch=o_l_state_dict['lr_scheduler']['last_epoch'])
mandoxzhang's avatar
mandoxzhang committed
143
144
145
146
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
147
        # if you want delete the above three code, must move the model to gpu. Because in optimizer.step()
mandoxzhang's avatar
mandoxzhang committed
148
        lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler'])
149

mandoxzhang's avatar
mandoxzhang committed
150
151
152
        start_epoch = o_l_state_dict['epoch']
        start_shard = o_l_state_dict['shard'] + 1
        # global_step = o_l_state_dict['global_step'] + 1
153
154
155
        logger.info(
            f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
        )
mandoxzhang's avatar
mandoxzhang committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    criterion = LossForPretraining(config.vocab_size)

    # build dataloader
    pretrain_dataset_provider = NvidiaBertDatasetProvider(args)

    logger.info(get_mem_info(prefix='After init model, '))

    best_loss = None
    eval_loss = 0
    train_loss = 0
    timers = get_timers()
    timers('interval_time').start()
    timers('epoch_time').start()
    timers('shard_time').start()

    for epoch in range(start_epoch, args.epoch):

        for shard in range(start_shard, len(os.listdir(args.data_path_prefix))):

            dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
            # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
            if torch.distributed.get_rank() == 0:
179
180
181
182
                iterator_data = tqdm(enumerate(dataset_iterator),
                                     total=(total_length // args.train_micro_batch_size_per_gpu // world_size),
                                     colour='cyan',
                                     smoothing=1)
mandoxzhang's avatar
mandoxzhang committed
183
184
185
            else:
                iterator_data = enumerate(dataset_iterator)

186
            model.train()
187
188

            for step, batch_data in iterator_data:
mandoxzhang's avatar
mandoxzhang committed
189
190
191
192
193
194
195
196

                # batch_data = pretrain_dataset_provider.get_batch(batch_index)
                input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
                attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}")
                token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}")
                mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}")
                # nsp_label = batch_data[5].cuda()

197
                output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
198

199
                loss = criterion(output.logits, mlm_label)
mandoxzhang's avatar
mandoxzhang committed
200
201
                pretrain_dataset_provider.prefetch_batch()

202
                optimizer.backward(loss)
mandoxzhang's avatar
mandoxzhang committed
203
204
                train_loss += loss.float().item()
                # if  (step + 1) % args.accumulation_step == 0:
205
206
207
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
208

mandoxzhang's avatar
mandoxzhang committed
209
210
211
                global_step += 1

                if global_step % args.log_interval == 0 and global_step != 0 \
212
                        and torch.distributed.get_rank() == 0:
mandoxzhang's avatar
mandoxzhang committed
213
214
                    elapsed_time = timers('interval_time').elapsed(reset=False)
                    elapsed_time_per_iteration = elapsed_time / global_step
215
216
                    samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(
                        numel, args, config, elapsed_time, global_step, world_size)
mandoxzhang's avatar
mandoxzhang committed
217
218

                    cur_loss = train_loss / args.log_interval
219
                    current_lr = lr_scheduler.get_last_lr()[0]
mandoxzhang's avatar
mandoxzhang committed
220
221
222
223
224
225
                    log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
                              f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}'
                    logger.info(log_str, print_=False)

                    if args.wandb:
                        tensorboard_log = get_tensorboard_writer()
226
227
228
229
230
231
232
                        tensorboard_log.log_train(
                            {
                                'lr': current_lr,
                                'loss': cur_loss,
                                'ppl': math.exp(cur_loss),
                                'mins_batch': elapsed_time_per_iteration
                            }, global_step)
mandoxzhang's avatar
mandoxzhang committed
233
234
235
236
237
238

                    train_loss = 0

            logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins')
            logger.info('*' * 100)

239
            eval_loss += evaluate(model, args, logger, global_step, criterion)
240
241
242
243
            save_ckpt(model, optimizer, lr_scheduler,
                      os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
                      shard, global_step)

mandoxzhang's avatar
mandoxzhang committed
244
        eval_loss /= len(os.listdir(args.data_path_prefix))
245
246
247
        logger.info(
            f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
            + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
mandoxzhang's avatar
mandoxzhang committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        logger.info('-' * 100)
        if args.wandb and torch.distributed.get_rank() == 0:
            tensorboard_log = get_tensorboard_writer()
            tensorboard_log.log_eval({
                'all_eval_shard_loss': eval_loss,
            }, epoch)
        start_shard = 0
        eval_loss = 0

    pretrain_dataset_provider.release_shard()

    logger.info('Congratulation, training has finished!!!')


if __name__ == '__main__':
    main()