train.py 7.3 KB
Newer Older
1
import gzip
2
from contextlib import nullcontext
3
from functools import partial
4
5
from time import time

6
7
import numpy as np
import torch
8
import torch.nn as nn
9
import torch.optim as optim
10
11
12
13
import tqdm
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset
14
15

import colossalai
16
17
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
18
19
20
21
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import HybridAdam
from colossalai.utils import get_current_device
22
23
24

# constants

25
NUM_BATCHES = int(10)
26
WARMUP_BATCHES = 1
27
GRADIENT_ACCUMULATE_EVERY = 1
28
29
30
31
32
33
34
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024


35
36
37
38
39
def parse_args():
    parser = colossalai.get_default_parser()
    parser.add_argument(
        "--distplan",
        type=str,
40
        default="colossalai",
41
42
43
        help="The distributed plan [colossalai, pytorch].",
    )
    parser.add_argument(
44
45
46
47
        "--offload_optim_frac",
        type=float,
        default=1.0,
        help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
48
    )
49
50
51
52
53
54
55
56
    parser.add_argument(
        "-p",
        "--plugin",
        type=str,
        default="torch_ddp",
        choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"],
        help="plugin to use",
    )
57
58
59
60
61
62
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="batch size per DP group of training.",
    )
63
64
65
66
67
68
    parser.add_argument(
        "--dummy_data",
        type=bool,
        default=False,
        help="use dummy dataset.",
    )
69
70
71
    args = parser.parse_args()
    return args

72

73
# helpers
74
75
76
77
78
79
80
81
82
def cycle(loader):
    while True:
        for data in loader:
            yield data


def decode_token(token):
    return str(chr(max(32, token)))

83

84
85
def get_tflops(model_numel, batch_size, seq_len, step_time):
    return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
86

87

88
89
90
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

91

92
93
94
95
96
97
def get_model_size(model: nn.Module):
    total_numel = 0
    for module in model.modules():
        for p in module.parameters(recurse=False):
            total_numel += p.numel()
    return total_numel
98

99

100
101
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
102
    raise TypeError(f"{args.distplan} is error")
103
disable_existing_loggers()
104
colossalai.launch_from_torch(config={})
105
logger = get_dist_logger()
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

def generate_dataset(dummy_data: bool = False):
    if not dummy_data:
        with gzip.open("./data/enwik8.gz") as file:
            X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
            trX, vaX = np.split(X, [int(90e6)])
            data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
            # print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}")
            # print(f"data_val {data_val.shape} {data_val.dtype}  {max(data_val)} {min(data_val)}")
            return data_train, data_val
    else:
        return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,))


data_train, data_val = generate_dataset(args.dummy_data)

print("generate dataset ready!")
124
125
126
127
128
129
130
131
132
133


class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
134
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
135
136
137
138
139
140
141
142
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len


train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
143
144
145
146
147
train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size))
val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size))

if args.distplan == "colossalai":
    # instantiate GPT-like decoder model
148

149
    booster_kwargs = {}
150
151
152
    if args.plugin == "torch_ddp_fp16":
        booster_kwargs["mixed_precision"] = "fp16"
    if args.plugin.startswith("torch_ddp"):
153
        plugin = TorchDDPPlugin()
154
    elif args.plugin == "gemini":
155
        plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
156
    elif args.plugin == "low_level_zero":
157
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
158
159
160
    logger.info(f"plugin: {plugin}")
    booster = Booster(plugin=plugin, **booster_kwargs)

161
    ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext()
162

163
    with ctx:
164
        model = PaLM(num_tokens=50304, dim=4096, depth=64)
165
166
        model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)

167
    # optimizer
168

169
170
171
    optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
    model, optimizer, _, _, _ = booster.boost(model, optimizer)

172
173
174
175
176
else:
    model = PaLM(num_tokens=256, dim=512, depth=8)
    model = AutoregressiveWrapper(model, max_seq_len=2048)
    model.cuda()
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
177

178
# model is shared after TP
179
180
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
181
182

# training
183
model.train()
184
tflops_list = []
185
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
186
187
    if args.distplan == "colossalai":
        optimizer.zero_grad()
188
        start = time()
189
        loss = model(next(train_loader))
190
191
        fwd_end = time()
        fwd_time = fwd_end - start
192
193
        # loss.backward()
        optimizer.backward(loss)
194
195
        bwd_end = time()
        bwd_time = bwd_end - fwd_end
196

197
        # print(f"training loss: {loss.item()}")
198
199
200
201
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        # optim.step()
        # optim.zero_grad()
        optimizer.step()
202
203
204
205
206
207
208
209
210
211
        optim_time = time() - bwd_end
        step_time = time() - start

        step_tflops = get_tflops_func(step_time)
        logger.info(
            f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
            ranks=[0],
        )
        if i >= WARMUP_BATCHES:
            tflops_list.append(step_tflops)
212

213
214
215
216
    else:
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            loss = model(next(train_loader))
            loss.backward()
217

218
219
220
221
        print(f"training loss: {loss.item()}")
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
222

223
224
225
226
tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")

227
228
229
230
231
232
# TODO
# if i % VALIDATE_EVERY == 0:
#     model.eval()
#     with torch.no_grad():
#         loss = model(next(val_loader))
#         print(f"validation loss: {loss.item()}")
233

234
235
236
237
238
# if i % GENERATE_EVERY == 0:
#     model.eval()
#     inp = random.choice(val_dataset)[:-1]
#     prime = decode_tokens(inp)
#     print(f"%s \n\n %s", (prime, "*" * 100))
239

240
241
242
#     sample = model.generate(inp[None, ...], GENERATE_LENGTH)
#     output_str = decode_tokens(sample[0])
#     print(output_str)