vit_benchmark.py 5.2 KB
Newer Older
1
2
3
4
import time

import torch
import transformers
5
from args import parse_benchmark_args
6
from tqdm import tqdm
7
8
9
10
from transformers import ViTConfig, ViTForImageClassification

import colossalai
from colossalai.booster import Booster
11
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
12
from colossalai.cluster import DistCoordinator
13
14
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
15
16
17
18
19
20
21
22
23
24
25
26


def format_num(num: int, bytes=False):
    """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
    factor = 1024 if bytes else 1000
    suffix = "B" if bytes else ""
    for unit in ["", " K", " M", " G", " T", " P"]:
        if num < factor:
            return f"{num:.2f}{unit}{suffix}"
        num /= factor


27
def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
28
29
30
    pixel_values = torch.randn(
        batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float
    )
31
    labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
32
    return dict(pixel_values=pixel_values, labels=labels)
33
34
35


def colo_memory_cap(size_in_GB):
36
37
    from colossalai.accelerator import get_accelerator
    from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction
38

39
    cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    if size_in_GB * (1024**3) < cuda_capacity:
        colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
        print(f"Limiting GPU memory usage to {size_in_GB} GB")


def main():
    args = parse_benchmark_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()
    world_size = coordinator.world_size

    # Manage loggers
    disable_existing_loggers()
    logger = get_dist_logger()
    if coordinator.is_master():
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
60

61
62
63
    # Whether to set limit on memory capacity
    if args.mem_cap > 0:
        colo_memory_cap(args.mem_cap)
64

65
66
67
68
69
70
    # Build ViT model
    config = ViTConfig.from_pretrained(args.model_name_or_path)
    model = ViTForImageClassification(config)
    logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

    # Enable gradient checkpointing
71
72
    if args.grad_checkpoint:
        model.gradient_checkpointing_enable()
73
74
75

    # Set plugin
    booster_kwargs = {}
76
77
78
    if args.plugin == "torch_ddp_fp16":
        booster_kwargs["mixed_precision"] = "fp16"
    if args.plugin.startswith("torch_ddp"):
79
        plugin = TorchDDPPlugin()
80
    elif args.plugin == "gemini":
81
        plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
82
    elif args.plugin == "low_level_zero":
83
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
84
85
86
87
88
89
90
91
92
93
    elif args.plugin == "hybrid_parallel":
        plugin = HybridParallelPlugin(
            tp_size=2,
            pp_size=2,
            num_microbatches=None,
            microbatch_size=1,
            enable_all_optimization=True,
            precision="fp16",
            initial_scale=1,
        )
94
95
96
97
98
    logger.info(f"Set plugin as {args.plugin}", ranks=[0])

    # Set optimizer
    optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))

99
100
101
102
    # Set criterion (loss function)
    def criterion(outputs, inputs):
        return outputs.loss

103
104
    # Set booster
    booster = Booster(plugin=plugin, **booster_kwargs)
105
    model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
106
107
108

    # Start training.
    logger.info(f"Start testing", ranks=[0])
109

110
111
112
    torch.cuda.synchronize()
    model.train()
    start_time = time.time()
113

114
115
116
117
118
119
120
121
    with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
        for _ in pbar:
            optimizer.zero_grad()
            batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)

            if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
                # run pipeline forward backward
                batch = iter([batch])
122
                outputs = booster.execute_pipeline(
123
                    batch, model, criterion, optimizer, return_loss=True
124
                )
125
126
127
128
129
130
131
132
133
            else:
                outputs = model(**batch)
                loss = criterion(outputs, None)
                # Backward
                booster.backward(loss, optimizer)

            optimizer.step()

            torch.cuda.synchronize()
134
135

    # Compute Statistics
136
137
138
    end_time = time.time()
    throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
    max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
139
140
141
142
143
144
145

    logger.info(
        f"Testing finished, "
        f"batch size per gpu: {args.batch_size}, "
        f"plugin: {args.plugin}, "
        f"throughput: {throughput}, "
        f"maximum memory usage per gpu: {max_mem}.",
146
147
        ranks=[0],
    )
148

149
150
    torch.cuda.empty_cache()

151
152
153

if __name__ == "__main__":
    main()