vit_benchmark.py 4.41 KB
Newer Older
1
2
3
import time

import torch
4
import tqdm
5
import transformers
6
from args import parse_benchmark_args
7
8
9
10
11
12
from transformers import ViTConfig, ViTForImageClassification

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
13
14
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
15
16
17
18
19
20
21
22
23
24
25
26
27


def format_num(num: int, bytes=False):
    """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
    factor = 1024 if bytes else 1000
    suffix = "B" if bytes else ""
    for unit in ["", " K", " M", " G", " T", " P"]:
        if num < factor:
            return f"{num:.2f}{unit}{suffix}"
        num /= factor


def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
28
29
30
31
32
33
34
    pixel_values = torch.randn(batch_size,
                               num_channels,
                               height,
                               width,
                               device=torch.cuda.current_device(),
                               dtype=torch.float)
    labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    return pixel_values, labels


def colo_memory_cap(size_in_GB):
    from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
    cuda_capacity = colo_device_memory_capacity(get_current_device())
    if size_in_GB * (1024**3) < cuda_capacity:
        colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
        print(f"Limiting GPU memory usage to {size_in_GB} GB")


def main():

    args = parse_benchmark_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()
    world_size = coordinator.world_size

    # Manage loggers
    disable_existing_loggers()
    logger = get_dist_logger()
    if coordinator.is_master():
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
62

63
64
65
    # Whether to set limit on memory capacity
    if args.mem_cap > 0:
        colo_memory_cap(args.mem_cap)
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    # Build ViT model
    config = ViTConfig.from_pretrained(args.model_name_or_path)
    model = ViTForImageClassification(config)
    logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Set plugin
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
82
        plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
    logger.info(f"Set plugin as {args.plugin}", ranks=[0])

    # Set optimizer
    optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))

    # Set booster
    booster = Booster(plugin=plugin, **booster_kwargs)
    model, optimizer, _, _, _ = booster.boost(model, optimizer)

    # Start training.
    logger.info(f"Start testing", ranks=[0])
    progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
97

98
99
100
    torch.cuda.synchronize()
    model.train()
    start_time = time.time()
101

102
103
104
105
106
107
108
109
110
111
112
    for _ in range(args.max_train_steps):

        pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs['loss']
        booster.backward(loss, optimizer)
        optimizer.step()

        torch.cuda.synchronize()
        progress_bar.update(1)
113
114

    # Compute Statistics
115
116
117
    end_time = time.time()
    throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
    max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
118
119
120
121
122
123
124
125

    logger.info(
        f"Testing finished, "
        f"batch size per gpu: {args.batch_size}, "
        f"plugin: {args.plugin}, "
        f"throughput: {throughput}, "
        f"maximum memory usage per gpu: {max_mem}.",
        ranks=[0])
126
127
128
129


if __name__ == "__main__":
    main()