benchmark.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import argparse
import resource
from contextlib import nullcontext

import torch
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from tqdm import tqdm
11
from transformers import AutoConfig, AutoModelForCausalLM
12
13
14
from transformers.models.llama.configuration_llama import LlamaConfig

import colossalai
15
from colossalai.accelerator import get_accelerator
16
17
18
19
20
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
21
from colossalai.shardformer import PipelineGradientCheckpointConfig
22
23
24
from examples.language.data_utils import RandomDataset
from examples.language.model_utils import format_numel_str, get_model_numel
from examples.language.performance_evaluator import PerformanceEvaluator
25
26
27
28
29
30

# ==============================
# Constants
# ==============================

MODEL_CONFIGS = {
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    "7b": LlamaConfig(max_position_embeddings=4096),
    "13b": LlamaConfig(
        hidden_size=5120,
        intermediate_size=13824,
        num_hidden_layers=40,
        num_attention_heads=40,
        max_position_embeddings=4096,
    ),
    "70b": LlamaConfig(
        hidden_size=8192,
        intermediate_size=28672,
        num_hidden_layers=80,
        num_attention_heads=64,
        max_position_embeddings=4096,
        num_key_value_heads=8,
    ),
47
48
49
50
51
52
53
54
}


def main():
    # ==============================
    # Parse Arguments
    # ==============================
    parser = argparse.ArgumentParser()
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
    parser.add_argument(
        "-p",
        "--plugin",
        choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"],
        default="gemini",
        help="Choose which plugin to use",
    )
    parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
    parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
    parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
    parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
    parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
    parser.add_argument(
        "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
    )
    parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
    parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
    parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
    parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
    parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
    parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
77
    parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
78
    parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
79
80
    parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
    parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
81
    parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
82
83
84
85
86
87
88
89
    args = parser.parse_args()

    colossalai.launch_from_torch({})
    coordinator = DistCoordinator()

    def empty_init():
        pass

90
91
92
93
94
95
96
97
98
99
100
101
102
    # ckpt config for LLaMA3-70B on 64 H100 GPUs
    ckpt_config = (
        PipelineGradientCheckpointConfig(
            num_stages=args.pp,
            num_model_chunks=1,
            num_model_layers=80,
            num_layers_per_stage=[19, 20, 20, 21],
            num_ckpt_layers_per_stage=[19, 19, 19, 13],
        )
        if args.custom_ckpt
        else None
    )

103
104
105
106
    # ==============================
    # Initialize Booster
    # ==============================
    use_empty_init = True
107
108
109
110
111
112
    if args.plugin == "gemini":
        plugin = GeminiPlugin(
            precision="bf16",
            shard_param_frac=args.shard_param_frac,
            offload_optim_frac=args.offload_optim_frac,
            offload_param_frac=args.offload_param_frac,
113
114
            tp_size=args.tp,
            extra_dp_size=args.extra_dp,
115
116
            enable_fused_normalization=torch.cuda.is_available(),
            enable_flash_attention=args.xformers,
117
118
        )
    elif args.plugin == "gemini_auto":
119
120
121
122
123
124
        plugin = GeminiPlugin(
            placement_policy="auto",
            precision="bf16",
            warmup_non_model_data_ratio=args.warmup_ratio,
            tp_size=args.tp,
            extra_dp_size=args.extra_dp,
125
126
            enable_fused_normalization=torch.cuda.is_available(),
            enable_flash_attention=args.xformers,
127
        )
128
    elif args.plugin == "fsdp":
129
130
        if use_empty_init:
            plugin = TorchFSDPPlugin(
131
                mixed_precision=MixedPrecision(
132
133
134
                    param_dtype=torch.float16,
                    reduce_dtype=torch.float16,
                    buffer_dtype=torch.float16,
135
                ),
136
137
138
                param_init_fn=empty_init(),
            )
        else:
139
140
            plugin = TorchFSDPPlugin(
                mixed_precision=MixedPrecision(
141
142
143
                    param_dtype=torch.float16,
                    reduce_dtype=torch.float16,
                    buffer_dtype=torch.float16,
144
145
146
                )
            )
    elif args.plugin == "fsdp_cpu":
147
148
        if use_empty_init:
            plugin = TorchFSDPPlugin(
149
                mixed_precision=MixedPrecision(
150
151
152
                    param_dtype=torch.float16,
                    reduce_dtype=torch.float16,
                    buffer_dtype=torch.float16,
153
                ),
154
155
156
157
                cpu_offload=CPUOffload(offload_params=True),
                param_init_fn=empty_init(),
            )
        else:
158
159
            plugin = TorchFSDPPlugin(
                mixed_precision=MixedPrecision(
160
161
162
                    param_dtype=torch.float16,
                    reduce_dtype=torch.float16,
                    buffer_dtype=torch.float16,
163
164
165
166
167
168
169
170
                ),
                cpu_offload=CPUOffload(offload_params=True),
            )
    elif args.plugin == "3d":
        plugin = HybridParallelPlugin(
            tp_size=args.tp,
            pp_size=args.pp,
            zero_stage=args.zero,
171
            enable_fused_normalization=torch.cuda.is_available(),
172
            enable_flash_attention=args.xformers,
173
            microbatch_size=args.mbs,
174
            precision="bf16",
175
176
            dp_outside=False,
            gradient_checkpoint_config=ckpt_config,
177
178
179
180
181
182
183
        )
    elif args.plugin == "3d_cpu":
        plugin = HybridParallelPlugin(
            tp_size=args.tp,
            pp_size=args.pp,
            zero_stage=args.zero,
            cpu_offload=True,
184
            enable_fused_normalization=torch.cuda.is_available(),
185
            enable_flash_attention=args.xformers,
186
            microbatch_size=args.mbs,
187
188
189
            initial_scale=2**8,
            precision="bf16",
        )
190
    else:
191
        raise ValueError(f"Unknown plugin {args.plugin}")
192
193
194
195
196
197

    booster = Booster(plugin=plugin)

    # ==============================
    # Initialize Dataset and Dataloader
    # ==============================
198
    dp_size = getattr(plugin, "dp_size", coordinator.world_size)
199

200
201
202
203
    if args.config in MODEL_CONFIGS:
        config = MODEL_CONFIGS[args.config]
    else:
        config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
204
205
206
    dataset = RandomDataset(
        num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
    )
207
208
209
210
211
    dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

    # ==============================
    # Initialize Model and Optimizer
    # ==============================
212
    init_ctx = (
213
        LazyInitContext(default_device=get_accelerator().get_current_device())
214
215
216
        if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
        else nullcontext()
    )
217

218
219
220
221
    init_kwargs = {}
    if config.model_type == "chatglm":
        init_kwargs["empty_init"] = False

222
    with init_ctx:
223
        model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
224
225
226

    if args.grad_checkpoint:
        model.gradient_checkpointing_enable()
227
228
        if config.model_type == "chatglm":
            model.transformer.encoder.gradient_checkpointing = True
229
230

    model_numel = get_model_numel(model)
231
232
    coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
    performance_evaluator = PerformanceEvaluator(
233
234
235
236
        model_numel,
        model.config.num_hidden_layers,
        model.config.hidden_size,
        model.config.vocab_size,
237
238
239
        args.grad_checkpoint,
        args.ignore_steps,
        dp_world_size=dp_size,
240
    )
241
242
243
244
245

    optimizer = HybridAdam(model.parameters())
    torch.set_default_dtype(torch.bfloat16)
    model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
    torch.set_default_dtype(torch.float)
246
247
248
    coordinator.print_on_master(
        f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
    )
249
    coordinator.print_on_master(
250
251
        f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
    )
252
253
254

    if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
        data_iter = iter(dataloader)
255
        for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
256
            performance_evaluator.on_step_start(step)
257
258
259
            booster.execute_pipeline(
                data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False
            )
260
261
262
263
            optimizer.step()
            optimizer.zero_grad()
            performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
    else:
264
        for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
265
266
267
268
269
270
271
272
273
            performance_evaluator.on_step_start(step)
            outputs = model(**batch)
            loss = outputs[0]
            booster.backward(loss, optimizer)
            optimizer.step()
            optimizer.zero_grad()
            performance_evaluator.on_step_end(**batch)

    performance_evaluator.on_fit_end()
274
    coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
275
276


277
if __name__ == "__main__":
278
    main()