Unverified Commit fe0f7970 authored by ZijianYY's avatar ZijianYY Committed by GitHub
Browse files

[examples] adding tflops to PaLM (#2365)

parent 93f62dd1
import gzip
import random
from time import time
from functools import partial
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import tqdm
from packaging import version
from palm_pytorch import PaLM
......@@ -21,7 +23,8 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
# constants
NUM_BATCHES = int(1000)
NUM_BATCHES = int(100)
WARMUP_BATCHES = 1
GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
......@@ -76,10 +79,18 @@ def cycle(loader):
def decode_token(token):
return str(chr(max(32, token)))
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
total_numel += p.numel()
return total_numel
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
......@@ -143,7 +154,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
......@@ -152,6 +162,7 @@ if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()
with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
......@@ -188,7 +199,7 @@ if args.distplan == "colossalai":
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx:
model = PaLM(num_tokens=256, dim=512, depth=8)
model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
......@@ -205,25 +216,42 @@ else:
model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# model is shared after TP
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
# training
model.train()
tflops_list = []
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
if args.distplan == "colossalai":
optimizer.zero_grad()
start = time()
loss = model(next(train_loader))
fwd_end = time()
fwd_time = fwd_end - start
# loss.backward()
optimizer.backward(loss)
bwd_end = time()
bwd_time = bwd_end - fwd_end
print(f"training loss: {loss.item()}")
# print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step()
# optim.zero_grad()
optimizer.step()
optim_time = time() - bwd_end
step_time = time() - start
step_tflops = get_tflops_func(step_time)
logger.info(
f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
ranks=[0],
)
if i >= WARMUP_BATCHES:
tflops_list.append(step_tflops)
else:
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
......@@ -233,6 +261,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
# TODO
# if i % VALIDATE_EVERY == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment