opt_benchmark.py 4.33 KB
Newer Older
1
2
3
import time

import torch
4
import tqdm
5
import transformers
6
from args import parse_benchmark_args
7
8
9
10
11
12
13
from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
14
15
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")


def format_num(num: int, bytes=False):
    """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
    factor = 1024 if bytes else 1000
    suffix = "B" if bytes else ""
    for unit in ["", " K", " M", " G", " T", " P"]:
        if num < factor:
            return f"{num:.2f}{unit}{suffix}"
        num /= factor


def get_data(batch_size, seq_len, vocab_size):
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
    attention_mask = torch.ones_like(input_ids)
    return input_ids, attention_mask


def colo_memory_cap(size_in_GB):
    from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
    cuda_capacity = colo_device_memory_capacity(get_current_device())
    if size_in_GB * (1024**3) < cuda_capacity:
        colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
        print(f"Limiting GPU memory usage to {size_in_GB} GB")


def main():

    args = parse_benchmark_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()
    world_size = coordinator.world_size

    # Manage loggers
    disable_existing_loggers()
    logger = get_dist_logger()
    if coordinator.is_master():
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
60

61
62
63
    # Whether to set limit of memory capacity
    if args.mem_cap > 0:
        colo_memory_cap(args.mem_cap)
64

65
66
    # Build OPT model
    config = AutoConfig.from_pretrained(args.model_name_or_path)
67
    model = OPTForCausalLM(config=config)
68
69
70
71
72
73
74
75
76
77
78
79
    logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Set plugin
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
80
        plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
81
82
83
84
85
86
87
88
89
90
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
    logger.info(f"Set plugin as {args.plugin}", ranks=[0])

    # Set optimizer
    optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)

    # Set booster
    booster = Booster(plugin=plugin, **booster_kwargs)
    model, optimizer, _, _, _ = booster.boost(model, optimizer)
91

92
93
94
95
96
97
    SEQ_LEN = 1024
    VOCAB_SIZE = 50257

    # Start training.
    logger.info(f"Start testing", ranks=[0])
    progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
98

99
100
101
    torch.cuda.synchronize()
    model.train()
    start_time = time.time()
102

103
104
105
106
107
108
109
110
111
112
113
    for _ in range(args.max_train_steps):

        input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
        loss = outputs['loss']
        booster.backward(loss, optimizer)
        optimizer.step()

        torch.cuda.synchronize()
        progress_bar.update(1)
114
115

    # Compute Statistics
116
117
118
    end_time = time.time()
    throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
    max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
119
120
121
122
123
124
125
126

    logger.info(
        f"Testing finished, "
        f"batch size per gpu: {args.batch_size}, "
        f"plugin: {args.plugin}, "
        f"throughput: {throughput}, "
        f"maximum memory usage per gpu: {max_mem}.",
        ranks=[0])
127
128
129
130


if __name__ == "__main__":
    main()