opt_benchmark.py 5.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import time

import torch
import transformers
from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import tqdm

import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator

from args import parse_benchmark_args

require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")


def format_num(num: int, bytes=False):
    """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
    factor = 1024 if bytes else 1000
    suffix = "B" if bytes else ""
    for unit in ["", " K", " M", " G", " T", " P"]:
        if num < factor:
            return f"{num:.2f}{unit}{suffix}"
        num /= factor


def get_data(batch_size, seq_len, vocab_size):
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
    attention_mask = torch.ones_like(input_ids)
    return input_ids, attention_mask


def colo_memory_cap(size_in_GB):
    from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
    cuda_capacity = colo_device_memory_capacity(get_current_device())
    if size_in_GB * (1024**3) < cuda_capacity:
        colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
        print(f"Limiting GPU memory usage to {size_in_GB} GB")


def main():

    args = parse_benchmark_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()
    world_size = coordinator.world_size

    # Manage loggers
    disable_existing_loggers()
    logger = get_dist_logger()
    if coordinator.is_master():
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
    
    # Whether to set limit of memory capacity
    if args.mem_cap > 0:
        colo_memory_cap(args.mem_cap)
    
    # Build OPT model
    # Initialize the model under ColoInitContext if using GeminiPlugin
    config = AutoConfig.from_pretrained(args.model_name_or_path)
    if args.plugin == 'gemini':
        shard_pg = ProcessGroup(tp_degree=world_size)
        default_dist_spec = ShardSpec([-1], [world_size])
        with ColoInitContext(device='cpu',
                            default_dist_spec=default_dist_spec,
                            default_pg=shard_pg):
            model = OPTForCausalLM(config)
    else:
        model = OPTForCausalLM(config)
    logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Set plugin
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
        plugin = GeminiPlugin(device=get_current_device(),
                        placement_policy='cpu',
                        pin_memory=True,
                        strict_ddp_mode=True,
                        initial_scale=2**5)
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
    logger.info(f"Set plugin as {args.plugin}", ranks=[0])

    # Set optimizer
    optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)

    # Set booster
    booster = Booster(plugin=plugin, **booster_kwargs)
    model, optimizer, _, _, _ = booster.boost(model, optimizer)
    
    SEQ_LEN = 1024
    VOCAB_SIZE = 50257

    # Start training.
    logger.info(f"Start testing", ranks=[0])
    progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
    
    torch.cuda.synchronize()
    model.train()
    start_time = time.time()
   
    for _ in range(args.max_train_steps):

        input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
        loss = outputs['loss']
        booster.backward(loss, optimizer)
        optimizer.step()

        torch.cuda.synchronize()
        progress_bar.update(1)
       
    # Compute Statistics   
    end_time = time.time()
    throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
    max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
    
    logger.info(f"Testing finished, " 
                f"batch size per gpu: {args.batch_size}, "
                f"plugin: {args.plugin}, "
                f"throughput: {throughput}, "
                f"maximum memory usage per gpu: {max_mem}.",
                ranks=[0])


if __name__ == "__main__":
    main()