Commit d03ea00f authored by suily's avatar suily
Browse files

Initial commit

parents
Pipeline #1898 canceled with stages
################## 1. 下载检查点,构建模型
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # 禁用默认参数init以获得更快的速度
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # 禁用默认参数init以获得更快的速度
from models import VQVAE, build_vae_var
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
print(torch.cuda.get_device_name(0))
MODEL_DEPTH = 16 # TODO:更改此处,指定模型
assert MODEL_DEPTH in {16, 20, 24, 30}
# download checkpoint
model_path="./checkpoint/" # TODO:更改此处,指定模型地址
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = model_path+'vae_ch160v4096z32.pth', model_path+f'var_d{MODEL_DEPTH}.pth'
assert os.path.exists(f'{vae_ckpt}')
assert os.path.exists(f'{var_ckpt}')
# if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
# if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')
# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else exit()
if 'vae' not in globals() or 'var' not in globals():
vae, var = build_vae_var(
V=4096, Cvae=32, ch=160, share_quant_resi=4, # 硬编码VQVAE超参数
device=device, patch_nums=patch_nums,
num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
)
# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')
############################# 2. 使用无分类器指导的采样
# set args
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 437, 437, 22, 22, 562, 562) #@param {type:"raw"} # TODO:更改此处,修改imagenet标签(标签类别映射文件在VAR/dataset中),决定了生成的图像类别
more_smooth = False # True for more smooth output
# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode(): # 推理,生成图像张量
with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True): # 设置训练精度
recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)
chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0) # 处理显示图像
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.save("./result/inference.png") # TODO:更改此处,指定推理结果存储地址
# chw.show()
\ No newline at end of file
export CUDA_VISIBLE_DEVICES=1,2,3,4
export HSA_FORCE_FINE_GRAIN_PCIE=1 #多卡用,强制开启PCIe细粒度模式,有助于提升多卡间通信效率。
export USE_MIOPEN_BATCHNORM=1 # 多GPU进行并行计算时的性能优化
# torchrun: 这是启动分布式训练的命令。
# --nproc_per_node=8: 指定每个节点上使用的进程数(即每个节点上参与训练的 GPU 数量)。
# --nnodes=...: 指定参与训练的节点总数。
# --node_rank=...: 指定当前节点的编号(从 0 开始)。
# --master_addr=...: 指定主节点的 IP 地址。
# --master_port=...: 指定主节点上用于通信的端口号。
# --depth=16: 设置模型的深度。
# --bs=768: 设置批处理大小。
# --ep=200: 设置训练的总轮数(epoch)。
# --fp16=1: 启用FP16训练。
# --alng=1e-3: 初始化ada_lin.w[gamma channels]
# --wpe=0.1: 训练结束时的最终lr
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 train.py \
--depth=16 --bs=192 --ep=5 --fp16=1 --alng=1e-3 --wpe=0.1 \
--data_path=/home/VAR/dataset
\ No newline at end of file
import gc
import os
import shutil
import sys
import time
import warnings
from functools import partial
import torch
from torch.utils.data import DataLoader
import dist
from utils import arg_util, misc
from utils.data import build_dataset
from utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler
from utils.misc import auto_resume
def build_everything(args: arg_util.Args):
# resume
auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth')
# create tensorboard logger
tb_lg: misc.TensorboardLogger
with_tb_lg = dist.is_master()
if with_tb_lg:
os.makedirs(args.tb_log_dir_path, exist_ok=True)
# noinspection PyTypeChecker
tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_path, filename_suffix=f'__{misc.time_str("%m%d_%H%M")}'), verbose=True)
tb_lg.flush()
else:
# noinspection PyTypeChecker
tb_lg = misc.DistLogger(None, verbose=False)
dist.barrier()
# log args
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
print(f'initial args:\n{str(args)}')
# build data
if not args.local_debug:
print(f'[build PT data] ...\n')
num_classes, dataset_train, dataset_val = build_dataset(
args.data_path, final_reso=args.data_load_reso, hflip=args.hflip, mid_reso=args.mid_reso,
)
types = str((type(dataset_train).__name__, type(dataset_val).__name__))
ld_val = DataLoader(
dataset_val, num_workers=0, pin_memory=True,
batch_size=round(args.batch_size*1.5), sampler=EvalDistributedSampler(dataset_val, num_replicas=dist.get_world_size(), rank=dist.get_rank()),
shuffle=False, drop_last=False,
)
del dataset_val
ld_train = DataLoader(
dataset=dataset_train, num_workers=args.workers, pin_memory=True,
generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn,
batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, same_seed_for_all_ranks=args.same_seed_for_all_ranks,
shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it,
),
)
del dataset_train
[print(line) for line in auto_resume_info]
print(f'[dataloader multi processing] ...', end='', flush=True)
stt = time.time()
iters_train = len(ld_train)
ld_train = iter(ld_train)
# noinspection PyArgumentList
print(f' [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True)
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}, types(tr, va)={types}')
else:
num_classes = 1000
ld_val = ld_train = None
iters_train = 10
# build models
from torch.nn.parallel import DistributedDataParallel as DDP
from models import VAR, VQVAE, build_vae_var
from trainer import VARTrainer
from utils.amp_sc import AmpOptimizer
from utils.lr_control import filter_params
vae_local, var_wo_ddp = build_vae_var(
V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters
device=dist.get_device(), patch_nums=args.patch_nums,
num_classes=num_classes, depth=args.depth, shared_aln=args.saln, attn_l2_norm=args.anorm,
flash_if_available=args.fuse, fused_if_available=args.fuse,
init_adaln=args.aln, init_adaln_gamma=args.alng, init_head=args.hd, init_std=args.ini,
)
vae_ckpt = './checkpoint/vae_ch160v4096z32.pth' #TODO
if dist.is_local_master():
if not os.path.exists(vae_ckpt):
os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}')
dist.barrier()
vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
vae_local: VQVAE = args.compile_model(vae_local, args.vfast)
var_wo_ddp: VAR = args.compile_model(var_wo_ddp, args.tfast)
var: DDP = (DDP if dist.initialized() else NullDDP)(var_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
print(f'[INIT] VAR model = {var_wo_ddp}\n\n')
count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}'
print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAE', vae_local), ('VAE.enc', vae_local.encoder), ('VAE.dec', vae_local.decoder), ('VAE.quant', vae_local.quantize))]))
print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAR', var_wo_ddp),)]) + '\n\n')
# build optimizer
names, paras, para_groups = filter_params(var_wo_ddp, nowd_keys={
'cls_token', 'start_token', 'task_token', 'cfg_uncond',
'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed',
'gamma', 'beta',
'ada_gss', 'moe_bias',
'scale_mul',
})
opt_clz = {
'adam': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
'adamw': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
}[args.opt.lower().strip()]
opt_kw = dict(lr=args.tlr, weight_decay=0)
print(f'[INIT] optim={opt_clz}, opt_kw={opt_kw}\n')
var_optim = AmpOptimizer(
mixed_precision=args.fp16, optimizer=opt_clz(params=para_groups, **opt_kw), names=names, paras=paras,
grad_clip=args.tclip, n_gradient_accumulation=args.ac
)
del names, paras, para_groups
# build trainer
trainer = VARTrainer(
device=args.device, patch_nums=args.patch_nums, resos=args.resos,
vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var,
var_opt=var_optim, label_smooth=args.ls,
)
if trainer_state is not None and len(trainer_state):
trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) # don't load vae again
del vae_local, var_wo_ddp, var, var_optim
if args.local_debug:
rng = torch.Generator('cpu')
rng.manual_seed(0)
B = 4
inp = torch.rand(B, 3, args.data_load_reso, args.data_load_reso)
label = torch.ones(B, dtype=torch.long)
me = misc.MetricLogger(delimiter=' ')
trainer.train_step(
it=0, g_it=0, stepping=True, metric_lg=me, tb_lg=tb_lg,
inp_B3HW=inp, label_B=label, prog_si=args.pg0, prog_wp_it=20,
)
trainer.load_state_dict(trainer.state_dict())
trainer.train_step(
it=99, g_it=599, stepping=True, metric_lg=me, tb_lg=tb_lg,
inp_B3HW=inp, label_B=label, prog_si=-1, prog_wp_it=20,
)
print({k: meter.global_avg for k, meter in me.meters.items()})
args.dump_log(); tb_lg.flush(); tb_lg.close()
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
sys.stdout.close(), sys.stderr.close()
exit(0)
dist.barrier()
return (
tb_lg, trainer, start_ep, start_it,
iters_train, ld_train, ld_val
)
def main_training():
args: arg_util.Args = arg_util.init_dist_and_get_args()
if args.local_debug:
torch.autograd.set_detect_anomaly(True)
(
tb_lg, trainer,
start_ep, start_it,
iters_train, ld_train, ld_val
) = build_everything(args)
# train
start_time = time.time()
best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1.
best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1
L_mean, L_tail = -1, -1
for ep in range(start_ep, args.ep):
if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'):
ld_train.sampler.set_epoch(ep)
if ep < 3:
# noinspection PyArgumentList
print(f'[{type(ld_train).__name__}] [ld_train.sampler.set_epoch({ep})]', flush=True, force=True)
tb_lg.set_step(ep * iters_train)
stats, (sec, remain_time, finish_time) = train_one_ep(
ep, ep == start_ep, start_it if ep == start_ep else 0, args, tb_lg, ld_train, iters_train, trainer
)
L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm']
best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean)
if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail)
args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm
args.cur_ep = f'{ep+1}/{args.ep}'
args.remain_time, args.finish_time = remain_time, finish_time
AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail)
is_val_and_also_saving = (ep + 1) % 1 == 0 or (ep + 1) == args.ep # TODO:修改为每跑一轮存一下检查点,原代码为is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep
if is_val_and_also_saving:
val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val)
best_updated = best_val_loss_tail > val_loss_tail
best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail)
best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail)
AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail)
args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail
print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s')
if dist.is_local_master():
local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')
local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth')
print(f'[saving ckpt] ...', end='', flush=True)
torch.save({
'epoch': ep+1,
'iter': 0,
'trainer': trainer.state_dict(),
'args': args.state_dict(),
}, local_out_ckpt)
if best_updated:
shutil.copy(local_out_ckpt, local_out_ckpt_best)
print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True)
dist.barrier()
print( f' [ep{ep}] (training ) Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True)
tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss)
tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2))
args.dump_log(); tb_lg.flush()
total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h'
print('\n\n')
print(f' [*] [PT finished] Total cost: {total_time}, Lm: {best_L_mean:.3f} ({L_mean}), Lt: {best_L_tail:.3f} ({L_tail})')
print('\n\n')
del stats
del iters_train, ld_train
time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
args.remain_time, args.finish_time = '-', time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() - 60))
print(f'final args:\n\n{str(args)}')
args.dump_log(); tb_lg.flush(); tb_lg.close()
dist.barrier()
def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer):
# import heavy packages after Dataloader object creation
from trainer import VARTrainer
from utils.lr_control import lr_wd_annealing
trainer: VARTrainer
step_cnt = 0
me = misc.MetricLogger(delimiter=' ')
me.add_meter('tlr', misc.SmoothedValue(window_size=1, fmt='{value:.2g}'))
me.add_meter('tnm', misc.SmoothedValue(window_size=1, fmt='{value:.2f}'))
[me.add_meter(x, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) for x in ['Lm', 'Lt']]
[me.add_meter(x, misc.SmoothedValue(fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Accm', 'Acct']]
header = f'[Ep]: [{ep:4d}/{args.ep}]'
if is_first_ep:
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
g_it, max_it = ep * iters_train, args.ep * iters_train
for it, (inp, label) in me.log_every(start_it, iters_train, ld_or_itrt, 30 if iters_train > 8000 else 5, header):
g_it = ep * iters_train + it
if it < start_it: continue
if is_first_ep and it == start_it: warnings.resetwarnings()
inp = inp.to(args.device, non_blocking=True)
label = label.to(args.device, non_blocking=True)
args.cur_it = f'{it+1}/{iters_train}'
wp_it = args.wp * iters_train
min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe)
args.cur_lr, args.cur_wd = max_tlr, max_twd
if args.pg: # default: args.pg == 0.0, means no progressive training, won't get into this
if g_it <= wp_it: prog_si = args.pg0
elif g_it >= max_it*args.pg: prog_si = len(args.patch_nums) - 1
else:
delta = len(args.patch_nums) - 1 - args.pg0
progress = min(max((g_it - wp_it) / (max_it*args.pg - wp_it), 0), 1) # from 0 to 1
prog_si = args.pg0 + round(progress * delta) # from args.pg0 to len(args.patch_nums)-1
else:
prog_si = -1
stepping = (g_it + 1) % args.ac == 0
step_cnt += int(stepping)
grad_norm, scale_log2 = trainer.train_step(
it=it, g_it=g_it, stepping=stepping, metric_lg=me, tb_lg=tb_lg,
inp_B3HW=inp, label_B=label, prog_si=prog_si, prog_wp_it=args.pgwp * iters_train,
)
me.update(tlr=max_tlr)
tb_lg.set_step(step=g_it)
tb_lg.update(head='AR_opt_lr/lr_min', sche_tlr=min_tlr)
tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr)
tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd)
tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd)
tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)
if args.tclip > 0:
tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm)
tb_lg.update(head='AR_opt_grad/grad', grad_clip=args.tclip)
me.synchronize_between_processes()
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15) # +15: other cost
class NullDDP(torch.nn.Module):
def __init__(self, module, *args, **kwargs):
super(NullDDP, self).__init__()
self.module = module
self.require_backward_grad_sync = False
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
if __name__ == '__main__':
try: main_training()
finally:
dist.finalize()
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
sys.stdout.close(), sys.stderr.close()
import time
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
import dist
from models import VAR, VQVAE, VectorQuantizer2
from utils.amp_sc import AmpOptimizer
from utils.misc import MetricLogger, TensorboardLogger
Ten = torch.Tensor
FTen = torch.Tensor
ITen = torch.LongTensor
BTen = torch.BoolTensor
class VARTrainer(object):
def __init__(
self, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],
vae_local: VQVAE, var_wo_ddp: VAR, var: DDP,
var_opt: AmpOptimizer, label_smooth: float,
):
super(VARTrainer, self).__init__()
self.var, self.vae_local, self.quantize_local = var, vae_local, vae_local.quantize
self.quantize_local: VectorQuantizer2
self.var_wo_ddp: VAR = var_wo_ddp # after torch.compile
self.var_opt = var_opt
del self.var_wo_ddp.rng
self.var_wo_ddp.rng = torch.Generator(device=device)
self.label_smooth = label_smooth
self.train_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth, reduction='none')
self.val_loss = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean')
self.L = sum(pn * pn for pn in patch_nums)
self.last_l = patch_nums[-1] * patch_nums[-1]
self.loss_weight = torch.ones(1, self.L, device=device) / self.L
self.patch_nums, self.resos = patch_nums, resos
self.begin_ends = []
cur = 0
for i, pn in enumerate(patch_nums):
self.begin_ends.append((cur, cur + pn * pn))
cur += pn*pn
self.prog_it = 0
self.last_prog_si = -1
self.first_prog = True
@torch.no_grad()
def eval_ep(self, ld_val: DataLoader):
tot = 0
L_mean, L_tail, acc_mean, acc_tail = 0, 0, 0, 0
stt = time.time()
training = self.var_wo_ddp.training
self.var_wo_ddp.eval()
for inp_B3HW, label_B in ld_val:
B, V = label_B.shape[0], self.vae_local.vocab_size
inp_B3HW = inp_B3HW.to(dist.get_device(), non_blocking=True)
label_B = label_B.to(dist.get_device(), non_blocking=True)
gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
gt_BL = torch.cat(gt_idx_Bl, dim=1)
x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
self.var_wo_ddp.forward
logits_BLV = self.var_wo_ddp(label_B, x_BLCv_wo_first_l)
L_mean += self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)) * B
L_tail += self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)) * B
acc_mean += (logits_BLV.data.argmax(dim=-1) == gt_BL).sum() * (100/gt_BL.shape[1])
acc_tail += (logits_BLV.data[:, -self.last_l:].argmax(dim=-1) == gt_BL[:, -self.last_l:]).sum() * (100 / self.last_l)
tot += B
self.var_wo_ddp.train(training)
stats = L_mean.new_tensor([L_mean.item(), L_tail.item(), acc_mean.item(), acc_tail.item(), tot])
dist.allreduce(stats)
tot = round(stats[-1].item())
stats /= tot
L_mean, L_tail, acc_mean, acc_tail, _ = stats.tolist()
return L_mean, L_tail, acc_mean, acc_tail, tot, time.time()-stt
def train_step(
self, it: int, g_it: int, stepping: bool, metric_lg: MetricLogger, tb_lg: TensorboardLogger,
inp_B3HW: FTen, label_B: Union[ITen, FTen], prog_si: int, prog_wp_it: float,
) -> Tuple[Optional[Union[Ten, float]], Optional[float]]:
# if progressive training
self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = prog_si
if self.last_prog_si != prog_si:
if self.last_prog_si != -1: self.first_prog = False
self.last_prog_si = prog_si
self.prog_it = 0
self.prog_it += 1
prog_wp = max(min(self.prog_it / prog_wp_it, 1), 0.01)
if self.first_prog: prog_wp = 1 # no prog warmup at first prog stage, as it's already solved in wp
if prog_si == len(self.patch_nums) - 1: prog_si = -1 # max prog, as if no prog
# forward
B, V = label_B.shape[0], self.vae_local.vocab_size
self.var.require_backward_grad_sync = stepping
gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
gt_BL = torch.cat(gt_idx_Bl, dim=1)
x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
with self.var_opt.amp_ctx:
self.var_wo_ddp.forward
logits_BLV = self.var(label_B, x_BLCv_wo_first_l)
loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1)
if prog_si >= 0: # in progressive training
bg, ed = self.begin_ends[prog_si]
assert logits_BLV.shape[1] == gt_BL.shape[1] == ed
lw = self.loss_weight[:, :ed].clone()
lw[:, bg:ed] *= min(max(prog_wp, 0), 1)
else: # not in progressive training
lw = self.loss_weight
loss = loss.mul(lw).sum(dim=-1).mean()
# backward
grad_norm, scale_log2 = self.var_opt.backward_clip_step(loss=loss, stepping=stepping)
# log
pred_BL = logits_BLV.data.argmax(dim=-1)
if it == 0 or it in metric_lg.log_iters:
Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item()
acc_mean = (pred_BL == gt_BL).float().mean().item() * 100
if prog_si >= 0: # in progressive training
Ltail = acc_tail = -1
else: # not in progressive training
Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item()
acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100
grad_norm = grad_norm.item()
metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm)
# log to tensorboard
if g_it == 0 or (g_it + 1) % 500 == 0:
prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float()
dist.allreduce(prob_per_class_is_chosen)
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
if dist.is_master():
if g_it == 0:
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)
kw = dict(z_voc_usage=cluster_usage)
for si, (bg, ed) in enumerate(self.begin_ends):
if 0 <= prog_si < si: break
pred, tar = logits_BLV.data[:, bg:ed].reshape(-1, V), gt_BL[:, bg:ed].reshape(-1)
acc = (pred.argmax(dim=-1) == tar).float().mean().item() * 100
ce = self.val_loss(pred, tar).item()
kw[f'acc_{self.resos[si]}'] = acc
kw[f'L_{self.resos[si]}'] = ce
tb_lg.update(head='AR_iter_loss', **kw, step=g_it)
tb_lg.update(head='AR_iter_schedule', prog_a_reso=self.resos[prog_si], prog_si=prog_si, prog_wp=prog_wp, step=g_it)
self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = -1
return grad_norm, scale_log2
def get_config(self):
return {
'patch_nums': self.patch_nums, 'resos': self.resos,
'label_smooth': self.label_smooth,
'prog_it': self.prog_it, 'last_prog_si': self.last_prog_si, 'first_prog': self.first_prog,
}
def state_dict(self):
state = {'config': self.get_config()}
for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
m = getattr(self, k)
if m is not None:
if hasattr(m, '_orig_mod'):
m = m._orig_mod
state[k] = m.state_dict()
return state
def load_state_dict(self, state, strict=True, skip_vae=False):
for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
if skip_vae and 'vae' in k: continue
m = getattr(self, k)
if m is not None:
if hasattr(m, '_orig_mod'):
m = m._orig_mod
ret = m.load_state_dict(state[k], strict=strict)
if ret is not None:
missing, unexpected = ret
print(f'[VARTrainer.load_state_dict] {k} missing: {missing}')
print(f'[VARTrainer.load_state_dict] {k} unexpected: {unexpected}')
config: dict = state.pop('config', None)
self.prog_it = config.get('prog_it', 0)
self.last_prog_si = config.get('last_prog_si', -1)
self.first_prog = config.get('first_prog', True)
if config is not None:
for k, v in self.get_config().items():
if config.get(k, None) != v:
err = f'[VAR.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})'
if strict: raise AttributeError(err)
else: print(err)
import math
from typing import List, Optional, Tuple, Union
import torch
class NullCtx:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class AmpOptimizer:
def __init__(
self,
mixed_precision: int,
optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],
grad_clip: float, n_gradient_accumulation: int = 1,
):
self.enable_amp = mixed_precision > 0
self.using_fp16_rather_bf16 = mixed_precision == 1
if self.enable_amp:
self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)
self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler
else:
self.amp_ctx = NullCtx()
self.scaler = None
self.optimizer, self.names, self.paras = optimizer, names, paras # paras have been filtered so everyone requires grad
self.grad_clip = grad_clip
self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')
self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation
def backward_clip_step(
self, stepping: bool, loss: torch.Tensor,
) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:
# backward
loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
orig_norm = scaler_sc = None
if self.scaler is not None:
self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)
else:
loss.backward(retain_graph=False, create_graph=False)
if stepping:
if self.scaler is not None: self.scaler.unscale_(self.optimizer)
if self.early_clipping:
orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)
if self.scaler is not None:
self.scaler.step(self.optimizer)
scaler_sc: float = self.scaler.get_scale()
if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
self.scaler.update(new_scale=32768.)
else:
self.scaler.update()
try:
scaler_sc = float(math.log2(scaler_sc))
except Exception as e:
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
raise e
else:
self.optimizer.step()
if self.late_clipping:
orig_norm = self.optimizer.global_grad_norm
self.optimizer.zero_grad(set_to_none=True)
return orig_norm, scaler_sc
def state_dict(self):
return {
'optimizer': self.optimizer.state_dict()
} if self.scaler is None else {
'scaler': self.scaler.state_dict(),
'optimizer': self.optimizer.state_dict()
}
def load_state_dict(self, state, strict=True):
if self.scaler is not None:
try: self.scaler.load_state_dict(state['scaler'])
except Exception as e: print(f'[fp16 load_state_dict err] {e}')
self.optimizer.load_state_dict(state['optimizer'])
import json
import os
import random
import re
import subprocess
import sys
import time
from collections import OrderedDict
from typing import Optional, Union
import numpy as np
import torch
try:
from tap import Tap
except ImportError as e:
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
time.sleep(5)
raise e
import dist
class Args(Tap):
data_path: str = '/path/to/imagenet'
exp_name: str = 'text'
# VAE
vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
# VAR
tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
depth: int = 16 # VAR depth
# VAR initialization
ini: float = -1 # -1: automated model parameter initialization
hd: float = 0.02 # head.w *= hd
aln: float = 0.5 # the multiplier of ada_lin.w's initialization
alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization
# VAR optimization
fp16: int = 0 # 1: using fp16, 2: bf16
tblr: float = 1e-4 # base lr
tlr: float = None # lr = base lr * (bs / 256)
twd: float = 0.05 # initial wd
twde: float = 0 # final wd, =twde or twd
tclip: float = 2. # <=0 for not using grad clip
ls: float = 0.0 # label smooth
bs: int = 768 # global batch size
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
ac: int = 1 # gradient accumulation
ep: int = 250
wp: float = 0
wp0: float = 0.005 # initial lr ratio at the begging of lr warm up
wpe: float = 0.01 # final lr ratio at the end of training
sche: str = 'lin0' # lr schedule
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
afuse: bool = True # fused adamw
# other hps
saln: bool = False # whether to use shared adaln
anorm: bool = True # whether to use L2 normalized attention
fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
# data
pn: str = '1_2_3_4_5_6_8_10_13_16'
patch_size: int = 16
patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size
mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
hflip: bool = False # augmentation: horizontal flip
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
# progressive training
pg: float = 0.0 # >0 for use progressive training during [0%, this] of training
pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
pgwp: float = 0 # num of warmup epochs at each progressive stage
# would be automatically set in runtime
cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this]
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
acc_mean: float = None # [automatically set; don't specify this]
acc_tail: float = None # [automatically set; don't specify this]
L_mean: float = None # [automatically set; don't specify this]
L_tail: float = None # [automatically set; don't specify this]
vacc_mean: float = None # [automatically set; don't specify this]
vacc_tail: float = None # [automatically set; don't specify this]
vL_mean: float = None # [automatically set; don't specify this]
vL_tail: float = None # [automatically set; don't specify this]
grad_norm: float = None # [automatically set; don't specify this]
cur_lr: float = None # [automatically set; don't specify this]
cur_wd: float = None # [automatically set; don't specify this]
cur_it: str = '' # [automatically set; don't specify this]
cur_ep: str = '' # [automatically set; don't specify this]
remain_time: str = '' # [automatically set; don't specify this]
finish_time: str = '' # [automatically set; don't specify this]
# environment
local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this]
tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this]
log_txt_path: str = '...' # [automatically set; don't specify this]
last_ckpt_path: str = '...' # [automatically set; don't specify this]
tf32: bool = True # whether to use TensorFloat32
device: str = 'cpu' # [automatically set; don't specify this]
seed: int = None # seed
def seed_everything(self, benchmark: bool):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = benchmark
if self.seed is None:
torch.backends.cudnn.deterministic = False
else:
torch.backends.cudnn.deterministic = True
seed = self.seed * dist.get_world_size() + dist.get_rank()
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
if self.seed is None:
return None
g = torch.Generator()
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
return g
local_debug: bool = 'KEVIN_LOCAL' in os.environ
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
def compile_model(self, m, fast):
if fast == 0 or self.local_debug:
return m
return torch.compile(m, mode={
1: 'reduce-overhead',
2: 'max-autotune',
3: 'default',
}[fast]) if hasattr(torch, 'compile') else m
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
d = (OrderedDict if key_ordered else dict)()
# self.as_dict() would contain methods, but we only need variables
for k in self.class_variables.keys():
if k not in {'device'}: # these are not serializable
d[k] = getattr(self, k)
return d
def load_state_dict(self, d: Union[OrderedDict, dict, str]):
if isinstance(d, str): # for compatibility with old version
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
for k in d.keys():
try:
setattr(self, k, d[k])
except Exception as e:
print(f'k={k}, v={d[k]}')
raise e
@staticmethod
def set_tf32(tf32: bool):
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
def dump_log(self):
if not dist.is_local_master():
return
if '1/' in self.cur_ep: # first time to dump log
with open(self.log_txt_path, 'w') as fp:
json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
fp.write('\n')
log_dict = {}
for k, v in {
'it': self.cur_it, 'ep': self.cur_ep,
'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
'remain_time': self.remain_time, 'finish_time': self.finish_time,
}.items():
if hasattr(v, 'item'): v = v.item()
log_dict[k] = v
with open(self.log_txt_path, 'a') as fp:
fp.write(f'{log_dict}\n')
def __str__(self):
s = []
for k in self.class_variables.keys():
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
s.append(f' {k:20s}: {getattr(self, k)}')
s = '\n'.join(s)
return f'{{\n{s}\n}}\n'
def init_dist_and_get_args():
for i in range(len(sys.argv)):
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
del sys.argv[i]
break
args = Args(explicit_bool=True).parse_args(known_only=True)
if args.local_debug:
args.pn = '1_2_3'
args.seed = 1
args.aln = 1e-2
args.alng = 1e-5
args.saln = False
args.afuse = False
args.pg = 0.8
args.pg0 = 1
else:
if args.data_path == '/path/to/imagenet':
raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}')
# warn args.extra_args
if len(args.extra_args) > 0:
print(f'======================================================================================')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
print(f'======================================================================================\n\n')
# init torch distributed
from utils import misc
os.makedirs(args.local_out_dir_path, exist_ok=True)
misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
# set env
args.set_tf32(args.tf32)
args.seed_everything(benchmark=args.pg == 0)
# update args: data loading
args.device = dist.get_device()
if args.pn == '256':
args.pn = '1_2_3_4_5_6_8_10_13_16'
elif args.pn == '512':
args.pn = '1_2_3_4_6_9_13_18_24_32'
elif args.pn == '1024':
args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
args.data_load_reso = max(args.resos)
# update args: bs and lr
bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
args.batch_size = bs_per_gpu
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
args.workers = min(max(0, args.workers), args.batch_size)
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
args.twde = args.twde or args.twd
if args.wp == 0:
args.wp = args.ep * 1/50
# update args: progressive training
if args.pgwp == 0:
args.pgwp = args.ep * 1/300
if args.pg > 0:
args.sche = f'lin{args.pg:g}'
# update args: paths
args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
_reg_valid_name = re.compile(r'[^\w\-+,.]')
tb_name = _reg_valid_name.sub(
'_',
f'tb-VARd{args.depth}'
f'__pn{args.pn}'
f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'
)
args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
return args
import os.path as osp
import PIL.Image as PImage
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
from torchvision.transforms import InterpolationMode, transforms
def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
return x.add(x).add_(-1)
def build_dataset(
data_path: str, final_reso: int,
hflip=False, mid_reso=1.125,
):
# build augmentations
mid_reso = round(mid_reso * final_reso) # first resize to mid_reso, then crop to final_reso
train_aug, val_aug = [
transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
transforms.RandomCrop((final_reso, final_reso)),
transforms.ToTensor(), normalize_01_into_pm1,
], [
transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
transforms.CenterCrop((final_reso, final_reso)),
transforms.ToTensor(), normalize_01_into_pm1,
]
if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip())
train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug)
# build dataset
train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug)
val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug)
num_classes = 1000
print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}')
print_aug(train_aug, '[train]')
print_aug(val_aug, '[val]')
return num_classes, train_set, val_set
def pil_loader(path):
with open(path, 'rb') as f:
img: PImage.Image = PImage.open(f).convert('RGB')
return img
def print_aug(transform, label):
print(f'Transform {label} = ')
if hasattr(transform, 'transforms'):
for t in transform.transforms:
print(t)
else:
print(transform)
print('---------------------------\n')
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
class EvalDistributedSampler(Sampler):
def __init__(self, dataset, num_replicas, rank):
seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int)
beg, end = seps[:-1], seps[1:]
beg, end = beg[rank], end[rank]
self.indices = tuple(range(beg, end))
def __iter__(self):
return iter(self.indices)
def __len__(self) -> int:
return len(self.indices)
class InfiniteBatchSampler(Sampler):
def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0):
self.dataset_len = dataset_len
self.batch_size = batch_size
self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
self.max_p = self.iters_per_ep * batch_size
self.fill_last = fill_last
self.shuffle = shuffle
self.epoch = start_ep
self.same_seed_for_all_ranks = seed_for_all_rank
self.indices = self.gener_indices()
self.start_ep, self.start_it = start_ep, start_it
def gener_indices(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
indices = torch.randperm(self.dataset_len, generator=g).numpy()
else:
indices = torch.arange(self.dataset_len).numpy()
tails = self.batch_size - (self.dataset_len % self.batch_size)
if tails != self.batch_size and self.fill_last:
tails = indices[:tails]
np.random.shuffle(indices)
indices = np.concatenate((indices, tails))
# built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
# noinspection PyTypeChecker
return tuple(indices.tolist())
def __iter__(self):
self.epoch = self.start_ep
while True:
self.epoch += 1
p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0
while p < self.max_p:
q = p + self.batch_size
yield self.indices[p:q]
p = q
if self.shuffle:
self.indices = self.gener_indices()
def __len__(self):
return self.iters_per_ep
class DistInfiniteBatchSampler(InfiniteBatchSampler):
def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0):
assert glb_batch_size % world_size == 0
self.world_size, self.rank = world_size, rank
self.dataset_len = dataset_len
self.glb_batch_size = glb_batch_size
self.batch_size = glb_batch_size // world_size
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
self.fill_last = fill_last
self.shuffle = shuffle
self.repeated_aug = repeated_aug
self.epoch = start_ep
self.same_seed_for_all_ranks = same_seed_for_all_ranks
self.indices = self.gener_indices()
self.start_ep, self.start_it = start_ep, start_it
def gener_indices(self):
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
# print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}')
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
global_indices = torch.randperm(self.dataset_len, generator=g)
if self.repeated_aug > 1:
global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
else:
global_indices = torch.arange(self.dataset_len)
filling = global_max_p - global_indices.shape[0]
if filling > 0 and self.fill_last:
global_indices = torch.cat((global_indices, global_indices[:filling]))
# global_indices = tuple(global_indices.numpy().tolist())
seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int)
local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist()
self.max_p = len(local_indices)
return local_indices
import math
from pprint import pformat
from typing import Tuple, List, Dict, Union
import torch.nn
import dist
def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
"""Decay the learning rate with half-cycle cosine after warmup"""
wp_it = round(wp_it)
if cur_it < wp_it:
cur_lr = wp0 + (1-wp0) * cur_it / wp_it
else:
pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
rest = 1 - pasd # [1, 0]
if sche_type == 'cos':
cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
elif sche_type == 'lin':
T = 0.15; max_rest = 1-T
if pasd < T: cur_lr = 1
else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
elif sche_type == 'lin0':
T = 0.05; max_rest = 1-T
if pasd < T: cur_lr = 1
else: cur_lr = wpe + (1-wpe) * rest / max_rest
elif sche_type == 'lin00':
cur_lr = wpe + (1-wpe) * rest
elif sche_type.startswith('lin'):
T = float(sche_type[3:]); max_rest = 1-T
wpe_mid = wpe + (1-wpe) * max_rest
wpe_mid = (1 + wpe_mid) / 2
if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
elif sche_type == 'exp':
T = 0.15; max_rest = 1-T
if pasd < T: cur_lr = 1
else:
expo = (pasd-T) / max_rest * math.log(wpe)
cur_lr = math.exp(expo)
else:
raise NotImplementedError(f'unknown sche_type {sche_type}')
cur_lr *= peak_lr
pasd = cur_it / (max_it-1)
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
inf = 1e6
min_lr, max_lr = inf, -1
min_wd, max_wd = inf, -1
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
max_lr = max(max_lr, param_group['lr'])
min_lr = min(min_lr, param_group['lr'])
param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
max_wd = max(max_wd, param_group['weight_decay'])
if param_group['weight_decay'] > 0:
min_wd = min(min_wd, param_group['weight_decay'])
if min_lr == inf: min_lr = -1
if min_wd == inf: min_wd = -1
return min_lr, max_lr, min_wd, max_wd
def filter_params(model, nowd_keys=()) -> Tuple[
List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
]:
para_groups, para_groups_dbg = {}, {}
names, paras = [], []
names_no_grad = []
count, numel = 0, 0
for name, para in model.named_parameters():
name = name.replace('_fsdp_wrapped_module.', '')
if not para.requires_grad:
names_no_grad.append(name)
continue # frozen weights
count += 1
numel += para.numel()
names.append(name)
paras.append(para)
if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
cur_wd_sc, group_name = 0., 'ND'
else:
cur_wd_sc, group_name = 1., 'D'
cur_lr_sc = 1.
if group_name not in para_groups:
para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
para_groups[group_name]['params'].append(para)
para_groups_dbg[group_name]['params'].append(name)
for g in para_groups_dbg.values():
g['params'] = pformat(', '.join(g['params']), width=200)
print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
for rk in range(dist.get_world_size()):
dist.barrier()
if dist.get_rank() == rk:
print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
print('')
assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
return names, paras, list(para_groups.values())
import datetime
import functools
import glob
import os
import subprocess
import sys
import time
from collections import defaultdict, deque
from typing import Iterator, List, Tuple
import numpy as np
import pytz
import torch
import torch.distributed as tdist
import dist
from utils import arg_util
os_system = functools.partial(subprocess.call, shell=True)
def echo(info):
os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
def os_system_get_stdout(cmd):
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
def os_system_get_stdout_stderr(cmd):
cnt = 0
while True:
try:
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
except subprocess.TimeoutExpired:
cnt += 1
print(f'[fetch free_port file] timeout cnt={cnt}')
else:
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
def time_str(fmt='[%m-%d %H:%M:%S]'):
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
def init_distributed_mode(local_out_path, only_sync_master=False, timeout=30):
try:
dist.initialize(fork=False, timeout=timeout)
dist.barrier()
except RuntimeError:
print(f'{">"*75} NCCL Error {"<"*75}', flush=True)
time.sleep(10)
if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
_change_builtin_print(dist.is_local_master())
if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path):
sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False)
def _change_builtin_print(is_master):
import builtins as __builtin__
builtin_print = __builtin__.print
if type(builtin_print) != type(open):
return
def prt(*args, **kwargs):
force = kwargs.pop('force', False)
clean = kwargs.pop('clean', False)
deeper = kwargs.pop('deeper', False)
if is_master or force:
if not clean:
f_back = sys._getframe().f_back
if deeper and f_back.f_back is not None:
f_back = f_back.f_back
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
else:
builtin_print(*args, **kwargs)
__builtin__.print = prt
class SyncPrint(object):
def __init__(self, local_output_dir, sync_stdout=True):
self.sync_stdout = sync_stdout
self.terminal_stream = sys.stdout if sync_stdout else sys.stderr
fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt')
existing = os.path.exists(fname)
self.file_stream = open(fname, 'a')
if existing:
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str()} ' + '='*55 + '\n')
self.file_stream.flush()
self.enabled = True
def write(self, message):
self.terminal_stream.write(message)
self.file_stream.write(message)
def flush(self):
self.terminal_stream.flush()
self.file_stream.flush()
def close(self):
if not self.enabled:
return
self.enabled = False
self.file_stream.flush()
self.file_stream.close()
if self.sync_stdout:
sys.stdout = self.terminal_stream
sys.stdout.flush()
else:
sys.stderr = self.terminal_stream
sys.stderr.flush()
def __del__(self):
self.close()
class DistLogger(object):
def __init__(self, lg, verbose):
self._lg, self._verbose = lg, verbose
@staticmethod
def do_nothing(*args, **kwargs):
pass
def __getattr__(self, attr: str):
return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing
class TensorboardLogger(object):
def __init__(self, log_dir, filename_suffix):
try: import tensorflow_io as tfio
except: pass
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
self.step = 0
def set_step(self, step=None):
if step is not None:
self.step = step
else:
self.step += 1
def update(self, head='scalar', step=None, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
# assert isinstance(v, (float, int)), type(v)
if step is None: # iter wise
it = self.step
if it == 0 or (it + 1) % 500 == 0:
if hasattr(v, 'item'): v = v.item()
self.writer.add_scalar(f'{head}/{k}', v, it)
else: # epoch wise
if hasattr(v, 'item'): v = v.item()
self.writer.add_scalar(f'{head}/{k}', v, step)
def log_tensor_as_distri(self, tag, tensor1d, step=None):
if step is None: # iter wise
step = self.step
loggable = step == 0 or (step + 1) % 500 == 0
else: # epoch wise
loggable = True
if loggable:
try:
self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
except Exception as e:
print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
def log_image(self, tag, img_chw, step=None):
if step is None: # iter wise
step = self.step
loggable = step == 0 or (step + 1) % 500 == 0
else: # epoch wise
loggable = True
if loggable:
self.writer.add_image(tag, img_chw, step, dataformats='CHW')
def flush(self):
self.writer.flush()
def close(self):
self.writer.close()
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=30, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
tdist.barrier()
tdist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
return np.median(self.deque) if len(self.deque) else 0
@property
def avg(self):
return sum(self.deque) / (len(self.deque) or 1)
@property
def global_avg(self):
return self.total / (self.count or 1)
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1] if len(self.deque) else 0
def time_preds(self, counts) -> Tuple[float, str, str]:
remain_secs = counts * self.median
return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter=' '):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
self.iter_end_t = time.time()
self.log_iters = []
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if hasattr(v, 'item'): v = v.item()
# assert isinstance(v, (float, int)), type(v)
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
if len(meter.deque):
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, start_it, max_iters, itrt, print_freq, header=None):
self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())
self.log_iters.add(start_it)
if not header:
header = ''
start_time = time.time()
self.iter_end_t = time.time()
self.iter_time = SmoothedValue(fmt='{avg:.4f}')
self.data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(max_iters))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
log_msg = self.delimiter.join(log_msg)
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
for i in range(start_it, max_iters):
obj = next(itrt)
self.data_time.update(time.time() - self.iter_end_t)
yield i, obj
self.iter_time.update(time.time() - self.iter_end_t)
if i in self.log_iters:
eta_seconds = self.iter_time.global_avg * (max_iters - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(log_msg.format(
i, max_iters, eta=eta_string,
meters=str(self),
time=str(self.iter_time), data=str(self.data_time)), flush=True)
self.iter_end_t = time.time()
else:
if isinstance(itrt, int): itrt = range(itrt)
for i, obj in enumerate(itrt):
self.data_time.update(time.time() - self.iter_end_t)
yield i, obj
self.iter_time.update(time.time() - self.iter_end_t)
if i in self.log_iters:
eta_seconds = self.iter_time.global_avg * (max_iters - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(log_msg.format(
i, max_iters, eta=eta_string,
meters=str(self),
time=str(self.iter_time), data=str(self.data_time)), flush=True)
self.iter_end_t = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.3f} s / it)'.format(
header, total_time_str, total_time / max_iters), flush=True)
def glob_with_latest_modified_first(pattern, recursive=False):
return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True)
def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]:
info = []
file = os.path.join(args.local_out_dir_path, pattern)
all_ckpt = glob_with_latest_modified_first(file)
if len(all_ckpt) == 0:
info.append(f'[auto_resume] no ckpt found @ {file}')
info.append(f'[auto_resume quit]')
return info, 0, 0, {}, {}
else:
info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...')
ckpt = torch.load(all_ckpt[0], map_location='cpu')
ep, it = ckpt['epoch'], ckpt['iter']
info.append(f'[auto_resume success] resume from ep{ep}, it{it}')
return info, ep, it, ckpt['trainer'], ckpt['args']
def create_npz_from_sample_folder(sample_folder: str):
"""
Builds a single .npz file from a folder of .png samples. Refer to DiT.
"""
import os, glob
import numpy as np
from tqdm import tqdm
from PIL import Image
samples = []
pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG'))
assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000'
for png in tqdm(pngs, desc='Building .npz file from samples (png only)'):
with Image.open(png) as sample_pil:
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3)
npz_path = f'{sample_folder}.npz'
np.savez(npz_path, arr_0=samples)
print(f'Saved .npz file to {npz_path} [shape={samples.shape}].')
return npz_path
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment