train.py 9.74 KB
Newer Older
1
2
import gzip
import random
3
from functools import partial
4
5
from time import time

6
7
import numpy as np
import torch
8
import torch.nn as nn
9
import torch.optim as optim
10
import tqdm
11
from packaging import version
12
13

from colossalai.nn import HybridAdam
14
15
16
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset
17
18

import colossalai
19
from colossalai.logging import disable_existing_loggers, get_dist_logger
20
21
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
22
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
23
24
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
25
26
27

# constants

28
NUM_BATCHES = int(10)
29
WARMUP_BATCHES = 1
30
GRADIENT_ACCUMULATE_EVERY = 1
31
32
33
34
35
36
37
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def parse_args():
    parser = colossalai.get_default_parser()
    parser.add_argument(
        "--distplan",
        type=str,
        default='colossalai',
        help="The distributed plan [colossalai, pytorch].",
    )
    parser.add_argument(
        "--tp_degree",
        type=int,
        default=1,
        help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
    )
    parser.add_argument(
        "--placement",
        type=str,
        default='cpu',
        help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
    )
    parser.add_argument(
        "--shardinit",
        type=bool,
        default=False,
        help=
        "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
    )
65
66
67
68
69
70
    parser.add_argument('-p',
                        '--plugin',
                        type=str,
                        default='torch_ddp',
                        choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
                        help="plugin to use")
71
72
73
74
75
76
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="batch size per DP group of training.",
    )
77
78
79
80
81
82
    parser.add_argument(
        "--dummy_data",
        type=bool,
        default=False,
        help="use dummy dataset.",
    )
83
84
85
    args = parser.parse_args()
    return args

86

87
# helpers
88
89
90
91
92
93
94
95
96
def cycle(loader):
    while True:
        for data in loader:
            yield data


def decode_token(token):
    return str(chr(max(32, token)))

97

98
99
def get_tflops(model_numel, batch_size, seq_len, step_time):
    return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
100

101

102
103
104
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

105

106
107
108
109
110
111
def get_model_size(model: nn.Module):
    total_numel = 0
    for module in model.modules():
        for p in module.parameters(recurse=False):
            total_numel += p.numel()
    return total_numel
112

113

114

115

116
# Parameter Sharding Strategies for Tensor Parallelism
ZijianYY's avatar
ZijianYY committed
117
118
119
120
121
122
123
124
125
126
127
128
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
    spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    param.set_tensor_spec(*spec)


def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
    split_param_single_dim_tp1d(0, param, pg)


def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
    split_param_single_dim_tp1d(-1, param, pg)

129

ZijianYY's avatar
ZijianYY committed
130
131
132
133
134
135
136
137
138
139
140
141
142
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
    """tensor_parallelize
    Sharding the Model Parameters.
    Args:
        model (torch.nn.Module): a torch module to be sharded
    """
    for mn, module in model.named_modules():
        for pn, param in module.named_parameters(recurse=False):
            if hasattr(param, 'visited'):
                continue
            param.set_dist_spec(ReplicaSpec())
            if 'net.0' in mn:
digger yu's avatar
digger yu committed
143
                split_param_col_tp1d(param, pg)    # column slice
ZijianYY's avatar
ZijianYY committed
144
            elif 'to_q' in mn:
digger yu's avatar
digger yu committed
145
                split_param_col_tp1d(param, pg)    # column slice
ZijianYY's avatar
ZijianYY committed
146
147
148
149
150
            elif 'to_kv' in mn:
                split_param_row_tp1d(param, pg)    # row slice
            elif 'to_out' in mn:
                split_param_row_tp1d(param, pg)    # row slice
            elif '1.1' in mn:
digger yu's avatar
digger yu committed
151
                split_param_col_tp1d(param, pg)    # column slice
ZijianYY's avatar
ZijianYY committed
152
153
154
155
156
157
            elif '1.2' in mn:
                split_param_row_tp1d(param, pg)    # row slice
            else:
                param.set_dist_spec(ReplicaSpec())
            param.visited = True

158

159
160
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
161
    raise TypeError(f"{args.distplan} is error")
162
disable_existing_loggers()
163
colossalai.launch_from_torch(config={})
164
logger = get_dist_logger()
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

def generate_dataset(dummy_data: bool = False):
    if not dummy_data:
        with gzip.open("./data/enwik8.gz") as file:
            X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
            trX, vaX = np.split(X, [int(90e6)])
            data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
            # print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}")
            # print(f"data_val {data_val.shape} {data_val.dtype}  {max(data_val)} {min(data_val)}")
            return data_train, data_val
    else:
        return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,))


data_train, data_val = generate_dataset(args.dummy_data)

print("generate dataset ready!")
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202


class TextSamplerDataset(Dataset):

    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len


train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
203
204
205
206
207
train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size))
val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size))

if args.distplan == "colossalai":
    # instantiate GPT-like decoder model
208

209
210
211
212
213
214
215
216
217
218
219
220
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
        plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
    logger.info(f"plugin: {plugin}")
    booster = Booster(plugin=plugin, **booster_kwargs)

221
222
223
    default_pg = ProcessGroup(tp_degree=args.tp_degree)
    default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
    ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
224

225
    with ctx:
226
        model = PaLM(num_tokens=50304, dim=4096, depth=64)
227
228
229
        model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)

    pg = default_pg
ZijianYY's avatar
ZijianYY committed
230
    tensor_parallelize(model, pg)
231

232
    # optimizer
233

234
235
236
    optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
    model, optimizer, _, _, _ = booster.boost(model, optimizer)

237
238
239
240
241
else:
    model = PaLM(num_tokens=256, dim=512, depth=8)
    model = AutoregressiveWrapper(model, max_seq_len=2048)
    model.cuda()
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
242

243
# model is shared after TP
244
245
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
246
247

# training
248
model.train()
249
tflops_list = []
250
251
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):

252
253
    if args.distplan == "colossalai":
        optimizer.zero_grad()
254
        start = time()
255
        loss = model(next(train_loader))
256
257
        fwd_end = time()
        fwd_time = fwd_end - start
258
259
        # loss.backward()
        optimizer.backward(loss)
260
261
        bwd_end = time()
        bwd_time = bwd_end - fwd_end
262

263
        # print(f"training loss: {loss.item()}")
264
265
266
267
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        # optim.step()
        # optim.zero_grad()
        optimizer.step()
268
269
270
271
272
273
274
275
276
277
        optim_time = time() - bwd_end
        step_time = time() - start

        step_tflops = get_tflops_func(step_time)
        logger.info(
            f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
            ranks=[0],
        )
        if i >= WARMUP_BATCHES:
            tflops_list.append(step_tflops)
278

279
280
281
282
    else:
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            loss = model(next(train_loader))
            loss.backward()
283

284
285
286
287
        print(f"training loss: {loss.item()}")
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
288

289
290
291
292
tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")

293
294
295
296
297
298
# TODO
# if i % VALIDATE_EVERY == 0:
#     model.eval()
#     with torch.no_grad():
#         loss = model(next(val_loader))
#         print(f"validation loss: {loss.item()}")
299
300
301
302
303
304
305
306
307

    # if i % GENERATE_EVERY == 0:
    #     model.eval()
    #     inp = random.choice(val_dataset)[:-1]
    #     prime = decode_tokens(inp)
    #     print(f"%s \n\n %s", (prime, "*" * 100))

    #     sample = model.generate(inp[None, ...], GENERATE_LENGTH)
    #     output_str = decode_tokens(sample[0])
308
    #     print(output_str)