train_gpt_demo.py 8.55 KB
Newer Older
1
import os
2
from contextlib import nullcontext
3
4
5
6
7
8
from functools import partial
from time import time

import psutil
import torch
import torch.nn as nn
9
from commons.model_zoo import model_builder
10
from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp
11
from packaging import version
12
from torch.nn.parallel import DistributedDataParallel as DDP
13
14

import colossalai
15
16
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
17
from colossalai.lazy import LazyInitContext
18
from colossalai.logging import disable_existing_loggers, get_dist_logger
19
from colossalai.nn.optimizer import HybridAdam
20
from colossalai.utils import get_current_device
21
22
23

CAI_VERSION = colossalai.__version__

24

25
26
def parse_args():
    parser = colossalai.get_default_parser()
27
28
29
    parser.add_argument(
        "--distplan",
        type=str,
30
        default='CAI_Gemini',
31
        help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
32
    )
33
34
35
36
37
38
39
40
41
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="batch size per DP group of training.",
    )
    parser.add_argument(
        "--model_type",
        type=str,
42
        default="gpt2_medium",
43
44
        help="model model scale",
    )
45
46
47
48
49
50
51
    parser.add_argument(
        "--train_step",
        type=int,
        default=10,
        help="training iterations for test",
    )

52
53
54
55
    args = parser.parse_args()
    return args


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class GPTLMLoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


def get_cpu_mem():
    return psutil.Process().memory_info().rss / 1024**2


def get_gpu_mem():
    return torch.cuda.memory_allocated() / 1024**2


def get_mem_info(prefix=''):
    return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'


81
82
83
84
85
86
87
88
def get_model_size(model: nn.Module):
    total_numel = 0
    for module in model.modules():
        for p in module.parameters(recurse=False):
            total_numel += p.numel()
    return total_numel


HELSON's avatar
HELSON committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def model_size_formatter(numel: int) -> str:
    GB_SIZE = 10**9
    MB_SIZE = 10**6
    KB_SIZE = 10**3
    if numel >= GB_SIZE:
        return f'{numel / GB_SIZE:.1f}B'
    elif numel >= MB_SIZE:
        return f'{numel / MB_SIZE:.1f}M'
    elif numel >= KB_SIZE:
        return f'{numel / KB_SIZE:.1f}K'
    else:
        return str(numel)


103
104
105
106
107
108
109
110
def set_cpu_maximum_parallelism():
    conf_str = torch.__config__.parallel_info()
    inter_str = conf_str.split("hardware_concurrency() : ")[1]
    max_concurrency = inter_str.split('\n')[0]
    os.environ["OMP_NUM_THREADS"] = max_concurrency
    print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")


111
def main():
112
    # version check
113
114
    # this example is supposed to work for versions greater than 0.2.0
    assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
115

116
    set_cpu_maximum_parallelism()
117
118
    args = parse_args()

119
120
    # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
    if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
121
122
        raise TypeError(f"{args.distplan} is error")

123
124
    # batch size per DP degree
    BATCH_SIZE = args.batch_size
125
126
    SEQ_LEN = 1024
    VOCAB_SIZE = 50257
127

128
129
    NUM_STEPS = args.train_step

130
131
    WARMUP_STEPS = 1
    assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
132
133
    assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
    PROF_FLAG = False    # The flag of profiling, False by default
134

135
136
137
    disable_existing_loggers()
    colossalai.launch_from_torch(config={})

138
    logger = get_dist_logger()
139
    logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
140
141
142

    # build criterion
    criterion = GPTLMLoss()
143
    torch.manual_seed(123)
144
    if args.distplan.startswith("CAI"):
145
        ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
146
        # build GPT model
147
        with ctx:
148
            model = model_builder(args.model_type)(checkpoint=True)
149

digger yu's avatar
digger yu committed
150
        # assign running configurations
151
152
153
154
155
156
157
158
159
        if args.distplan == "CAI_ZeRO1":
            zero_stage = 1
        elif args.distplan == "CAI_ZeRO2":
            zero_stage = 2
        elif args.distplan == "CAI_Gemini":
            zero_stage = 3
        else:
            raise RuntimeError

160
161
162
        plugin = None
        if args.distplan.startswith("CAI_ZeRO"):
            plugin = LowLevelZeroPlugin(stage=zero_stage,
163
                                        reduce_bucket_size_in_m=12,
164
165
166
                                        overlap_communication=True,
                                        verbose=True)
        elif args.distplan == "CAI_Gemini":
167
            plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
168
169
170
171
172
        else:
            raise RuntimeError

        # build a highly optimized gpu/cpu optimizer
        optimizer = HybridAdam(model.parameters(), lr=1e-3)
173
174

        logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
175
    elif args.distplan.startswith("Pytorch"):
176
        assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
177
        model = model_builder(args.model_type)(checkpoint=True).cuda()
178
        plugin = TorchDDPPlugin()
179
180
181
        if args.distplan.endswith("DDP"):
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        elif args.distplan.endswith("ZeRO"):
182
            from torch.distributed.optim import ZeroRedundancyOptimizer
183
            optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
184

185
186
    else:
        raise RuntimeError
187
188
189
    # wrap your model and optimizer
    booster = Booster(plugin=plugin)
    model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
190

191
    # model is shared after TP
192
    numel = get_model_size(model)
HELSON's avatar
HELSON committed
193
    logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
194
    logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
195
196
197
198

    # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
    # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
    # = batch_per_DP_group * numel * seq_len * 8
199
    get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
200

201
    torch.cuda.synchronize()
202
    model.train()
203
    tflops_list = []
204
205

    def train_step():
206
207
208
        # we just use randomly generated data here
        input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
        optimizer.zero_grad()
209

210
211
212
        start = time()
        outputs = model(input_ids, attn_mask)
        loss = criterion(outputs, input_ids)
213
214
215
216
        torch.cuda.synchronize()
        fwd_end = time()
        fwd_time = fwd_end - start
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
217
        booster.backward(loss, optimizer)
218

219
220
221
222
223
        torch.cuda.synchronize()
        bwd_end = time()
        bwd_time = bwd_end - fwd_end
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])

224
        optimizer.step()
225
        torch.cuda.synchronize()
226
        optim_time = time() - bwd_end
227
        step_time = time() - start
228
        logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
229

230
231
232
233
234
235
236
237
        step_tflops = get_tflops_func(step_time)
        logger.info(
            f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
            ranks=[0],
        )
        if n >= WARMUP_STEPS:
            tflops_list.append(step_tflops)

238
239
240
241
242
243
244
245
246
247
    demo_profiler = get_profile_context(PROF_FLAG,
                                        WARMUP_STEPS,
                                        NUM_STEPS - WARMUP_STEPS,
                                        save_dir=f"profile/{get_time_stamp()}-demo")

    with demo_profiler as prof:
        for n in range(NUM_STEPS):
            train_step()
            prof.step()

248
249
250
    tflops_list.sort()
    median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
    logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
251
252
    torch.cuda.synchronize()

253
254
255

if __name__ == '__main__':
    main()