run_pretraining.py 11.1 KB
Newer Older
mandoxzhang's avatar
mandoxzhang committed
1
import math
2
3
4
5
import os
import time
from functools import partial

mandoxzhang's avatar
mandoxzhang committed
6
7
8
9
10
import torch
from arguments import parse_args
from evaluation import evaluate
from loss import LossForPretraining
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
11
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
mandoxzhang's avatar
mandoxzhang committed
12
13
from tqdm import tqdm
from transformers import AutoTokenizer
14
15
16
from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
from utils.logger import Logger
mandoxzhang's avatar
mandoxzhang committed
17

18
19
import colossalai
from colossalai.context import ParallelMode
20
21
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ProcessGroup, ShardSpec
22
23
24
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext

mandoxzhang's avatar
mandoxzhang committed
25
26
27
28

def main():
    args = parse_args()
    launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
29

mandoxzhang's avatar
mandoxzhang committed
30
31
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)

32
    # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
33

mandoxzhang's avatar
mandoxzhang committed
34
    logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
35

mandoxzhang's avatar
mandoxzhang committed
36
    if args.vscode_debug:
37
38
39
        colossalai.launch(
            config={}, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend
        )
mandoxzhang's avatar
mandoxzhang committed
40
41
42
        args.local_rank = -1
        args.log_interval = 1
    else:
43
        colossalai.launch_from_torch(config={})  # args.colossal_config
mandoxzhang's avatar
mandoxzhang committed
44
        args.local_rank = int(os.environ["LOCAL_RANK"])
45
        logger.info(
46
47
            f"launch_from_torch, world size: {torch.distributed.get_world_size()} | "
            + f"ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}"
48
        )
mandoxzhang's avatar
mandoxzhang committed
49
50
51
52
53

    log_args(logger, args)
    args.tokenizer = tokenizer
    args.logger = logger
    set_global_variables(launch_time, args.tensorboard_path)
54

mandoxzhang's avatar
mandoxzhang committed
55
    world_size = torch.distributed.get_world_size()
56
    get_current_device()
mandoxzhang's avatar
mandoxzhang committed
57
58

    # build model, optimizer and criterion
59
60
61
62
63
64
65
66
67
68
    if args.distplan.startswith("CAI"):
        # all param must use the same process group.
        world_size = torch.distributed.get_world_size()
        shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
        default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None

        if args.shardinit and args.distplan != "CAI_Gemini":
            raise RuntimeError("You can only use shardinit with CAI_Gemini")

        # build GPT model
69
70
71
        with ColoInitContext(
            device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg
        ):
mandoxzhang's avatar
mandoxzhang committed
72
            config, model, numel = get_model(args, logger)
73

74
        # assign running configurations
75
76
77
78
        gemini_config = None
        if args.distplan.startswith("CAI_ZeRO"):
            optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
        elif args.distplan == "CAI_Gemini":
79
80
81
82
83
84
85
86
87
            gemini_config = dict(
                strict_ddp_mode=args.tp_degree == 1,
                device=get_current_device(),
                placement_policy=args.placement,
                pin_memory=True,
                hidden_dim=model.config.hidden_size,
                search_range_m=128,
            )
            optim_config = dict(gpu_margin_mem_ratio=0.0)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        else:
            raise RuntimeError

        # build a highly optimized gpu/cpu optimizer
        optimizer = get_optimizer(model, lr=args.lr)

        if args.distplan == "CAI_ZeRO1":
            zero_stage = 1
        elif args.distplan == "CAI_ZeRO2":
            zero_stage = 2
        elif args.distplan == "CAI_Gemini":
            zero_stage = 3
        else:
            raise RuntimeError

        # wrap your model and optimizer
        model = zero_model_wrapper(model, zero_stage, gemini_config)
        optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)

107
        logger.info(get_mem_info(prefix="After init optim, "))
108

mandoxzhang's avatar
mandoxzhang committed
109
110
111
    else:
        config, model, numel = get_model(args, logger)
        logger.info("no_zero")
112

mandoxzhang's avatar
mandoxzhang committed
113
114
115
    if torch.distributed.get_rank() == 0:
        os.mkdir(os.path.join(args.ckpt_path, launch_time))

116
    logger.info(f"Model numel: {numel}")
117

mandoxzhang's avatar
mandoxzhang committed
118
    get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
119
120

    # 144003367 is is the length of the entire dataset
121
    # len(dataloader)
122
123
124
125
126
127
128
    steps_per_epoch = (
        144003367
        // world_size
        // args.train_micro_batch_size_per_gpu
        // args.gradient_accumulation_steps
        // args.refresh_bucket_size
    )
mandoxzhang's avatar
mandoxzhang committed
129
130
    total_steps = steps_per_epoch * args.epoch

131
    lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
mandoxzhang's avatar
mandoxzhang committed
132
133
134
135
136
137

    start_epoch = 0
    start_shard = 0
    global_step = 0
    if args.resume_train:
        assert os.path.exists(args.load_optimizer_lr)
138
139
140
        o_l_state_dict = torch.load(args.load_optimizer_lr, map_location="cpu")
        o_l_state_dict["lr_scheduler"]["last_epoch"] = o_l_state_dict["lr_scheduler"]["last_epoch"] - 1
        optimizer.load_state_dict(o_l_state_dict["optimizer"])
141
        # o_l_state_dict['lr_scheduler']['last_epoch']
142
143
144
        lr_scheduler = get_lr_scheduler(
            optimizer, total_steps=total_steps, last_epoch=o_l_state_dict["lr_scheduler"]["last_epoch"]
        )
mandoxzhang's avatar
mandoxzhang committed
145
146
147
148
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
149
        # if you want delete the above three code, must move the model to gpu. Because in optimizer.step()
150
        lr_scheduler.load_state_dict(o_l_state_dict["lr_scheduler"])
151

152
153
        start_epoch = o_l_state_dict["epoch"]
        start_shard = o_l_state_dict["shard"] + 1
mandoxzhang's avatar
mandoxzhang committed
154
        # global_step = o_l_state_dict['global_step'] + 1
155
        logger.info(
156
            f"resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}"
157
        )
mandoxzhang's avatar
mandoxzhang committed
158
159
160
161
162
163

    criterion = LossForPretraining(config.vocab_size)

    # build dataloader
    pretrain_dataset_provider = NvidiaBertDatasetProvider(args)

164
    logger.info(get_mem_info(prefix="After init model, "))
mandoxzhang's avatar
mandoxzhang committed
165
166
167
168

    eval_loss = 0
    train_loss = 0
    timers = get_timers()
169
170
171
    timers("interval_time").start()
    timers("epoch_time").start()
    timers("shard_time").start()
mandoxzhang's avatar
mandoxzhang committed
172
173
174
175
176
177

    for epoch in range(start_epoch, args.epoch):
        for shard in range(start_shard, len(os.listdir(args.data_path_prefix))):
            dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
            # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
            if torch.distributed.get_rank() == 0:
178
179
180
181
182
183
                iterator_data = tqdm(
                    enumerate(dataset_iterator),
                    total=(total_length // args.train_micro_batch_size_per_gpu // world_size),
                    colour="cyan",
                    smoothing=1,
                )
mandoxzhang's avatar
mandoxzhang committed
184
185
186
            else:
                iterator_data = enumerate(dataset_iterator)

187
            model.train()
188
189

            for step, batch_data in iterator_data:
mandoxzhang's avatar
mandoxzhang committed
190
191
192
193
194
195
196
                # batch_data = pretrain_dataset_provider.get_batch(batch_index)
                input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
                attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}")
                token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}")
                mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}")
                # nsp_label = batch_data[5].cuda()

197
                output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
198

199
                loss = criterion(output.logits, mlm_label)
mandoxzhang's avatar
mandoxzhang committed
200
201
                pretrain_dataset_provider.prefetch_batch()

202
                optimizer.backward(loss)
mandoxzhang's avatar
mandoxzhang committed
203
204
                train_loss += loss.float().item()
                # if  (step + 1) % args.accumulation_step == 0:
205
206
207
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
208

mandoxzhang's avatar
mandoxzhang committed
209
210
                global_step += 1

211
212
                if global_step % args.log_interval == 0 and global_step != 0 and torch.distributed.get_rank() == 0:
                    elapsed_time = timers("interval_time").elapsed(reset=False)
mandoxzhang's avatar
mandoxzhang committed
213
                    elapsed_time_per_iteration = elapsed_time / global_step
214
                    samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(
215
216
                        numel, args, config, elapsed_time, global_step, world_size
                    )
mandoxzhang's avatar
mandoxzhang committed
217
218

                    cur_loss = train_loss / args.log_interval
219
                    current_lr = lr_scheduler.get_last_lr()[0]
220
221
222
223
                    log_str = (
                        f"| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes "
                        + f"| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}"
                    )
mandoxzhang's avatar
mandoxzhang committed
224
225
226
227
                    logger.info(log_str, print_=False)

                    if args.wandb:
                        tensorboard_log = get_tensorboard_writer()
228
229
                        tensorboard_log.log_train(
                            {
230
231
232
233
234
235
236
                                "lr": current_lr,
                                "loss": cur_loss,
                                "ppl": math.exp(cur_loss),
                                "mins_batch": elapsed_time_per_iteration,
                            },
                            global_step,
                        )
mandoxzhang's avatar
mandoxzhang committed
237
238
239
240

                    train_loss = 0

            logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins')
241
            logger.info("*" * 100)
mandoxzhang's avatar
mandoxzhang committed
242

243
            eval_loss += evaluate(model, args, logger, global_step, criterion)
244
245
246
247
248
249
250
251
252
            save_ckpt(
                model,
                optimizer,
                lr_scheduler,
                os.path.join(args.ckpt_path, launch_time, f"epoch-{epoch}_shard-{shard}_" + launch_time),
                epoch,
                shard,
                global_step,
            )
253

mandoxzhang's avatar
mandoxzhang committed
254
        eval_loss /= len(os.listdir(args.data_path_prefix))
255
256
        logger.info(
            f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
257
258
259
            + f"eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}"
        )
        logger.info("-" * 100)
mandoxzhang's avatar
mandoxzhang committed
260
261
        if args.wandb and torch.distributed.get_rank() == 0:
            tensorboard_log = get_tensorboard_writer()
262
263
264
265
266
267
            tensorboard_log.log_eval(
                {
                    "all_eval_shard_loss": eval_loss,
                },
                epoch,
            )
mandoxzhang's avatar
mandoxzhang committed
268
269
270
271
272
        start_shard = 0
        eval_loss = 0

    pretrain_dataset_provider.release_shard()

273
    logger.info("Congratulation, training has finished!!!")
mandoxzhang's avatar
mandoxzhang committed
274
275


276
if __name__ == "__main__":
mandoxzhang's avatar
mandoxzhang committed
277
    main()