Commit 846290bf authored by suily's avatar suily
Browse files

Update test.py

parent 67f3d069
Pipeline #1903 canceled with stages
......@@ -9,7 +9,7 @@ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # 禁用默
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # 禁用默认参数init以获得更快的速度
from models import VQVAE, build_vae_var
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["HIP_VISIBLE_DEVICES"] = "2"
print(torch.cuda.get_device_name(0))
MODEL_DEPTH = 16 # TODO:更改此处,指定模型
assert MODEL_DEPTH in {16, 20, 24, 30}
......@@ -75,4 +75,4 @@ chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.save("./result/inference.png") # TODO:更改此处,指定推理结果存储地址
# chw.show()
\ No newline at end of file
# chw.show()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment