benchmark.py 9.62 KB
Newer Older
1
2
3
4
5
import argparse
import resource
from contextlib import nullcontext

import torch
6
from attn import replace_with_flash_attention
7
8
9
10
11
12
13
14
15
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from tqdm import tqdm
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

import colossalai
16
from colossalai.accelerator import get_accelerator
17
18
19
20
21
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
22
23
24
from examples.language.data_utils import RandomDataset
from examples.language.model_utils import format_numel_str, get_model_numel
from examples.language.performance_evaluator import PerformanceEvaluator
25
26
27
28
29
30

# ==============================
# Constants
# ==============================

MODEL_CONFIGS = {
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    "7b": LlamaConfig(max_position_embeddings=4096),
    "13b": LlamaConfig(
        hidden_size=5120,
        intermediate_size=13824,
        num_hidden_layers=40,
        num_attention_heads=40,
        max_position_embeddings=4096,
    ),
    "70b": LlamaConfig(
        hidden_size=8192,
        intermediate_size=28672,
        num_hidden_layers=80,
        num_attention_heads=64,
        max_position_embeddings=4096,
        num_key_value_heads=8,
    ),
47
48
49
50
51
52
53
54
}


def main():
    # ==============================
    # Parse Arguments
    # ==============================
    parser = argparse.ArgumentParser()
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
    parser.add_argument(
        "-p",
        "--plugin",
        choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"],
        default="gemini",
        help="Choose which plugin to use",
    )
    parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
    parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
    parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
    parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
    parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
    parser.add_argument(
        "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
    )
    parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
    parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
    parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
    parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
    parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
    parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
77
    parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
78
    parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
79
80
    parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
    parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
81
82
83
84
85
86
87
88
89
90
91
92
    args = parser.parse_args()

    colossalai.launch_from_torch({})
    coordinator = DistCoordinator()

    def empty_init():
        pass

    # ==============================
    # Initialize Booster
    # ==============================
    use_empty_init = True
93
94
95
96
97
98
    if args.plugin == "gemini":
        plugin = GeminiPlugin(
            precision="bf16",
            shard_param_frac=args.shard_param_frac,
            offload_optim_frac=args.offload_optim_frac,
            offload_param_frac=args.offload_param_frac,
99
100
            tp_size=args.tp,
            extra_dp_size=args.extra_dp,
101
102
        )
    elif args.plugin == "gemini_auto":
103
104
105
106
107
108
109
        plugin = GeminiPlugin(
            placement_policy="auto",
            precision="bf16",
            warmup_non_model_data_ratio=args.warmup_ratio,
            tp_size=args.tp,
            extra_dp_size=args.extra_dp,
        )
110
    elif args.plugin == "fsdp":
111
112
        if use_empty_init:
            plugin = TorchFSDPPlugin(
113
114
115
                mixed_precision=MixedPrecision(
                    param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
                ),
116
117
118
                param_init_fn=empty_init(),
            )
        else:
119
120
121
122
123
124
            plugin = TorchFSDPPlugin(
                mixed_precision=MixedPrecision(
                    param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
                )
            )
    elif args.plugin == "fsdp_cpu":
125
126
        if use_empty_init:
            plugin = TorchFSDPPlugin(
127
128
129
                mixed_precision=MixedPrecision(
                    param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
                ),
130
131
132
133
                cpu_offload=CPUOffload(offload_params=True),
                param_init_fn=empty_init(),
            )
        else:
134
135
136
137
138
139
140
141
142
143
            plugin = TorchFSDPPlugin(
                mixed_precision=MixedPrecision(
                    param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
                ),
                cpu_offload=CPUOffload(offload_params=True),
            )
    elif args.plugin == "3d":
        plugin = HybridParallelPlugin(
            tp_size=args.tp,
            pp_size=args.pp,
144
            pp_style="interleaved",
145
            zero_stage=args.zero,
146
            num_model_chunks=2,
147
            enable_fused_normalization=torch.cuda.is_available(),
148
            microbatch_size=args.mbs,
149
150
151
152
153
154
155
156
            precision="bf16",
        )
    elif args.plugin == "3d_cpu":
        plugin = HybridParallelPlugin(
            tp_size=args.tp,
            pp_size=args.pp,
            zero_stage=args.zero,
            cpu_offload=True,
157
            enable_fused_normalization=torch.cuda.is_available(),
158
            microbatch_size=args.mbs,
159
160
161
            initial_scale=2**8,
            precision="bf16",
        )
162
    else:
163
        raise ValueError(f"Unknown plugin {args.plugin}")
164
165
166
167
168
169
170
171
172

    booster = Booster(plugin=plugin)

    # ==============================
    # Initialize Dataset and Dataloader
    # ==============================
    dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size

    config = MODEL_CONFIGS[args.config]
173
174
175
    dataset = RandomDataset(
        num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
    )
176
177
178
179
180
    dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

    # ==============================
    # Initialize Model and Optimizer
    # ==============================
181
    init_ctx = (
182
        LazyInitContext(default_device=get_accelerator().get_current_device())
183
184
185
        if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
        else nullcontext()
    )
186
187
188
189
190
191
192
193

    with init_ctx:
        model = LlamaForCausalLM(config)

    if args.grad_checkpoint:
        model.gradient_checkpointing_enable()

    if args.xformers:
194
        replace_with_flash_attention(model)
195
196

    model_numel = get_model_numel(model)
197
198
    coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
    performance_evaluator = PerformanceEvaluator(
199
200
201
202
        model_numel,
        model.config.num_hidden_layers,
        model.config.hidden_size,
        model.config.vocab_size,
203
204
205
        args.grad_checkpoint,
        args.ignore_steps,
        dp_world_size=dp_size,
206
    )
207
208
209
210
211

    optimizer = HybridAdam(model.parameters())
    torch.set_default_dtype(torch.bfloat16)
    model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
    torch.set_default_dtype(torch.float)
212
213
214
    coordinator.print_on_master(
        f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
    )
215
    coordinator.print_on_master(
216
217
        f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
    )
218
219
220

    if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
        data_iter = iter(dataloader)
221
        for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
222
            performance_evaluator.on_step_start(step)
223
224
225
            booster.execute_pipeline(
                data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False
            )
226
227
228
229
            optimizer.step()
            optimizer.zero_grad()
            performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
    else:
230
        for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
231
232
233
234
235
236
237
238
239
            performance_evaluator.on_step_start(step)
            outputs = model(**batch)
            loss = outputs[0]
            booster.backward(loss, optimizer)
            optimizer.step()
            optimizer.zero_grad()
            performance_evaluator.on_step_end(**batch)

    performance_evaluator.on_fit_end()
240
    coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
241
242


243
if __name__ == "__main__":
244
    main()