"driver/vscode:/vscode.git/clone" did not exist on "9a383af9aabf5749a63d875c46ca1cfa92b3acef"
Commit d03ea00f authored by suily's avatar suily
Browse files

Initial commit

parents
Pipeline #1898 canceled with stages
################## 1. 下载检查点,构建模型
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # 禁用默认参数init以获得更快的速度
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # 禁用默认参数init以获得更快的速度
from models import VQVAE, build_vae_var
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
print(torch.cuda.get_device_name(0))
MODEL_DEPTH = 16 # TODO:更改此处,指定模型
assert MODEL_DEPTH in {16, 20, 24, 30}
# download checkpoint
model_path="./checkpoint/" # TODO:更改此处,指定模型地址
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = model_path+'vae_ch160v4096z32.pth', model_path+f'var_d{MODEL_DEPTH}.pth'
assert os.path.exists(f'{vae_ckpt}')
assert os.path.exists(f'{var_ckpt}')
# if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
# if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')
# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else exit()
if 'vae' not in globals() or 'var' not in globals():
vae, var = build_vae_var(
V=4096, Cvae=32, ch=160, share_quant_resi=4, # 硬编码VQVAE超参数
device=device, patch_nums=patch_nums,
num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
)
# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')
############################# 2. 使用无分类器指导的采样
# set args
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 437, 437, 22, 22, 562, 562) #@param {type:"raw"} # TODO:更改此处,修改imagenet标签(标签类别映射文件在VAR/dataset中),决定了生成的图像类别
more_smooth = False # True for more smooth output
# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode(): # 推理,生成图像张量
with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True): # 设置训练精度
recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)
chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0) # 处理显示图像
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.save("./result/inference.png") # TODO:更改此处,指定推理结果存储地址
# chw.show()
\ No newline at end of file
export CUDA_VISIBLE_DEVICES=1,2,3,4
export HSA_FORCE_FINE_GRAIN_PCIE=1 #多卡用,强制开启PCIe细粒度模式,有助于提升多卡间通信效率。
export USE_MIOPEN_BATCHNORM=1 # 多GPU进行并行计算时的性能优化
# torchrun: 这是启动分布式训练的命令。
# --nproc_per_node=8: 指定每个节点上使用的进程数(即每个节点上参与训练的 GPU 数量)。
# --nnodes=...: 指定参与训练的节点总数。
# --node_rank=...: 指定当前节点的编号(从 0 开始)。
# --master_addr=...: 指定主节点的 IP 地址。
# --master_port=...: 指定主节点上用于通信的端口号。
# --depth=16: 设置模型的深度。
# --bs=768: 设置批处理大小。
# --ep=200: 设置训练的总轮数(epoch)。
# --fp16=1: 启用FP16训练。
# --alng=1e-3: 初始化ada_lin.w[gamma channels]
# --wpe=0.1: 训练结束时的最终lr
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 train.py \
--depth=16 --bs=192 --ep=5 --fp16=1 --alng=1e-3 --wpe=0.1 \
--data_path=/home/VAR/dataset
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
import math
from typing import List, Optional, Tuple, Union
import torch
class NullCtx:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class AmpOptimizer:
def __init__(
self,
mixed_precision: int,
optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],
grad_clip: float, n_gradient_accumulation: int = 1,
):
self.enable_amp = mixed_precision > 0
self.using_fp16_rather_bf16 = mixed_precision == 1
if self.enable_amp:
self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)
self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler
else:
self.amp_ctx = NullCtx()
self.scaler = None
self.optimizer, self.names, self.paras = optimizer, names, paras # paras have been filtered so everyone requires grad
self.grad_clip = grad_clip
self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')
self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation
def backward_clip_step(
self, stepping: bool, loss: torch.Tensor,
) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:
# backward
loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
orig_norm = scaler_sc = None
if self.scaler is not None:
self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)
else:
loss.backward(retain_graph=False, create_graph=False)
if stepping:
if self.scaler is not None: self.scaler.unscale_(self.optimizer)
if self.early_clipping:
orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)
if self.scaler is not None:
self.scaler.step(self.optimizer)
scaler_sc: float = self.scaler.get_scale()
if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
self.scaler.update(new_scale=32768.)
else:
self.scaler.update()
try:
scaler_sc = float(math.log2(scaler_sc))
except Exception as e:
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
raise e
else:
self.optimizer.step()
if self.late_clipping:
orig_norm = self.optimizer.global_grad_norm
self.optimizer.zero_grad(set_to_none=True)
return orig_norm, scaler_sc
def state_dict(self):
return {
'optimizer': self.optimizer.state_dict()
} if self.scaler is None else {
'scaler': self.scaler.state_dict(),
'optimizer': self.optimizer.state_dict()
}
def load_state_dict(self, state, strict=True):
if self.scaler is not None:
try: self.scaler.load_state_dict(state['scaler'])
except Exception as e: print(f'[fp16 load_state_dict err] {e}')
self.optimizer.load_state_dict(state['optimizer'])
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment