vit_train_demo.py 6.58 KB
Newer Older
1
2
3
import torch
import torch.distributed as dist
import transformers
4
5
from args import parse_demo_args
from data import BeansDataset, beans_collator
6
from tqdm import tqdm
7
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
8
9
10
11
12

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
13
14
15
16
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
17
18
19
20
21
22
23


def move_to_cuda(batch, device):
    return {k: v.to(device) for k, v in batch.items()}


def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
24

25
26
27
28
    torch.cuda.synchronize()
    model.train()

    with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        for batch in pbar:

            # Foward
            optimizer.zero_grad()
            batch = move_to_cuda(batch, torch.cuda.current_device())
            outputs = model(**batch)
            loss = outputs['loss']

            # Backward
            booster.backward(loss, optimizer)
            optimizer.step()
            lr_scheduler.step()

            # Print batch loss
            pbar.set_postfix({'loss': loss.item()})


@torch.no_grad()
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    model.eval()
    accum_loss = torch.zeros(1, device=get_current_device())
    total_num = torch.zeros(1, device=get_current_device())
    accum_correct = torch.zeros(1, device=get_current_device())

    for batch in eval_dataloader:
        batch = move_to_cuda(batch, torch.cuda.current_device())
        outputs = model(**batch)
        val_loss, logits = outputs[:2]
        accum_loss += (val_loss / len(eval_dataloader))
        if num_labels > 1:
            preds = torch.argmax(logits, dim=1)
        elif num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]
        total_num += batch["labels"].shape[0]
        accum_correct += (torch.sum(preds == labels))

    dist.all_reduce(accum_loss)
    dist.all_reduce(total_num)
    dist.all_reduce(accum_correct)
    avg_loss = "{:.4f}".format(accum_loss.item())
    accuracy = "{:.4f}".format(accum_correct.item() / total_num.item())
    if coordinator.is_master():
        print(f"Evaluation result for epoch {epoch + 1}: \
                average_loss={avg_loss}, \
                accuracy={accuracy}.")
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

def main():

    args = parse_demo_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()
    world_size = coordinator.world_size

    # Manage loggers
    disable_existing_loggers()
    logger = get_dist_logger()
    if coordinator.is_master():
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()

    # Prepare Dataset
    image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path)
    train_dataset = BeansDataset(image_processor, split='train')
    eval_dataset = BeansDataset(image_processor, split='validation')

    # Load pretrained ViT model
    config = ViTConfig.from_pretrained(args.model_name_or_path)
    config.num_labels = train_dataset.num_labels
    config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
    config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
107
108
    model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
                                                      config=config,
109
110
111
112
113
114
115
116
117
118
119
120
121
                                                      ignore_mismatched_sizes=True)
    logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Set plugin
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
122
        plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
123
124
125
126
127
128
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
    logger.info(f"Set plugin as {args.plugin}", ranks=[0])

    # Prepare dataloader
    train_dataloader = plugin.prepare_dataloader(train_dataset,
129
130
131
132
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 drop_last=True,
                                                 collate_fn=beans_collator)
133
    eval_dataloader = plugin.prepare_dataloader(eval_dataset,
134
135
136
137
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                drop_last=True,
                                                collate_fn=beans_collator)
138
139
140
141
142
143
144
145
146
147
148
149
150

    # Set optimizer
    optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)

    # Set lr scheduler
    total_steps = len(train_dataloader) * args.num_epoch
    num_warmup_steps = int(args.warmup_ratio * total_steps)
    lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
                                           total_steps=(len(train_dataloader) * args.num_epoch),
                                           warmup_steps=num_warmup_steps)

    # Set booster
    booster = Booster(plugin=plugin, **booster_kwargs)
151
152
153
154
155
    model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
                                                                        optimizer=optimizer,
                                                                        dataloader=train_dataloader,
                                                                        lr_scheduler=lr_scheduler)

156
157
158
159
160
161
162
163
164
165
166
167
168
    # Finetuning
    logger.info(f"Start finetuning", ranks=[0])
    for epoch in range(args.num_epoch):
        train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
        evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator)
    logger.info(f"Finish finetuning", ranks=[0])

    # Save the finetuned model
    booster.save_model(model, args.output_path)
    logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])


if __name__ == "__main__":
169
    main()