train_gpt_demo.py 13.7 KB
Newer Older
1
import os
2
3
4
5
6
7
from functools import partial
from time import time

import psutil
import torch
import torch.nn as nn
8
9
from commons.model_zoo import model_builder
from commons.utils import get_data, get_tflops
10
from packaging import version
11
from torch.nn.parallel import DistributedDataParallel as DDP
12
13
14
15

import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.parallel import ZeroDDP
16
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
17
18
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
19
20
21
22
23
24
25
26

CAI_VERSION = colossalai.__version__

if version.parse(CAI_VERSION) > version.parse("0.1.10"):
    # These are added after 0.1.10
    from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
    from colossalai.nn.parallel import GeminiDDP
    from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
27
28


29
30
def parse_args():
    parser = colossalai.get_default_parser()
31
32
33
34
    parser.add_argument(
        "--distplan",
        type=str,
        default='colossalai',
35
        help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
36
    )
37
38
39
40
    parser.add_argument(
        "--tp_degree",
        type=int,
        default=1,
41
        help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
42
43
44
45
46
    )
    parser.add_argument(
        "--placement",
        type=str,
        default='cpu',
47
48
49
50
51
52
53
54
        help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
    )
    parser.add_argument(
        "--shardinit",
        type=bool,
        default=False,
        help=
        "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
55
    )
56
57
58
59
60
61
62
63
64
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="batch size per DP group of training.",
    )
    parser.add_argument(
        "--model_type",
        type=str,
65
        default="gpt2_medium",
66
67
        help="model model scale",
    )
68
69
70
71
72
73
74
    parser.add_argument(
        "--train_step",
        type=int,
        default=10,
        help="training iterations for test",
    )

75
76
77
78
    args = parser.parse_args()
    return args


79
# Parameter Sharding Strategies for Tensor Parallelism
80
81
82
83
84
85
86
87
88
89
90
91
92
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
    spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    param.set_tensor_spec(*spec)


def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
    split_param_single_dim_tp1d(0, param, pg)


def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
    split_param_single_dim_tp1d(-1, param, pg)


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class GPTLMLoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


def get_cpu_mem():
    return psutil.Process().memory_info().rss / 1024**2


def get_gpu_mem():
    return torch.cuda.memory_allocated() / 1024**2


def get_mem_info(prefix=''):
    return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'


118
119
120
121
122
123
124
125
def get_model_size(model: nn.Module):
    total_numel = 0
    for module in model.modules():
        for p in module.parameters(recurse=False):
            total_numel += p.numel()
    return total_numel


HELSON's avatar
HELSON committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def model_size_formatter(numel: int) -> str:
    GB_SIZE = 10**9
    MB_SIZE = 10**6
    KB_SIZE = 10**3
    if numel >= GB_SIZE:
        return f'{numel / GB_SIZE:.1f}B'
    elif numel >= MB_SIZE:
        return f'{numel / MB_SIZE:.1f}M'
    elif numel >= KB_SIZE:
        return f'{numel / KB_SIZE:.1f}K'
    else:
        return str(numel)


140
141
142
143
144
145
146
147
def set_cpu_maximum_parallelism():
    conf_str = torch.__config__.parallel_info()
    inter_str = conf_str.split("hardware_concurrency() : ")[1]
    max_concurrency = inter_str.split('\n')[0]
    os.environ["OMP_NUM_THREADS"] = max_concurrency
    print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")


148
149
150
151
152
153
154
155
156
157
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
    """tensor_parallelize
    Sharding the Model Parameters.

    Args:
        model (torch.nn.Module): a torch module to be sharded
    """
    for mn, module in model.named_modules():
        for pn, param in module.named_parameters(recurse=False):
158
            # NOTE() a param maybe shared by two modules
159
160
            if hasattr(param, 'visited'):
                continue
161
162
163

            # if shard init, then convert param to replica and use the dp-only ProcessGroup
            param: ColoParameter = param
164
            param.set_dist_spec(ReplicaSpec())
165
166
167
            param.set_process_group(pg)

            # shard it w.r.t tp pattern
168
169
170
171
172
            if 'mlp.c_fc' in mn:
                if 'weight' in pn or 'bias' in pn:
                    split_param_col_tp1d(param, pg)    # colmn slice
                    # keep the shape of the output from c_fc
                    param.compute_spec.set_output_replicate(False)
173
174
                else:
                    param.set_dist_spec(ReplicaSpec())
175
176
177
            elif 'mlp.c_proj' in mn:
                if 'weight' in pn:
                    split_param_row_tp1d(param, pg)    # row slice
178
179
                else:
                    param.set_dist_spec(ReplicaSpec())
180
181
182
183
            elif 'wte' in mn or 'wpe' in mn:
                split_param_col_tp1d(param, pg)    # colmn slice
            elif 'c_attn' in mn or 'c_proj' in mn:
                split_param_col_tp1d(param, pg)    # colmn slice
184
185
186
            else:
                param.set_dist_spec(ReplicaSpec())
            param.visited = True
187
188
189


# Gemini + ZeRO DDP
HELSON's avatar
HELSON committed
190
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
191
192
193
194
    fp16_init_scale = 2**5
    gpu_margin_mem_ratio_for_auto = 0

    if version.parse(CAI_VERSION) > version.parse("0.1.10"):
195
        model = GeminiDDP(model,
HELSON's avatar
HELSON committed
196
                          strict_ddp_mode=ddp_flag,
197
                          device=get_current_device(),
198
                          placement_policy=placement_policy,
199
                          pin_memory=True,
HELSON's avatar
HELSON committed
200
                          hidden_dim=model.config.n_embd,
HELSON's avatar
HELSON committed
201
                          search_range_mb=128)
202
203
        # configure the const policy
        if placement_policy == 'const':
204
            model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
205
206
207
208
209
210
        # build a highly optimized cpu optimizer
        optimizer = GeminiAdamOptimizer(model,
                                        lr=1e-3,
                                        initial_scale=fp16_init_scale,
                                        gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
    elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"):
211
        from colossalai.gemini import ChunkManager, GeminiManager
212
213
214
        from colossalai.nn.optimizer import HybridAdam
        from colossalai.zero import ZeroOptimizer
        chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True)
215
216
217
        chunk_manager = ChunkManager(chunk_size,
                                     pg,
                                     enable_distributed_storage=True,
218
219
                                     init_device=GeminiManager.get_default_device(placement_policy))
        gemini_manager = GeminiManager(placement_policy, chunk_manager)
220
        model = ZeroDDP(model, gemini_manager)
221
222
223
224
225
        optimizer = HybridAdam(model.parameters(), lr=1e-3)
        optimizer = ZeroOptimizer(optimizer,
                                  model,
                                  initial_scale=fp16_init_scale,
                                  gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
226
    else:
227
228
        raise NotImplemented(f"CAI version {CAI_VERSION} is not supported")
    return model, optimizer
229
230


231
def main():
232
    # version check
233
    # this example is supposed to work for versions greater than 0.1.9
234
235
    assert version.parse(CAI_VERSION) >= version.parse("0.1.9")

236
    set_cpu_maximum_parallelism()
237
238
    args = parse_args()

239
240
241
    if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
        raise TypeError(f"{args.distplan} is error")

242
243
    # batch size per DP degree
    BATCH_SIZE = args.batch_size
244
245
    SEQ_LEN = 1024
    VOCAB_SIZE = 50257
246

247
248
    NUM_STEPS = args.train_step

249
250
251
    WARMUP_STEPS = 1
    assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
    assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
252

253
254
255
    disable_existing_loggers()
    colossalai.launch_from_torch(config={})

256
    logger = get_dist_logger()
257
    logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
258
259
260
261

    # build criterion
    criterion = GPTLMLoss()

262
263
264
    torch.manual_seed(123)
    if args.distplan == "colossalai":
        # all param must use the same process group.
265
        world_size = torch.distributed.get_world_size()
266
        shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
267
        default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
268
269

        # build GPT model
270
271
272
273
        if version.parse(CAI_VERSION) > version.parse("0.1.10"):
            with ColoInitContext(device=get_current_device(),
                                 dtype=torch.half,
                                 default_dist_spec=default_dist_spec,
274
                                 default_pg=shard_pg):
275
276
277
278
                model = model_builder(args.model_type)(checkpoint=True)
        else:
            with ColoInitContext(device=get_current_device()):
                model = model_builder(args.model_type)(checkpoint=True)
279

280
        tp_pg = ProcessGroup(tp_degree=args.tp_degree)
281
        # Tensor Parallelism (TP)
282
        # You should notice that v0.1.10 is not compatible with TP degree > 1
HELSON's avatar
HELSON committed
283
284
        if args.tp_degree > 1:
            tensor_parallelize(model, tp_pg)
285

286
        # build a Gemini model and a highly optimized cpu optimizer
287
        # Gemini + ZeRO DP, Note it must be used after TP
HELSON's avatar
HELSON committed
288
        model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
289
290

        logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
291
    else:
292
        assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
293
        model = model_builder(args.model_type)(checkpoint=True).cuda()
294

295
296
297
298
299
300
301
302
    if args.distplan.startswith("torch"):
        model = DDP(model)
        if args.distplan.endswith("ddp"):
            optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        elif args.distplan.endswith("zero"):
            from torch.distributed.optim import ZeroRedundancyOptimizer
            optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
    elif args.distplan.startswith("zero"):
303
        model = model.half()
304
        partition_flag = (args.distplan == "zero2")
305
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
306
307
308
309
310
311
312
313

        optimizer = LowLevelZeroOptimizer(
            optimizer,
            reduce_bucket_size=12 * 1024 * 1024,
            overlap_communication=True,
            partition_grad=partition_flag,
            verbose=True,
        )
314

315
    # model is shared after TP
316
    numel = get_model_size(model)
HELSON's avatar
HELSON committed
317
    logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
318
    logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
319
320
321
322

    # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
    # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
    # = batch_per_DP_group * numel * seq_len * 8
323
    get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
324

325
    torch.cuda.synchronize()
326
    model.train()
327
    tflops_list = []
328
329
330
331
    for n in range(NUM_STEPS):
        # we just use randomly generated data here
        input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
        optimizer.zero_grad()
332

333
334
335
        start = time()
        outputs = model(input_ids, attn_mask)
        loss = criterion(outputs, input_ids)
336
337
338
339
340
        torch.cuda.synchronize()
        fwd_end = time()
        fwd_time = fwd_end - start
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])

341
        if args.distplan in ["colossalai", "zero1", "zero2"]:
342
            optimizer.backward(loss)
343
        elif args.distplan in ["torch_ddp", "torch_zero"]:
344
            loss.backward()
345
346
347
348
349
        torch.cuda.synchronize()
        bwd_end = time()
        bwd_time = bwd_end - fwd_end
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])

350
351
        if args.distplan in ["zero1", "zero2"]:
            optimizer.sync_grad()
352
        optimizer.step()
353
        torch.cuda.synchronize()
354
        optim_time = time() - bwd_end
355
        step_time = time() - start
356
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
357

358
359
360
361
362
363
364
365
366
367
368
        step_tflops = get_tflops_func(step_time)
        logger.info(
            f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
            ranks=[0],
        )
        if n >= WARMUP_STEPS:
            tflops_list.append(step_tflops)

    tflops_list.sort()
    median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
    logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
369
370
    torch.cuda.synchronize()

371
372
373

if __name__ == '__main__':
    main()