"ppocr/vscode:/vscode.git/clone" did not exist on "0b34ad5b93f19cf2324308e52bee59daeb164c4d"
Commit 63bde97a authored by chenpangpang's avatar chenpangpang
Browse files

feat: 初始提交

parent 9cf8c6f1
Pipeline #1475 failed with stages
in 0 seconds
pytorch-lightning==2.1.2
gradio==3.41.2
huggingface-hub
einops
omegaconf
torchmetrics
webdataset
accelerate
tensorboard
PyMCubes
trimesh
rembg
transformers==4.34.1
diffusers==0.20.2
bitsandbytes
imageio[ffmpeg]
xatlas
plyfile
git+https://github.com/NVlabs/nvdiffrast/
\ No newline at end of file
import os
import argparse
import numpy as np
import torch
import rembg
from PIL import Image
from torchvision.transforms import v2
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from einops import rearrange, repeat
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
get_zero123plus_input_cameras,
get_circular_camera_poses,
)
from src.utils.mesh_util import save_obj, save_obj_with_mtl
from src.utils.infer_util import remove_background, resize_foreground, save_video
def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False):
"""
Get the rendering camera parameters.
"""
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
if is_flexicubes:
cameras = torch.linalg.inv(c2ws)
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
else:
extrinsics = c2ws.flatten(-2)
intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
return cameras
def render_frames(model, planes, render_cameras, render_size=512, chunk_size=1, is_flexicubes=False):
"""
Render frames from triplanes.
"""
frames = []
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
if is_flexicubes:
frame = model.forward_geometry(
planes,
render_cameras[:, i:i+chunk_size],
render_size=render_size,
)['img']
else:
frame = model.forward_synthesizer(
planes,
render_cameras[:, i:i+chunk_size],
render_size=render_size,
)['images_rgb']
frames.append(frame)
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
return frames
###############################################################################
# Arguments.
###############################################################################
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('input_path', type=str, help='Path to input image or directory.')
parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.')
parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.')
parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.')
parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.')
parser.add_argument('--distance', type=float, default=4.5, help='Render distance.')
parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.')
parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.')
parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.')
parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.')
args = parser.parse_args()
seed_everything(args.seed)
###############################################################################
# Stage 0: Configuration.
###############################################################################
config = OmegaConf.load(args.config)
config_name = os.path.basename(args.config).replace('.yaml', '')
model_config = config.model_config
infer_config = config.infer_config
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
device = torch.device('cuda')
# load diffusion model
print('Loading diffusion model ...')
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="zero123plus",
torch_dtype=torch.float16,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
# load custom white-background UNet
print('Loading custom white-background unet ...')
if os.path.exists(infer_config.unet_path):
unet_ckpt_path = infer_config.unet_path
else:
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True)
pipeline = pipeline.to(device)
# load reconstruction model
print('Loading reconstruction model ...')
model = instantiate_from_config(model_config)
if os.path.exists(infer_config.model_path):
model_ckpt_path = infer_config.model_path
else:
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model")
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
if IS_FLEXICUBES:
model.init_flexicubes_geometry(device, fovy=30.0)
model = model.eval()
# make output directories
image_path = os.path.join(args.output_path, config_name, 'images')
mesh_path = os.path.join(args.output_path, config_name, 'meshes')
video_path = os.path.join(args.output_path, config_name, 'videos')
os.makedirs(image_path, exist_ok=True)
os.makedirs(mesh_path, exist_ok=True)
os.makedirs(video_path, exist_ok=True)
# process input files
if os.path.isdir(args.input_path):
input_files = [
os.path.join(args.input_path, file)
for file in os.listdir(args.input_path)
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
]
else:
input_files = [args.input_path]
print(f'Total number of input images: {len(input_files)}')
###############################################################################
# Stage 1: Multiview generation.
###############################################################################
rembg_session = None if args.no_rembg else rembg.new_session()
outputs = []
for idx, image_file in enumerate(input_files):
name = os.path.basename(image_file).split('.')[0]
print(f'[{idx+1}/{len(input_files)}] Imagining {name} ...')
# remove background optionally
input_image = Image.open(image_file)
if not args.no_rembg:
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
# sampling
output_image = pipeline(
input_image,
num_inference_steps=args.diffusion_steps,
).images[0]
output_image.save(os.path.join(image_path, f'{name}.png'))
print(f"Image saved to {os.path.join(image_path, f'{name}.png')}")
images = np.asarray(output_image, dtype=np.float32) / 255.0
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
outputs.append({'name': name, 'images': images})
# delete pipeline to save memory
del pipeline
###############################################################################
# Stage 2: Reconstruction.
###############################################################################
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0*args.scale).to(device)
chunk_size = 20 if IS_FLEXICUBES else 1
for idx, sample in enumerate(outputs):
name = sample['name']
print(f'[{idx+1}/{len(outputs)}] Creating {name} ...')
images = sample['images'].unsqueeze(0).to(device)
images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1)
if args.view == 4:
indices = torch.tensor([0, 2, 4, 5]).long().to(device)
images = images[:, indices]
input_cameras = input_cameras[:, indices]
with torch.no_grad():
# get triplane
planes = model.forward_planes(images, input_cameras)
# get mesh
mesh_path_idx = os.path.join(mesh_path, f'{name}.obj')
mesh_out = model.extract_mesh(
planes,
use_texture_map=args.export_texmap,
**infer_config,
)
if args.export_texmap:
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
save_obj_with_mtl(
vertices.data.cpu().numpy(),
uvs.data.cpu().numpy(),
faces.data.cpu().numpy(),
mesh_tex_idx.data.cpu().numpy(),
tex_map.permute(1, 2, 0).data.cpu().numpy(),
mesh_path_idx,
)
else:
vertices, faces, vertex_colors = mesh_out
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
print(f"Mesh saved to {mesh_path_idx}")
# get video
if args.save_video:
video_path_idx = os.path.join(video_path, f'{name}.mp4')
render_size = infer_config.render_resolution
render_cameras = get_render_cameras(
batch_size=1,
M=120,
radius=args.distance,
elevation=20.0,
is_flexicubes=IS_FLEXICUBES,
).to(device)
frames = render_frames(
model,
planes,
render_cameras=render_cameras,
render_size=render_size,
chunk_size=chunk_size,
is_flexicubes=IS_FLEXICUBES,
)
save_video(
frames,
video_path_idx,
fps=30,
)
print(f"Video saved to {video_path_idx}")
import os, sys
import math
import json
import importlib
from pathlib import Path
import cv2
import random
import numpy as np
from PIL import Image
import webdataset as wds
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
center_looking_at_camera_pose,
get_circular_camera_poses,
)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size=8,
num_workers=4,
train=None,
validation=None,
test=None,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_configs = dict()
if train is not None:
self.dataset_configs['train'] = train
if validation is not None:
self.dataset_configs['validation'] = validation
if test is not None:
self.dataset_configs['test'] = test
def setup(self, stage):
if stage in ['fit']:
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
else:
raise NotImplementedError
def train_dataloader(self):
sampler = DistributedSampler(self.datasets['train'])
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def val_dataloader(self):
sampler = DistributedSampler(self.datasets['validation'])
return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def test_dataloader(self):
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
class ObjaverseData(Dataset):
def __init__(self,
root_dir='objaverse/',
meta_fname='valid_paths.json',
input_image_dir='rendering_random_32views',
target_image_dir='rendering_random_32views',
input_view_num=6,
target_view_num=4,
total_view_n=32,
fov=50,
camera_rotation=True,
validation=False,
):
self.root_dir = Path(root_dir)
self.input_image_dir = input_image_dir
self.target_image_dir = target_image_dir
self.input_view_num = input_view_num
self.target_view_num = target_view_num
self.total_view_n = total_view_n
self.fov = fov
self.camera_rotation = camera_rotation
with open(os.path.join(root_dir, meta_fname)) as f:
filtered_dict = json.load(f)
paths = filtered_dict['good_objs']
self.paths = paths
self.depth_scale = 6.0
total_objects = len(self.paths)
print('============= length of dataset %d =============' % len(self.paths))
def __len__(self):
return len(self.paths)
def load_im(self, path, color):
'''
replace background pixel with random color in rendering
'''
pil_img = Image.open(path)
image = np.asarray(pil_img, dtype=np.float32) / 255.
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha
def __getitem__(self, index):
while True:
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
input_indices = indices[:self.input_view_num]
target_indices = indices[self.input_view_num:]
'''background color, default: white'''
bg_white = [1., 1., 1.]
bg_black = [0., 0., 0.]
image_list = []
alpha_list = []
depth_list = []
normal_list = []
pose_list = []
try:
input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
for idx in input_indices:
image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
depth = torch.from_numpy(depth).unsqueeze(0)
pose = input_cameras[idx]
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
image_list.append(image)
alpha_list.append(alpha)
depth_list.append(depth)
normal_list.append(normal)
pose_list.append(pose)
target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
for idx in target_indices:
image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
depth = torch.from_numpy(depth).unsqueeze(0)
pose = target_cameras[idx]
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
image_list.append(image)
alpha_list.append(alpha)
depth_list.append(depth)
normal_list.append(normal)
pose_list.append(pose)
except Exception as e:
print(e)
index = np.random.randint(0, len(self.paths))
continue
break
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
c2ws = torch.linalg.inv(w2cs).float()
normals = normals * 2.0 - 1.0
normals = F.normalize(normals, dim=1)
normals = (normals + 1.0) / 2.0
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
# random rotation along z axis
if self.camera_rotation:
degree = np.random.uniform(0, math.pi * 2)
rot = torch.tensor([
[np.cos(degree), -np.sin(degree), 0, 0],
[np.sin(degree), np.cos(degree), 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]).unsqueeze(0).float()
c2ws = torch.matmul(rot, c2ws)
# rotate normals
N, _, H, W = normals.shape
normals = normals * 2.0 - 1.0
normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
normals = F.normalize(normals, dim=1)
normals = (normals + 1.0) / 2.0
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
# random scaling
if np.random.rand() < 0.5:
scale = np.random.uniform(0.7, 1.1)
c2ws[:, :3, 3] *= scale
depths *= scale
# instrinsics of perspective cameras
K = FOV_to_intrinsics(self.fov)
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
data = {
'input_images': images[:self.input_view_num], # (6, 3, H, W)
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4)
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
# lrm generator input and supervision
'target_images': images[self.input_view_num:], # (V, 3, H, W)
'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
}
return data
class ValidationData(Dataset):
def __init__(self,
root_dir='objaverse/',
input_view_num=6,
input_image_size=320,
fov=30,
):
self.root_dir = Path(root_dir)
self.input_view_num = input_view_num
self.input_image_size = input_image_size
self.fov = fov
self.paths = sorted(os.listdir(self.root_dir))
print('============= length of dataset %d =============' % len(self.paths))
cam_distance = 4.0
azimuths = np.array([30, 90, 150, 210, 270, 330])
elevations = np.array([20, -10, 20, -10, 20, -10])
azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations)
x = cam_distance * np.cos(elevations) * np.cos(azimuths)
y = cam_distance * np.cos(elevations) * np.sin(azimuths)
z = cam_distance * np.sin(elevations)
cam_locations = np.stack([x, y, z], axis=-1)
cam_locations = torch.from_numpy(cam_locations).float()
c2ws = center_looking_at_camera_pose(cam_locations)
self.c2ws = c2ws.float()
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0)
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
self.render_c2ws = render_c2ws.float()
self.render_Ks = render_Ks.float()
def __len__(self):
return len(self.paths)
def load_im(self, path, color):
'''
replace background pixel with random color in rendering
'''
pil_img = Image.open(path)
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
image = np.asarray(pil_img, dtype=np.float32) / 255.
if image.shape[-1] == 4:
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)
else:
alpha = np.ones_like(image[:, :, :1])
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha
def __getitem__(self, index):
# load data
input_image_path = os.path.join(self.root_dir, self.paths[index])
'''background color, default: white'''
bkg_color = [1.0, 1.0, 1.0]
image_list = []
alpha_list = []
for idx in range(self.input_view_num):
image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
image_list.append(image)
alpha_list.append(alpha)
images = torch.stack(image_list, dim=0).float()
alphas = torch.stack(alpha_list, dim=0).float()
data = {
'input_images': images,
'input_alphas': alphas,
'input_c2ws': self.c2ws,
'input_Ks': self.Ks,
'render_c2ws': self.render_c2ws,
'render_Ks': self.render_Ks,
}
return data
import os
import json
import numpy as np
import webdataset as wds
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from PIL import Image
from pathlib import Path
from src.utils.train_util import instantiate_from_config
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size=8,
num_workers=4,
train=None,
validation=None,
test=None,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_configs = dict()
if train is not None:
self.dataset_configs['train'] = train
if validation is not None:
self.dataset_configs['validation'] = validation
if test is not None:
self.dataset_configs['test'] = test
def setup(self, stage):
if stage in ['fit']:
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
else:
raise NotImplementedError
def train_dataloader(self):
sampler = DistributedSampler(self.datasets['train'])
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def val_dataloader(self):
sampler = DistributedSampler(self.datasets['validation'])
return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def test_dataloader(self):
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
class ObjaverseData(Dataset):
def __init__(self,
root_dir='objaverse/',
meta_fname='valid_paths.json',
image_dir='rendering_zero123plus',
validation=False,
):
self.root_dir = Path(root_dir)
self.image_dir = image_dir
with open(os.path.join(root_dir, meta_fname)) as f:
lvis_dict = json.load(f)
paths = []
for k in lvis_dict.keys():
paths.extend(lvis_dict[k])
self.paths = paths
total_objects = len(self.paths)
if validation:
self.paths = self.paths[-16:] # used last 16 as validation
else:
self.paths = self.paths[:-16]
print('============= length of dataset %d =============' % len(self.paths))
def __len__(self):
return len(self.paths)
def load_im(self, path, color):
pil_img = Image.open(path)
image = np.asarray(pil_img, dtype=np.float32) / 255.
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha
def __getitem__(self, index):
while True:
image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index])
'''background color, default: white'''
bkg_color = [1., 1., 1.]
img_list = []
try:
for idx in range(7):
img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color)
img_list.append(img)
except Exception as e:
print(e)
index = np.random.randint(0, len(self.paths))
continue
break
imgs = torch.stack(img_list, dim=0).float()
data = {
'cond_imgs': imgs[0], # (3, H, W)
'target_imgs': imgs[1:], # (6, 3, H, W)
}
return data
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import v2
from torchvision.utils import make_grid, save_image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import pytorch_lightning as pl
from einops import rearrange, repeat
from src.utils.train_util import instantiate_from_config
class MVRecon(pl.LightningModule):
def __init__(
self,
lrm_generator_config,
lrm_path=None,
input_size=256,
render_size=192,
):
super(MVRecon, self).__init__()
self.input_size = input_size
self.render_size = render_size
# init modules
self.lrm_generator = instantiate_from_config(lrm_generator_config)
if lrm_path is not None:
lrm_ckpt = torch.load(lrm_path)
self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
self.validation_step_outputs = []
def on_fit_start(self):
if self.global_rank == 0:
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
def prepare_batch_data(self, batch):
lrm_generator_input = {}
render_gt = {} # for supervision
# input images
images = batch['input_images']
images = v2.functional.resize(
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
lrm_generator_input['images'] = images.to(self.device)
# input cameras and render cameras
input_c2ws = batch['input_c2ws'].flatten(-2)
input_Ks = batch['input_Ks'].flatten(-2)
target_c2ws = batch['target_c2ws'].flatten(-2)
target_Ks = batch['target_Ks'].flatten(-2)
render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
input_extrinsics = input_c2ws[:, :, :12]
input_intrinsics = torch.stack([
input_Ks[:, :, 0], input_Ks[:, :, 4],
input_Ks[:, :, 2], input_Ks[:, :, 5],
], dim=-1)
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
# add noise to input cameras
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
lrm_generator_input['cameras'] = cameras.to(self.device)
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
# target images
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
# random crop
render_size = np.random.randint(self.render_size, 513)
target_images = v2.functional.resize(
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
target_depths = v2.functional.resize(
target_depths, render_size, interpolation=0, antialias=True)
target_alphas = v2.functional.resize(
target_alphas, render_size, interpolation=0, antialias=True)
crop_params = v2.RandomCrop.get_params(
target_images, output_size=(self.render_size, self.render_size))
target_images = v2.functional.crop(target_images, *crop_params)
target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
lrm_generator_input['render_size'] = render_size
lrm_generator_input['crop_params'] = crop_params
render_gt['target_images'] = target_images.to(self.device)
render_gt['target_depths'] = target_depths.to(self.device)
render_gt['target_alphas'] = target_alphas.to(self.device)
return lrm_generator_input, render_gt
def prepare_validation_batch_data(self, batch):
lrm_generator_input = {}
# input images
images = batch['input_images']
images = v2.functional.resize(
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
lrm_generator_input['images'] = images.to(self.device)
input_c2ws = batch['input_c2ws'].flatten(-2)
input_Ks = batch['input_Ks'].flatten(-2)
input_extrinsics = input_c2ws[:, :, :12]
input_intrinsics = torch.stack([
input_Ks[:, :, 0], input_Ks[:, :, 4],
input_Ks[:, :, 2], input_Ks[:, :, 5],
], dim=-1)
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
lrm_generator_input['cameras'] = cameras.to(self.device)
render_c2ws = batch['render_c2ws'].flatten(-2)
render_Ks = batch['render_Ks'].flatten(-2)
render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
lrm_generator_input['render_size'] = 384
lrm_generator_input['crop_params'] = None
return lrm_generator_input
def forward_lrm_generator(
self,
images,
cameras,
render_cameras,
render_size=192,
crop_params=None,
chunk_size=1,
):
planes = torch.utils.checkpoint.checkpoint(
self.lrm_generator.forward_planes,
images,
cameras,
use_reentrant=False,
)
frames = []
for i in range(0, render_cameras.shape[1], chunk_size):
frames.append(
torch.utils.checkpoint.checkpoint(
self.lrm_generator.synthesizer,
planes,
cameras=render_cameras[:, i:i+chunk_size],
render_size=render_size,
crop_params=crop_params,
use_reentrant=False
)
)
frames = {
k: torch.cat([r[k] for r in frames], dim=1)
for k in frames[0].keys()
}
return frames
def forward(self, lrm_generator_input):
images = lrm_generator_input['images']
cameras = lrm_generator_input['cameras']
render_cameras = lrm_generator_input['render_cameras']
render_size = lrm_generator_input['render_size']
crop_params = lrm_generator_input['crop_params']
out = self.forward_lrm_generator(
images,
cameras,
render_cameras,
render_size=render_size,
crop_params=crop_params,
chunk_size=1,
)
render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
render_depths = out['images_depth']
render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
out = {
'render_images': render_images,
'render_depths': render_depths,
'render_alphas': render_alphas,
}
return out
def training_step(self, batch, batch_idx):
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
render_out = self.forward(lrm_generator_input)
loss, loss_dict = self.compute_loss(render_out, render_gt)
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
if self.global_step % 1000 == 0 and self.global_rank == 0:
B, N, C, H, W = render_gt['target_images'].shape
N_in = lrm_generator_input['images'].shape[1]
input_images = v2.functional.resize(
lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
input_images = torch.cat(
[input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
input_images = rearrange(
input_images, 'b n c h w -> b c h (n w)')
target_images = rearrange(
render_gt['target_images'], 'b n c h w -> b c h (n w)')
render_images = rearrange(
render_out['render_images'], 'b n c h w -> b c h (n w)')
target_alphas = rearrange(
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
render_alphas = rearrange(
repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
target_depths = rearrange(
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
render_depths = rearrange(
repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
MAX_DEPTH = torch.max(target_depths)
target_depths = target_depths / MAX_DEPTH * target_alphas
render_depths = render_depths / MAX_DEPTH
grid = torch.cat([
input_images,
target_images, render_images,
target_alphas, render_alphas,
target_depths, render_depths,
], dim=-2)
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
return loss
def compute_loss(self, render_out, render_gt):
# NOTE: the rgb value range of OpenLRM is [0, 1]
render_images = render_out['render_images']
target_images = render_gt['target_images'].to(render_images)
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
loss_mse = F.mse_loss(render_images, target_images)
loss_lpips = 2.0 * self.lpips(render_images, target_images)
render_alphas = render_out['render_alphas']
target_alphas = render_gt['target_alphas']
loss_mask = F.mse_loss(render_alphas, target_alphas)
loss = loss_mse + loss_lpips + loss_mask
prefix = 'train'
loss_dict = {}
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
@torch.no_grad()
def validation_step(self, batch, batch_idx):
lrm_generator_input = self.prepare_validation_batch_data(batch)
render_out = self.forward(lrm_generator_input)
render_images = render_out['render_images']
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
self.validation_step_outputs.append(render_images)
def on_validation_epoch_end(self):
images = torch.cat(self.validation_step_outputs, dim=-1)
all_images = self.all_gather(images)
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
if self.global_rank == 0:
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
save_image(grid, image_path)
print(f"Saved image to {image_path}")
self.validation_step_outputs.clear()
def configure_optimizers(self):
lr = self.learning_rate
params = []
params.append({"params": self.lrm_generator.parameters(), "lr": lr, "weight_decay": 0.01 })
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/10)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import v2
from torchvision.utils import make_grid, save_image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import pytorch_lightning as pl
from einops import rearrange, repeat
from src.utils.train_util import instantiate_from_config
# Regulrarization loss for FlexiCubes
def sdf_reg_loss_batch(sdf, all_edges):
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
sdf_f1x6x2 = sdf_f1x6x2[mask]
sdf_diff = F.binary_cross_entropy_with_logits(
sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
F.binary_cross_entropy_with_logits(
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
return sdf_diff
class MVRecon(pl.LightningModule):
def __init__(
self,
lrm_generator_config,
input_size=256,
render_size=512,
init_ckpt=None,
):
super(MVRecon, self).__init__()
self.input_size = input_size
self.render_size = render_size
# init modules
self.lrm_generator = instantiate_from_config(lrm_generator_config)
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
# Load weights from pretrained MVRecon model, and use the mlp
# weights to initialize the weights of sdf and rgb mlps.
if init_ckpt is not None:
sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
sd_fc = {}
for k, v in sd.items():
if k.startswith('lrm_generator.synthesizer.decoder.net.'):
if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
# Here we assume the density filed's isosurface threshold is t,
# we reverse the sign of density filed to initialize SDF field.
# -(w*x + b - t) = (-w)*x + (t - b)
if 'weight' in k:
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
else:
sd_fc[k.replace('net.', 'net_sdf.')] = 10.0 - v[0:1]
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
else:
sd_fc[k.replace('net.', 'net_sdf.')] = v
sd_fc[k.replace('net.', 'net_rgb.')] = v
else:
sd_fc[k] = v
sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
# missing `net_deformation` and `net_weight` parameters
self.lrm_generator.load_state_dict(sd_fc, strict=False)
print(f'Loaded weights from {init_ckpt}')
self.validation_step_outputs = []
def on_fit_start(self):
device = torch.device(f'cuda:{self.global_rank}')
self.lrm_generator.init_flexicubes_geometry(device)
if self.global_rank == 0:
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
def prepare_batch_data(self, batch):
lrm_generator_input = {}
render_gt = {}
# input images
images = batch['input_images']
images = v2.functional.resize(
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
lrm_generator_input['images'] = images.to(self.device)
# input cameras and render cameras
input_c2ws = batch['input_c2ws']
input_Ks = batch['input_Ks']
target_c2ws = batch['target_c2ws']
render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
render_w2cs = torch.linalg.inv(render_c2ws)
input_extrinsics = input_c2ws.flatten(-2)
input_extrinsics = input_extrinsics[:, :, :12]
input_intrinsics = input_Ks.flatten(-2)
input_intrinsics = torch.stack([
input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
], dim=-1)
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
# add noise to input_cameras
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
lrm_generator_input['cameras'] = cameras.to(self.device)
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
# target images
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
render_size = self.render_size
target_images = v2.functional.resize(
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
target_depths = v2.functional.resize(
target_depths, render_size, interpolation=0, antialias=True)
target_alphas = v2.functional.resize(
target_alphas, render_size, interpolation=0, antialias=True)
target_normals = v2.functional.resize(
target_normals, render_size, interpolation=3, antialias=True)
lrm_generator_input['render_size'] = render_size
render_gt['target_images'] = target_images.to(self.device)
render_gt['target_depths'] = target_depths.to(self.device)
render_gt['target_alphas'] = target_alphas.to(self.device)
render_gt['target_normals'] = target_normals.to(self.device)
return lrm_generator_input, render_gt
def prepare_validation_batch_data(self, batch):
lrm_generator_input = {}
# input images
images = batch['input_images']
images = v2.functional.resize(
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
lrm_generator_input['images'] = images.to(self.device)
# input cameras
input_c2ws = batch['input_c2ws'].flatten(-2)
input_Ks = batch['input_Ks'].flatten(-2)
input_extrinsics = input_c2ws[:, :, :12]
input_intrinsics = torch.stack([
input_Ks[:, :, 0], input_Ks[:, :, 4],
input_Ks[:, :, 2], input_Ks[:, :, 5],
], dim=-1)
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
lrm_generator_input['cameras'] = cameras.to(self.device)
# render cameras
render_c2ws = batch['render_c2ws']
render_w2cs = torch.linalg.inv(render_c2ws)
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
lrm_generator_input['render_size'] = 384
return lrm_generator_input
def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
planes = torch.utils.checkpoint.checkpoint(
self.lrm_generator.forward_planes,
images,
cameras,
use_reentrant=False,
)
out = self.lrm_generator.forward_geometry(
planes,
render_cameras,
render_size,
)
return out
def forward(self, lrm_generator_input):
images = lrm_generator_input['images']
cameras = lrm_generator_input['cameras']
render_cameras = lrm_generator_input['render_cameras']
render_size = lrm_generator_input['render_size']
out = self.forward_lrm_generator(
images, cameras, render_cameras, render_size=render_size)
return out
def training_step(self, batch, batch_idx):
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
render_out = self.forward(lrm_generator_input)
loss, loss_dict = self.compute_loss(render_out, render_gt)
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
if self.global_step % 1000 == 0 and self.global_rank == 0:
B, N, C, H, W = render_gt['target_images'].shape
N_in = lrm_generator_input['images'].shape[1]
target_images = rearrange(
render_gt['target_images'], 'b n c h w -> b c h (n w)')
render_images = rearrange(
render_out['img'], 'b n c h w -> b c h (n w)')
target_alphas = rearrange(
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
render_alphas = rearrange(
repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
target_depths = rearrange(
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
render_depths = rearrange(
repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
target_normals = rearrange(
render_gt['target_normals'], 'b n c h w -> b c h (n w)')
render_normals = rearrange(
render_out['normal'], 'b n c h w -> b c h (n w)')
MAX_DEPTH = torch.max(target_depths)
target_depths = target_depths / MAX_DEPTH * target_alphas
render_depths = render_depths / MAX_DEPTH
grid = torch.cat([
target_images, render_images,
target_alphas, render_alphas,
target_depths, render_depths,
target_normals, render_normals,
], dim=-2)
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
save_image(grid, image_path)
print(f"Saved image to {image_path}")
return loss
def compute_loss(self, render_out, render_gt):
# NOTE: the rgb value range of OpenLRM is [0, 1]
render_images = render_out['img']
target_images = render_gt['target_images'].to(render_images)
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
loss_mse = F.mse_loss(render_images, target_images)
loss_lpips = 2.0 * self.lpips(render_images, target_images)
render_alphas = render_out['mask']
target_alphas = render_gt['target_alphas']
loss_mask = F.mse_loss(render_alphas, target_alphas)
render_depths = render_out['depth']
target_depths = render_gt['target_depths']
loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
render_normals = render_out['normal'] * 2.0 - 1.0
target_normals = render_gt['target_normals'] * 2.0 - 1.0
similarity = (render_normals * target_normals).sum(dim=-3).abs()
normal_mask = target_alphas.squeeze(-3)
loss_normal = 1 - similarity[normal_mask>0].mean()
loss_normal = 0.2 * loss_normal
# flexicubes regularization loss
sdf = render_out['sdf']
sdf_reg_loss = render_out['sdf_reg_loss']
sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
_, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
loss = loss_mse + loss_lpips + loss_mask + loss_depth + loss_normal + loss_reg
prefix = 'train'
loss_dict = {}
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
loss_dict.update({f'{prefix}/loss_normal': loss_normal})
loss_dict.update({f'{prefix}/loss_depth': loss_depth})
loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
@torch.no_grad()
def validation_step(self, batch, batch_idx):
lrm_generator_input = self.prepare_validation_batch_data(batch)
render_out = self.forward(lrm_generator_input)
render_images = render_out['img']
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
self.validation_step_outputs.append(render_images)
def on_validation_epoch_end(self):
images = torch.cat(self.validation_step_outputs, dim=-1)
all_images = self.all_gather(images)
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
if self.global_rank == 0:
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
save_image(grid, image_path)
print(f"Saved image to {image_path}")
self.validation_step_outputs.clear()
def configure_optimizers(self):
lr = self.learning_rate
optimizer = torch.optim.AdamW(
self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
\ No newline at end of file
# Copyright (c) 2023, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class BasicTransformerBlock(nn.Module):
"""
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
"""
# use attention from torch.nn.MultiHeadAttention
# Block contains a cross-attention layer, a self-attention layer, and a MLP
def __init__(
self,
inner_dim: int,
cond_dim: int,
num_heads: int,
eps: float,
attn_drop: float = 0.,
attn_bias: bool = False,
mlp_ratio: float = 4.,
mlp_drop: float = 0.,
):
super().__init__()
self.norm1 = nn.LayerNorm(inner_dim)
self.cross_attn = nn.MultiheadAttention(
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
dropout=attn_drop, bias=attn_bias, batch_first=True)
self.norm2 = nn.LayerNorm(inner_dim)
self.self_attn = nn.MultiheadAttention(
embed_dim=inner_dim, num_heads=num_heads,
dropout=attn_drop, bias=attn_bias, batch_first=True)
self.norm3 = nn.LayerNorm(inner_dim)
self.mlp = nn.Sequential(
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(mlp_drop),
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
nn.Dropout(mlp_drop),
)
def forward(self, x, cond):
# x: [N, L, D]
# cond: [N, L_cond, D_cond]
x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
before_sa = self.norm2(x)
x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
x = x + self.mlp(self.norm3(x))
return x
class TriplaneTransformer(nn.Module):
"""
Transformer with condition that generates a triplane representation.
Reference:
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
"""
def __init__(
self,
inner_dim: int,
image_feat_dim: int,
triplane_low_res: int,
triplane_high_res: int,
triplane_dim: int,
num_layers: int,
num_heads: int,
eps: float = 1e-6,
):
super().__init__()
# attributes
self.triplane_low_res = triplane_low_res
self.triplane_high_res = triplane_high_res
self.triplane_dim = triplane_dim
# modules
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
self.layers = nn.ModuleList([
BasicTransformerBlock(
inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(inner_dim, eps=eps)
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
def forward(self, image_feats):
# image_feats: [N, L_cond, D_cond]
N = image_feats.shape[0]
H = W = self.triplane_low_res
L = 3 * H * W
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
for layer in self.layers:
x = layer(x, image_feats)
x = self.norm(x)
# separate each plane and apply deconv
x = x.view(N, 3, H, W, -1)
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
x = self.deconv(x) # [3*N, D', H', W']
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
x = x.contiguous()
return x
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment