opt_train_demo.py 4.72 KB
Newer Older
1
2
3
import time

import datasets
4
import torch
5
import transformers
6
7
from args import parse_demo_args
from data import NetflixDataset, netflix_collator
8
from tqdm import tqdm
9
10
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
from transformers.utils.versions import require_version
11
12
13
14
15

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
16
17
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
18
19
20
21
22
23
24
25
26
27

require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")


def move_to_cuda(batch, device):
    return {k: v.to(device) for k, v in batch.items()}


def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
28

29
30
31
32
    torch.cuda.synchronize()
    model.train()

    with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
33

34
35
        for batch in pbar:

digger yu's avatar
digger yu committed
36
            # Forward
37
38
            optimizer.zero_grad()
            batch = move_to_cuda(batch, torch.cuda.current_device())
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
            outputs = model(use_cache=False, **batch)
            loss = outputs['loss']

            # Backward
            booster.backward(loss, optimizer)
            optimizer.step()
            lr_scheduler.step()

            # Print batch loss
            pbar.set_postfix({'loss': loss.item()})


def main():

    args = parse_demo_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()
    world_size = coordinator.world_size

    # Manage loggers
    disable_existing_loggers()
    logger = get_dist_logger()
    if coordinator.is_master():
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
70

71
72
    # Build OPT model
    config = AutoConfig.from_pretrained(args.model_name_or_path)
73
    model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
74
75
76
77
78
79
80
81
82
83
84
85
    logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Set plugin
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
86
        plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
87
88
89
90
91
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
    logger.info(f"Set plugin as {args.plugin}", ranks=[0])

    # Prepare tokenizer and dataloader
92
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
93
94
95
96
97
98
    dataset = NetflixDataset(tokenizer)
    dataloader = plugin.prepare_dataloader(dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           drop_last=True,
                                           collate_fn=netflix_collator)
99

100
    # Set optimizer
101
    optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
102
103
104
105

    # Set lr scheduler
    total_steps = len(dataloader) * args.num_epoch
    num_warmup_steps = int(args.warmup_ratio * total_steps)
106
107
108
    lr_scheduler = get_linear_schedule_with_warmup(optimizer,
                                                   num_warmup_steps=num_warmup_steps,
                                                   num_training_steps=len(dataloader) * args.num_epoch)
109
110
111

    # Set booster
    booster = Booster(plugin=plugin, **booster_kwargs)
112
113
114
    model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
                                                                  optimizer=optimizer,
                                                                  dataloader=dataloader,
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
                                                                  lr_scheduler=lr_scheduler)

    # Start finetuning
    logger.info(f"Start finetuning", ranks=[0])
    for epoch in range(args.num_epoch):
        train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator)

    # Finish training and evaluate
    logger.info(f"Finish finetuning", ranks=[0])
    booster.save_model(model, args.output_path)
    logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])


if __name__ == "__main__":
    main()