train.py 9.74 KB
Newer Older
1
2
import gzip
import random
3
from functools import partial
4
5
from time import time

6
7
import numpy as np
import torch
8
import torch.nn as nn
9
import torch.optim as optim
10
import tqdm
11
from packaging import version
12
13

from colossalai.nn import HybridAdam
14
15
16
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset
17
18

import colossalai
19
from colossalai.logging import disable_existing_loggers, get_dist_logger
20
21
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
22
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
23
24
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
25
26
27

# constants

28
NUM_BATCHES = int(10)
29
WARMUP_BATCHES = 1
30
GRADIENT_ACCUMULATE_EVERY = 1
31
32
33
34
35
36
37
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def parse_args():
    parser = colossalai.get_default_parser()
    parser.add_argument(
        "--distplan",
        type=str,
        default='colossalai',
        help="The distributed plan [colossalai, pytorch].",
    )
    parser.add_argument(
        "--tp_degree",
        type=int,
        default=1,
        help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
    )
    parser.add_argument(
        "--placement",
        type=str,
        default='cpu',
        help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
    )
    parser.add_argument(
        "--shardinit",
        type=bool,
        default=False,
        help=
        "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
    )
65
66
67
68
69
70
    parser.add_argument('-p',
                        '--plugin',
                        type=str,
                        default='torch_ddp',
                        choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
                        help="plugin to use")
71
72
73
74
75
76
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="batch size per DP group of training.",
    )
77
78
79
80
81
82
    parser.add_argument(
        "--dummy_data",
        type=bool,
        default=False,
        help="use dummy dataset.",
    )
83
84
85
    args = parser.parse_args()
    return args

86

87
# helpers
88
89
90
91
92
93
94
95
96
def cycle(loader):
    while True:
        for data in loader:
            yield data


def decode_token(token):
    return str(chr(max(32, token)))

97

98
99
def get_tflops(model_numel, batch_size, seq_len, step_time):
    return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
100

101

102
103
104
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

105

106
107
108
109
110
111
def get_model_size(model: nn.Module):
    total_numel = 0
    for module in model.modules():
        for p in module.parameters(recurse=False):
            total_numel += p.numel()
    return total_numel
112

113

114

115

116
# Parameter Sharding Strategies for Tensor Parallelism
ZijianYY's avatar
ZijianYY committed
117
118
119
120
121
122
123
124
125
126
127
128
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
    spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    param.set_tensor_spec(*spec)


def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
    split_param_single_dim_tp1d(0, param, pg)


def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
    split_param_single_dim_tp1d(-1, param, pg)

129

ZijianYY's avatar
ZijianYY committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
    """tensor_parallelize
    Sharding the Model Parameters.
    Args:
        model (torch.nn.Module): a torch module to be sharded
    """
    for mn, module in model.named_modules():
        for pn, param in module.named_parameters(recurse=False):
            if hasattr(param, 'visited'):
                continue
            param.set_dist_spec(ReplicaSpec())
            if 'net.0' in mn:
                split_param_col_tp1d(param, pg)    # colmn slice
            elif 'to_q' in mn:
                split_param_col_tp1d(param, pg)    # colmn slice
            elif 'to_kv' in mn:
                split_param_row_tp1d(param, pg)    # row slice
            elif 'to_out' in mn:
                split_param_row_tp1d(param, pg)    # row slice
            elif '1.1' in mn:
                split_param_col_tp1d(param, pg)    # colmn slice
            elif '1.2' in mn:
                split_param_row_tp1d(param, pg)    # row slice
            else:
                param.set_dist_spec(ReplicaSpec())
            param.visited = True

158

159
160
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
161
    raise TypeError(f"{args.distplan} is error")
162
disable_existing_loggers()
163
colossalai.launch_from_torch(config={})
164
logger = get_dist_logger()
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

def generate_dataset(dummy_data: bool = False):
    if not dummy_data:
        with gzip.open("./data/enwik8.gz") as file:
            X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
            trX, vaX = np.split(X, [int(90e6)])
            data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
            # print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}")
            # print(f"data_val {data_val.shape} {data_val.dtype}  {max(data_val)} {min(data_val)}")
            return data_train, data_val
    else:
        return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,))


data_train, data_val = generate_dataset(args.dummy_data)

print("generate dataset ready!")
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202


class TextSamplerDataset(Dataset):

    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len


train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
203
204
205
206
207
train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size))
val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size))

if args.distplan == "colossalai":
    # instantiate GPT-like decoder model
208

209
210
211
212
213
214
215
216
217
218
219
220
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
        plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
    logger.info(f"plugin: {plugin}")
    booster = Booster(plugin=plugin, **booster_kwargs)

221
222
223
    default_pg = ProcessGroup(tp_degree=args.tp_degree)
    default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
    ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
224

225
    with ctx:
226
        model = PaLM(num_tokens=50304, dim=4096, depth=64)
227
228
229
        model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)

    pg = default_pg
ZijianYY's avatar
ZijianYY committed
230
    tensor_parallelize(model, pg)
231

232
    # optimizer
233

234
235
236
    optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
    model, optimizer, _, _, _ = booster.boost(model, optimizer)

237
238
239
240
241
else:
    model = PaLM(num_tokens=256, dim=512, depth=8)
    model = AutoregressiveWrapper(model, max_seq_len=2048)
    model.cuda()
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
242

243
# model is shared after TP
244
245
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
246
247

# training
248
model.train()
249
tflops_list = []
250
251
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):

252
253
    if args.distplan == "colossalai":
        optimizer.zero_grad()
254
        start = time()
255
        loss = model(next(train_loader))
256
257
        fwd_end = time()
        fwd_time = fwd_end - start
258
259
        # loss.backward()
        optimizer.backward(loss)
260
261
        bwd_end = time()
        bwd_time = bwd_end - fwd_end
262

263
        # print(f"training loss: {loss.item()}")
264
265
266
267
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        # optim.step()
        # optim.zero_grad()
        optimizer.step()
268
269
270
271
272
273
274
275
276
277
        optim_time = time() - bwd_end
        step_time = time() - start

        step_tflops = get_tflops_func(step_time)
        logger.info(
            f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
            ranks=[0],
        )
        if i >= WARMUP_BATCHES:
            tflops_list.append(step_tflops)
278

279
280
281
282
    else:
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            loss = model(next(train_loader))
            loss.backward()
283

284
285
286
287
        print(f"training loss: {loss.item()}")
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
288

289
290
291
292
tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")

293
294
295
296
297
298
# TODO
# if i % VALIDATE_EVERY == 0:
#     model.eval()
#     with torch.no_grad():
#         loss = model(next(val_loader))
#         print(f"validation loss: {loss.item()}")
299
300
301
302
303
304
305
306
307

    # if i % GENERATE_EVERY == 0:
    #     model.eval()
    #     inp = random.choice(val_dataset)[:-1]
    #     prime = decode_tokens(inp)
    #     print(f"%s \n\n %s", (prime, "*" * 100))

    #     sample = model.generate(inp[None, ...], GENERATE_LENGTH)
    #     output_str = decode_tokens(sample[0])
308
    #     print(output_str)