train.py 7.4 KB
Newer Older
1
import gzip
2
from contextlib import nullcontext
3
from functools import partial
4
5
from time import time

6
7
import numpy as np
import torch
8
import torch.nn as nn
9
import torch.optim as optim
10
11
12
13
import tqdm
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset
14
15

import colossalai
16
17
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
18
19
20
21
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import HybridAdam
from colossalai.utils import get_current_device
22
23
24

# constants

25
NUM_BATCHES = int(10)
26
WARMUP_BATCHES = 1
27
GRADIENT_ACCUMULATE_EVERY = 1
28
29
30
31
32
33
34
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024


35
36
37
38
39
40
41
42
43
def parse_args():
    parser = colossalai.get_default_parser()
    parser.add_argument(
        "--distplan",
        type=str,
        default='colossalai',
        help="The distributed plan [colossalai, pytorch].",
    )
    parser.add_argument(
44
45
46
47
        "--offload_optim_frac",
        type=float,
        default=1.0,
        help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
48
    )
49
50
51
52
53
54
    parser.add_argument('-p',
                        '--plugin',
                        type=str,
                        default='torch_ddp',
                        choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
                        help="plugin to use")
55
56
57
58
59
60
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="batch size per DP group of training.",
    )
61
62
63
64
65
66
    parser.add_argument(
        "--dummy_data",
        type=bool,
        default=False,
        help="use dummy dataset.",
    )
67
68
69
    args = parser.parse_args()
    return args

70

71
# helpers
72
73
74
75
76
77
78
79
80
def cycle(loader):
    while True:
        for data in loader:
            yield data


def decode_token(token):
    return str(chr(max(32, token)))

81

82
83
def get_tflops(model_numel, batch_size, seq_len, step_time):
    return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
84

85

86
87
88
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

89

90
91
92
93
94
95
def get_model_size(model: nn.Module):
    total_numel = 0
    for module in model.modules():
        for p in module.parameters(recurse=False):
            total_numel += p.numel()
    return total_numel
96

97

98
99
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
100
    raise TypeError(f"{args.distplan} is error")
101
disable_existing_loggers()
102
colossalai.launch_from_torch(config={})
103
logger = get_dist_logger()
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

def generate_dataset(dummy_data: bool = False):
    if not dummy_data:
        with gzip.open("./data/enwik8.gz") as file:
            X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
            trX, vaX = np.split(X, [int(90e6)])
            data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
            # print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}")
            # print(f"data_val {data_val.shape} {data_val.dtype}  {max(data_val)} {min(data_val)}")
            return data_train, data_val
    else:
        return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,))


data_train, data_val = generate_dataset(args.dummy_data)

print("generate dataset ready!")
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141


class TextSamplerDataset(Dataset):

    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len


train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
142
143
144
145
146
train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size))
val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size))

if args.distplan == "colossalai":
    # instantiate GPT-like decoder model
147

148
149
150
151
152
153
    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
154
        plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
155
    elif args.plugin == 'low_level_zero':
156
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
157
158
159
    logger.info(f"plugin: {plugin}")
    booster = Booster(plugin=plugin, **booster_kwargs)

160
    ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
161

162
    with ctx:
163
        model = PaLM(num_tokens=50304, dim=4096, depth=64)
164
165
        model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)

166
    # optimizer
167

168
169
170
    optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
    model, optimizer, _, _, _ = booster.boost(model, optimizer)

171
172
173
174
175
else:
    model = PaLM(num_tokens=256, dim=512, depth=8)
    model = AutoregressiveWrapper(model, max_seq_len=2048)
    model.cuda()
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
176

177
# model is shared after TP
178
179
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
180
181

# training
182
model.train()
183
tflops_list = []
184
185
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):

186
187
    if args.distplan == "colossalai":
        optimizer.zero_grad()
188
        start = time()
189
        loss = model(next(train_loader))
190
191
        fwd_end = time()
        fwd_time = fwd_end - start
192
193
        # loss.backward()
        optimizer.backward(loss)
194
195
        bwd_end = time()
        bwd_time = bwd_end - fwd_end
196

197
        # print(f"training loss: {loss.item()}")
198
199
200
201
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        # optim.step()
        # optim.zero_grad()
        optimizer.step()
202
203
204
205
206
207
208
209
210
211
        optim_time = time() - bwd_end
        step_time = time() - start

        step_tflops = get_tflops_func(step_time)
        logger.info(
            f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
            ranks=[0],
        )
        if i >= WARMUP_BATCHES:
            tflops_list.append(step_tflops)
212

213
214
215
216
    else:
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            loss = model(next(train_loader))
            loss.backward()
217

218
219
220
221
        print(f"training loss: {loss.item()}")
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
222

223
224
225
226
tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")

227
228
229
230
231
232
# TODO
# if i % VALIDATE_EVERY == 0:
#     model.eval()
#     with torch.no_grad():
#         loss = model(next(val_loader))
#         print(f"validation loss: {loss.item()}")
233
234
235
236
237
238
239
240
241

    # if i % GENERATE_EVERY == 0:
    #     model.eval()
    #     inp = random.choice(val_dataset)[:-1]
    #     prime = decode_tokens(inp)
    #     print(f"%s \n\n %s", (prime, "*" * 100))

    #     sample = model.generate(inp[None, ...], GENERATE_LENGTH)
    #     output_str = decode_tokens(sample[0])
242
    #     print(output_str)