################## 1. 下载检查点,构建模型 import os import os.path as osp import torch, torchvision import random import numpy as np import PIL.Image as PImage, PIL.ImageDraw as PImageDraw setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # 禁用默认参数init以获得更快的速度 setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # 禁用默认参数init以获得更快的速度 from models import VQVAE, build_vae_var os.environ["CUDA_VISIBLE_DEVICES"] = "2" print(torch.cuda.get_device_name(0)) MODEL_DEPTH = 16 # TODO:更改此处,指定模型 assert MODEL_DEPTH in {16, 20, 24, 30} # download checkpoint model_path="./checkpoint/" # TODO:更改此处,指定模型地址 hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main' vae_ckpt, var_ckpt = model_path+'vae_ch160v4096z32.pth', model_path+f'var_d{MODEL_DEPTH}.pth' assert os.path.exists(f'{vae_ckpt}') assert os.path.exists(f'{var_ckpt}') # if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}') # if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}') # build vae, var patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16) device = 'cuda' if torch.cuda.is_available() else exit() if 'vae' not in globals() or 'var' not in globals(): vae, var = build_vae_var( V=4096, Cvae=32, ch=160, share_quant_resi=4, # 硬编码VQVAE超参数 device=device, patch_nums=patch_nums, num_classes=1000, depth=MODEL_DEPTH, shared_aln=False, ) # load checkpoints vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True) var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True) vae.eval(), var.eval() for p in vae.parameters(): p.requires_grad_(False) for p in var.parameters(): p.requires_grad_(False) print(f'prepare finished.') ############################# 2. 使用无分类器指导的采样 # set args seed = 0 #@param {type:"number"} torch.manual_seed(seed) num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1} cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1} class_labels = (980, 980, 437, 437, 22, 22, 562, 562) #@param {type:"raw"} # TODO:更改此处,修改imagenet标签(标签类别映射文件在VAR/dataset中),决定了生成的图像类别 more_smooth = False # True for more smooth output # seed torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # run faster tf32 = True torch.backends.cudnn.allow_tf32 = bool(tf32) torch.backends.cuda.matmul.allow_tf32 = bool(tf32) torch.set_float32_matmul_precision('high' if tf32 else 'highest') # sample B = len(class_labels) label_B: torch.LongTensor = torch.tensor(class_labels, device=device) with torch.inference_mode(): # 推理,生成图像张量 with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True): # 设置训练精度 recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth) chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0) # 处理显示图像 chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy() chw = PImage.fromarray(chw.astype(np.uint8)) chw.save("./result/inference.png") # TODO:更改此处,指定推理结果存储地址 # chw.show()