train_gpt_demo.py 12.3 KB
Newer Older
1
import os
2
3
4
5
6
7
from functools import partial
from time import time

import psutil
import torch
import torch.nn as nn
8
from commons.model_zoo import model_builder
9
from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp
10
from packaging import version
11
from torch.nn.parallel import DistributedDataParallel as DDP
12
13

import colossalai
14
15
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
16
from colossalai.logging import disable_existing_loggers, get_dist_logger
17
from colossalai.nn.optimizer import HybridAdam
18
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
19
from colossalai.utils import get_current_device
20
from colossalai.zero import ColoInitContext
21
22
23

CAI_VERSION = colossalai.__version__

24

25
26
def parse_args():
    parser = colossalai.get_default_parser()
27
28
29
    parser.add_argument(
        "--distplan",
        type=str,
30
        default='CAI_Gemini',
31
        help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
32
    )
33
34
35
36
    parser.add_argument(
        "--tp_degree",
        type=int,
        default=1,
37
        help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
38
39
40
41
42
    )
    parser.add_argument(
        "--placement",
        type=str,
        default='cpu',
43
44
45
46
        help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
    )
    parser.add_argument(
        "--shardinit",
47
        action='store_true',
48
49
        help=
        "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
50
    )
51
52
53
54
55
56
57
58
59
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="batch size per DP group of training.",
    )
    parser.add_argument(
        "--model_type",
        type=str,
60
        default="gpt2_medium",
61
62
        help="model model scale",
    )
63
64
65
66
67
68
69
    parser.add_argument(
        "--train_step",
        type=int,
        default=10,
        help="training iterations for test",
    )

70
71
72
73
    args = parser.parse_args()
    return args


74
# Parameter Sharding Strategies for Tensor Parallelism
75
76
77
78
79
80
81
82
83
84
85
86
87
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
    spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    param.set_tensor_spec(*spec)


def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
    split_param_single_dim_tp1d(0, param, pg)


def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
    split_param_single_dim_tp1d(-1, param, pg)


88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
class GPTLMLoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


def get_cpu_mem():
    return psutil.Process().memory_info().rss / 1024**2


def get_gpu_mem():
    return torch.cuda.memory_allocated() / 1024**2


def get_mem_info(prefix=''):
    return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'


113
114
115
116
117
118
119
120
def get_model_size(model: nn.Module):
    total_numel = 0
    for module in model.modules():
        for p in module.parameters(recurse=False):
            total_numel += p.numel()
    return total_numel


HELSON's avatar
HELSON committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def model_size_formatter(numel: int) -> str:
    GB_SIZE = 10**9
    MB_SIZE = 10**6
    KB_SIZE = 10**3
    if numel >= GB_SIZE:
        return f'{numel / GB_SIZE:.1f}B'
    elif numel >= MB_SIZE:
        return f'{numel / MB_SIZE:.1f}M'
    elif numel >= KB_SIZE:
        return f'{numel / KB_SIZE:.1f}K'
    else:
        return str(numel)


135
136
137
138
139
140
141
142
def set_cpu_maximum_parallelism():
    conf_str = torch.__config__.parallel_info()
    inter_str = conf_str.split("hardware_concurrency() : ")[1]
    max_concurrency = inter_str.split('\n')[0]
    os.environ["OMP_NUM_THREADS"] = max_concurrency
    print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")


143
144
145
146
147
148
149
150
151
152
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
    """tensor_parallelize
    Sharding the Model Parameters.

    Args:
        model (torch.nn.Module): a torch module to be sharded
    """
    for mn, module in model.named_modules():
        for pn, param in module.named_parameters(recurse=False):
153
            # NOTE() a param maybe shared by two modules
154
155
            if hasattr(param, 'visited'):
                continue
156
157
158

            # if shard init, then convert param to replica and use the dp-only ProcessGroup
            param: ColoParameter = param
159
            param.set_dist_spec(ReplicaSpec())
160
161
162
            param.set_process_group(pg)

            # shard it w.r.t tp pattern
163
164
            if 'mlp.c_fc' in mn:
                if 'weight' in pn or 'bias' in pn:
digger yu's avatar
digger yu committed
165
                    split_param_col_tp1d(param, pg)    # column slice
166
167
                    # keep the shape of the output from c_fc
                    param.compute_spec.set_output_replicate(False)
168
169
                else:
                    param.set_dist_spec(ReplicaSpec())
170
171
172
            elif 'mlp.c_proj' in mn:
                if 'weight' in pn:
                    split_param_row_tp1d(param, pg)    # row slice
173
174
                else:
                    param.set_dist_spec(ReplicaSpec())
175
            elif 'wte' in mn or 'wpe' in mn:
digger yu's avatar
digger yu committed
176
                split_param_col_tp1d(param, pg)    # column slice
177
            elif 'c_attn' in mn or 'c_proj' in mn:
digger yu's avatar
digger yu committed
178
                split_param_col_tp1d(param, pg)    # column slice
179
180
181
            else:
                param.set_dist_spec(ReplicaSpec())
            param.visited = True
182
183


184
def main():
185
    # version check
186
187
    # this example is supposed to work for versions greater than 0.2.0
    assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
188

189
    set_cpu_maximum_parallelism()
190
191
    args = parse_args()

192
193
    # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
    if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
194
195
        raise TypeError(f"{args.distplan} is error")

196
197
    # batch size per DP degree
    BATCH_SIZE = args.batch_size
198
199
    SEQ_LEN = 1024
    VOCAB_SIZE = 50257
200

201
202
    NUM_STEPS = args.train_step

203
204
    WARMUP_STEPS = 1
    assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
205
206
    assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
    PROF_FLAG = False    # The flag of profiling, False by default
207

208
209
210
    disable_existing_loggers()
    colossalai.launch_from_torch(config={})

211
    logger = get_dist_logger()
212
    logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
213
214
215
216

    # build criterion
    criterion = GPTLMLoss()

217
    torch.manual_seed(123)
218
    if args.distplan.startswith("CAI"):
219
        # all param must use the same process group.
220
        world_size = torch.distributed.get_world_size()
221
        shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
222
        default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
223

224
225
226
        if args.shardinit and args.distplan != "CAI_Gemini":
            raise RuntimeError("You can only use shardinit with CAI_Gemini")

227
        # build GPT model
228
229
230
231
232
        with ColoInitContext(device=get_current_device(),
                             dtype=torch.half,
                             default_dist_spec=default_dist_spec,
                             default_pg=shard_pg):
            model = model_builder(args.model_type)(checkpoint=True)
233

234
        tp_pg = ProcessGroup(tp_degree=args.tp_degree)
235
        # Tensor Parallelism (TP)
236
        # You should notice that v0.1.10 is not compatible with TP degree > 1
HELSON's avatar
HELSON committed
237
238
        if args.tp_degree > 1:
            tensor_parallelize(model, tp_pg)
239

digger yu's avatar
digger yu committed
240
        # assign running configurations
241
242
243
244
245
246
247
248
249
        if args.distplan == "CAI_ZeRO1":
            zero_stage = 1
        elif args.distplan == "CAI_ZeRO2":
            zero_stage = 2
        elif args.distplan == "CAI_Gemini":
            zero_stage = 3
        else:
            raise RuntimeError

250
251
252
        plugin = None
        if args.distplan.startswith("CAI_ZeRO"):
            plugin = LowLevelZeroPlugin(stage=zero_stage,
253
                                        reduce_bucket_size_in_m=12,
254
255
256
257
258
259
260
                                        overlap_communication=True,
                                        verbose=True)
        elif args.distplan == "CAI_Gemini":
            plugin = GeminiPlugin(device=get_current_device(),
                                  placement_policy=args.placement,
                                  pin_memory=True,
                                  strict_ddp_mode=args.tp_degree == 1,
261
                                  search_range_m=128,
262
263
264
265
266
267
268
                                  hidden_dim=model.config.n_embd,
                                  gpu_margin_mem_ratio=0.)
        else:
            raise RuntimeError

        # build a highly optimized gpu/cpu optimizer
        optimizer = HybridAdam(model.parameters(), lr=1e-3)
269
270

        logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
271
    elif args.distplan.startswith("Pytorch"):
272
        assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
273
        model = model_builder(args.model_type)(checkpoint=True).cuda()
274
        plugin = TorchDDPPlugin()
275
276
277
        if args.distplan.endswith("DDP"):
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        elif args.distplan.endswith("ZeRO"):
278
            from torch.distributed.optim import ZeroRedundancyOptimizer
279
            optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
280

281
282
    else:
        raise RuntimeError
283
284
285
    # wrap your model and optimizer
    booster = Booster(plugin=plugin)
    model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
286

287
    # model is shared after TP
288
    numel = get_model_size(model)
HELSON's avatar
HELSON committed
289
    logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
290
    logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
291
292
293
294

    # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
    # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
    # = batch_per_DP_group * numel * seq_len * 8
295
    get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
296

297
    torch.cuda.synchronize()
298
    model.train()
299
    tflops_list = []
300
301

    def train_step():
302
303
304
        # we just use randomly generated data here
        input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
        optimizer.zero_grad()
305

306
307
308
        start = time()
        outputs = model(input_ids, attn_mask)
        loss = criterion(outputs, input_ids)
309
310
311
312
        torch.cuda.synchronize()
        fwd_end = time()
        fwd_time = fwd_end - start
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
313
        booster.backward(loss, optimizer)
314

315
316
317
318
319
        torch.cuda.synchronize()
        bwd_end = time()
        bwd_time = bwd_end - fwd_end
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])

320
        optimizer.step()
321
        torch.cuda.synchronize()
322
        optim_time = time() - bwd_end
323
        step_time = time() - start
324
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
325

326
327
328
329
330
331
332
333
        step_tflops = get_tflops_func(step_time)
        logger.info(
            f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
            ranks=[0],
        )
        if n >= WARMUP_STEPS:
            tflops_list.append(step_tflops)

334
335
336
337
338
339
340
341
342
343
    demo_profiler = get_profile_context(PROF_FLAG,
                                        WARMUP_STEPS,
                                        NUM_STEPS - WARMUP_STEPS,
                                        save_dir=f"profile/{get_time_stamp()}-demo")

    with demo_profiler as prof:
        for n in range(NUM_STEPS):
            train_step()
            prof.step()

344
345
346
    tflops_list.sort()
    median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
    logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
347
348
    torch.cuda.synchronize()

349
350
351

if __name__ == '__main__':
    main()