Unverified Commit 2cdecc9f authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[example] make palm + GeminiDPP work (#2227)

parent 63cc7717
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import einsum, nn, matmul from torch import einsum, matmul, nn
# normalization # normalization
# they use layernorm without bias, something that pytorch does not offer # they use layernorm without bias, something that pytorch does not offer
...@@ -86,8 +86,6 @@ def FeedForward(dim, mult=4): ...@@ -86,8 +86,6 @@ def FeedForward(dim, mult=4):
# attention # attention
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8): def __init__(self, dim, dim_head=64, heads=8):
...@@ -142,8 +140,6 @@ class Attention(nn.Module): ...@@ -142,8 +140,6 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
# split heads # split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper # they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously # they found no performance loss past a certain scale, and more efficient decoding obviously
...@@ -165,7 +161,7 @@ class Attention(nn.Module): ...@@ -165,7 +161,7 @@ class Attention(nn.Module):
# similarity # similarity
#sim = einsum("b h i d, b j d -> b h i j", q, k) #sim = einsum("b h i d, b j d -> b h i j", q, k)
sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2)) sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2))
sim = sim.reshape(b, h, i, j) sim = sim.reshape(b, h, i, j)
# causal mask # causal mask
...@@ -183,7 +179,7 @@ class Attention(nn.Module): ...@@ -183,7 +179,7 @@ class Attention(nn.Module):
# aggregate values # aggregate values
#out = einsum("b h i j, b j d -> b h i d", attn, v) #out = einsum("b h i j, b j d -> b h i d", attn, v)
out = matmul(attn.reshape(b_, h_*i_, j_), v) out = matmul(attn.reshape(b_, h_ * i_, j_), v)
out = out.reshape(b_, h_, i_, d_) out = out.reshape(b_, h_, i_, d_)
# merge heads # merge heads
......
env OMP_NUM_THREADS=12 torchrun --nproc_per_node 8 --master_port 29501 train.py --config palm_config.py env OMP_NUM_THREADS=12 torchrun --nproc_per_node 4 --master_port 29501 train.py --config palm_config.py
\ No newline at end of file
...@@ -5,38 +5,36 @@ import numpy as np ...@@ -5,38 +5,36 @@ import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from packaging import version
from palm_pytorch import PaLM from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from packaging import version
import colossalai import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device from colossalai.utils import MultiTimer, get_current_device
from colossalai.nn.parallel import ZeroDDP from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.logging import disable_existing_loggers, get_dist_logger
# constants # constants
NUM_BATCHES = int(1e5) NUM_BATCHES = int(20)
BATCH_SIZE = 4 BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4 GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4 LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100 VALIDATE_EVERY = 100
GENERATE_EVERY = 500 GENERATE_EVERY = 500
GENERATE_LENGTH = 512 GENERATE_LENGTH = 512
SEQ_LEN = 1024 SEQ_LEN = 1024
TPDEGREE = 2 TPDEGREE = 1
USE_SHARD_INIT = False USE_SHARD_INIT = False
placement = 'cpu' placement = 'cpu'
# helpers
# helpers
def cycle(loader): def cycle(loader):
while True: while True:
for data in loader: for data in loader:
...@@ -50,6 +48,7 @@ def decode_token(token): ...@@ -50,6 +48,7 @@ def decode_token(token):
def decode_tokens(tokens): def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens))) return "".join(list(map(decode_token, tokens)))
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__ cai_version = colossalai.__version__
...@@ -72,7 +71,8 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: ...@@ -72,7 +71,8 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
else: else:
raise NotImplemented(f"CAI version {cai_version} is not supported") raise NotImplemented(f"CAI version {cai_version} is not supported")
return model return model
# instantiate GPT-like decoder model # instantiate GPT-like decoder model
parser = colossalai.get_default_parser() parser = colossalai.get_default_parser()
...@@ -80,24 +80,15 @@ args = parser.parse_args() ...@@ -80,24 +80,15 @@ args = parser.parse_args()
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config=args.config, seed=42) colossalai.launch_from_torch(config=args.config, seed=42)
# instantiate GPT-like decoder model # instantiate GPT-like decoder model
default_pg = ProcessGroup(tp_degree=TPDEGREE) default_pg = ProcessGroup(tp_degree=TPDEGREE)
default_dist_spec = ShardSpec([-1], [TPDEGREE]) if USE_SHARD_INIT else None default_dist_spec = ShardSpec([-1], [TPDEGREE]) if USE_SHARD_INIT else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx:
model = PaLM(num_tokens=256,dim=512,depth=8)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
model.cuda()
# prepare enwik8 data
# model = PaLM(num_tokens=256, dim=512, depth=8) with ctx:
model = PaLM(num_tokens=256, dim=512, depth=8)
# model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
# model.cuda()
with gzip.open("./data/enwik8.gz") as file: with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
...@@ -129,46 +120,42 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE)) ...@@ -129,46 +120,42 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
#tensor_parallelize(model, pg) #tensor_parallelize(model, pg)
pg = default_pg pg = default_pg
# model = GeminiDDP(model,
# device=get_current_device(),
# placement_policy="auto",
# pin_memory=True,
# search_range_mb=32)
model = gemini_zero_dpp(model, pg, placement) model = gemini_zero_dpp(model, pg, placement)
#optimizer #optimizer
optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
#optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training # training
model.train()
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY): optimizer.zero_grad()
loss = model(next(train_loader))
loss.backward() loss = model(next(train_loader))
# loss.backward()
optimizer.backward(loss)
print(f"training loss: {loss.item()}") print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step() # optim.step()
# optim.zero_grad() # optim.zero_grad()
optimizer.step() optimizer.step()
optimizer.zero_grad()
if i % VALIDATE_EVERY == 0: # TODO
model.eval() # if i % VALIDATE_EVERY == 0:
with torch.no_grad(): # model.eval()
loss = model(next(val_loader)) # with torch.no_grad():
print(f"validation loss: {loss.item()}") # loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
if i % GENERATE_EVERY == 0:
model.eval() # if i % GENERATE_EVERY == 0:
inp = random.choice(val_dataset)[:-1] # model.eval()
prime = decode_tokens(inp) # inp = random.choice(val_dataset)[:-1]
print(f"%s \n\n %s", (prime, "*" * 100)) # prime = decode_tokens(inp)
# print(f"%s \n\n %s", (prime, "*" * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0]) # sample = model.generate(inp[None, ...], GENERATE_LENGTH)
print(output_str) # output_str = decode_tokens(sample[0])
# print(output_str)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment