"test/vscode:/vscode.git/clone" did not exist on "aba9eae4c653ee4949bb7d5723b4d1b918d206b6"
Commit 63bde97a authored by chenpangpang's avatar chenpangpang
Browse files

feat: 初始提交

parent 9cf8c6f1
Pipeline #1475 failed with stages
in 0 seconds
pytorch-lightning==2.1.2
gradio==3.41.2
huggingface-hub
einops
omegaconf
torchmetrics
webdataset
accelerate
tensorboard
PyMCubes
trimesh
rembg
transformers==4.34.1
diffusers==0.20.2
bitsandbytes
imageio[ffmpeg]
xatlas
plyfile
git+https://github.com/NVlabs/nvdiffrast/
\ No newline at end of file
import os
import argparse
import numpy as np
import torch
import rembg
from PIL import Image
from torchvision.transforms import v2
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from einops import rearrange, repeat
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
get_zero123plus_input_cameras,
get_circular_camera_poses,
)
from src.utils.mesh_util import save_obj, save_obj_with_mtl
from src.utils.infer_util import remove_background, resize_foreground, save_video
def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False):
"""
Get the rendering camera parameters.
"""
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
if is_flexicubes:
cameras = torch.linalg.inv(c2ws)
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
else:
extrinsics = c2ws.flatten(-2)
intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
return cameras
def render_frames(model, planes, render_cameras, render_size=512, chunk_size=1, is_flexicubes=False):
"""
Render frames from triplanes.
"""
frames = []
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
if is_flexicubes:
frame = model.forward_geometry(
planes,
render_cameras[:, i:i+chunk_size],
render_size=render_size,
)['img']
else:
frame = model.forward_synthesizer(
planes,
render_cameras[:, i:i+chunk_size],
render_size=render_size,
)['images_rgb']
frames.append(frame)
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
return frames
###############################################################################
# Arguments.
###############################################################################
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('input_path', type=str, help='Path to input image or directory.')
parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.')
parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.')
parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.')
parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.')
parser.add_argument('--distance', type=float, default=4.5, help='Render distance.')
parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.')
parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.')
parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.')
parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.')
args = parser.parse_args()
seed_everything(args.seed)
###############################################################################
# Stage 0: Configuration.
###############################################################################
config = OmegaConf.load(args.config)
config_name = os.path.basename(args.config).replace('.yaml', '')
model_config = config.model_config
infer_config = config.infer_config
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
device = torch.device('cuda')
# load diffusion model
print('Loading diffusion model ...')
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="zero123plus",
torch_dtype=torch.float16,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
# load custom white-background UNet
print('Loading custom white-background unet ...')
if os.path.exists(infer_config.unet_path):
unet_ckpt_path = infer_config.unet_path
else:
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True)
pipeline = pipeline.to(device)
# load reconstruction model
print('Loading reconstruction model ...')
model = instantiate_from_config(model_config)
if os.path.exists(infer_config.model_path):
model_ckpt_path = infer_config.model_path
else:
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model")
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
if IS_FLEXICUBES:
model.init_flexicubes_geometry(device, fovy=30.0)
model = model.eval()
# make output directories
image_path = os.path.join(args.output_path, config_name, 'images')
mesh_path = os.path.join(args.output_path, config_name, 'meshes')
video_path = os.path.join(args.output_path, config_name, 'videos')
os.makedirs(image_path, exist_ok=True)
os.makedirs(mesh_path, exist_ok=True)
os.makedirs(video_path, exist_ok=True)
# process input files
if os.path.isdir(args.input_path):
input_files = [
os.path.join(args.input_path, file)
for file in os.listdir(args.input_path)
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
]
else:
input_files = [args.input_path]
print(f'Total number of input images: {len(input_files)}')
###############################################################################
# Stage 1: Multiview generation.
###############################################################################
rembg_session = None if args.no_rembg else rembg.new_session()
outputs = []
for idx, image_file in enumerate(input_files):
name = os.path.basename(image_file).split('.')[0]
print(f'[{idx+1}/{len(input_files)}] Imagining {name} ...')
# remove background optionally
input_image = Image.open(image_file)
if not args.no_rembg:
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
# sampling
output_image = pipeline(
input_image,
num_inference_steps=args.diffusion_steps,
).images[0]
output_image.save(os.path.join(image_path, f'{name}.png'))
print(f"Image saved to {os.path.join(image_path, f'{name}.png')}")
images = np.asarray(output_image, dtype=np.float32) / 255.0
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
outputs.append({'name': name, 'images': images})
# delete pipeline to save memory
del pipeline
###############################################################################
# Stage 2: Reconstruction.
###############################################################################
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0*args.scale).to(device)
chunk_size = 20 if IS_FLEXICUBES else 1
for idx, sample in enumerate(outputs):
name = sample['name']
print(f'[{idx+1}/{len(outputs)}] Creating {name} ...')
images = sample['images'].unsqueeze(0).to(device)
images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1)
if args.view == 4:
indices = torch.tensor([0, 2, 4, 5]).long().to(device)
images = images[:, indices]
input_cameras = input_cameras[:, indices]
with torch.no_grad():
# get triplane
planes = model.forward_planes(images, input_cameras)
# get mesh
mesh_path_idx = os.path.join(mesh_path, f'{name}.obj')
mesh_out = model.extract_mesh(
planes,
use_texture_map=args.export_texmap,
**infer_config,
)
if args.export_texmap:
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
save_obj_with_mtl(
vertices.data.cpu().numpy(),
uvs.data.cpu().numpy(),
faces.data.cpu().numpy(),
mesh_tex_idx.data.cpu().numpy(),
tex_map.permute(1, 2, 0).data.cpu().numpy(),
mesh_path_idx,
)
else:
vertices, faces, vertex_colors = mesh_out
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
print(f"Mesh saved to {mesh_path_idx}")
# get video
if args.save_video:
video_path_idx = os.path.join(video_path, f'{name}.mp4')
render_size = infer_config.render_resolution
render_cameras = get_render_cameras(
batch_size=1,
M=120,
radius=args.distance,
elevation=20.0,
is_flexicubes=IS_FLEXICUBES,
).to(device)
frames = render_frames(
model,
planes,
render_cameras=render_cameras,
render_size=render_size,
chunk_size=chunk_size,
is_flexicubes=IS_FLEXICUBES,
)
save_video(
frames,
video_path_idx,
fps=30,
)
print(f"Video saved to {video_path_idx}")
import os, sys
import math
import json
import importlib
from pathlib import Path
import cv2
import random
import numpy as np
from PIL import Image
import webdataset as wds
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
center_looking_at_camera_pose,
get_circular_camera_poses,
)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size=8,
num_workers=4,
train=None,
validation=None,
test=None,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_configs = dict()
if train is not None:
self.dataset_configs['train'] = train
if validation is not None:
self.dataset_configs['validation'] = validation
if test is not None:
self.dataset_configs['test'] = test
def setup(self, stage):
if stage in ['fit']:
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
else:
raise NotImplementedError
def train_dataloader(self):
sampler = DistributedSampler(self.datasets['train'])
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def val_dataloader(self):
sampler = DistributedSampler(self.datasets['validation'])
return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def test_dataloader(self):
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
class ObjaverseData(Dataset):
def __init__(self,
root_dir='objaverse/',
meta_fname='valid_paths.json',
input_image_dir='rendering_random_32views',
target_image_dir='rendering_random_32views',
input_view_num=6,
target_view_num=4,
total_view_n=32,
fov=50,
camera_rotation=True,
validation=False,
):
self.root_dir = Path(root_dir)
self.input_image_dir = input_image_dir
self.target_image_dir = target_image_dir
self.input_view_num = input_view_num
self.target_view_num = target_view_num
self.total_view_n = total_view_n
self.fov = fov
self.camera_rotation = camera_rotation
with open(os.path.join(root_dir, meta_fname)) as f:
filtered_dict = json.load(f)
paths = filtered_dict['good_objs']
self.paths = paths
self.depth_scale = 6.0
total_objects = len(self.paths)
print('============= length of dataset %d =============' % len(self.paths))
def __len__(self):
return len(self.paths)
def load_im(self, path, color):
'''
replace background pixel with random color in rendering
'''
pil_img = Image.open(path)
image = np.asarray(pil_img, dtype=np.float32) / 255.
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha
def __getitem__(self, index):
while True:
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
input_indices = indices[:self.input_view_num]
target_indices = indices[self.input_view_num:]
'''background color, default: white'''
bg_white = [1., 1., 1.]
bg_black = [0., 0., 0.]
image_list = []
alpha_list = []
depth_list = []
normal_list = []
pose_list = []
try:
input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
for idx in input_indices:
image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
depth = torch.from_numpy(depth).unsqueeze(0)
pose = input_cameras[idx]
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
image_list.append(image)
alpha_list.append(alpha)
depth_list.append(depth)
normal_list.append(normal)
pose_list.append(pose)
target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
for idx in target_indices:
image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
depth = torch.from_numpy(depth).unsqueeze(0)
pose = target_cameras[idx]
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
image_list.append(image)
alpha_list.append(alpha)
depth_list.append(depth)
normal_list.append(normal)
pose_list.append(pose)
except Exception as e:
print(e)
index = np.random.randint(0, len(self.paths))
continue
break
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
c2ws = torch.linalg.inv(w2cs).float()
normals = normals * 2.0 - 1.0
normals = F.normalize(normals, dim=1)
normals = (normals + 1.0) / 2.0
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
# random rotation along z axis
if self.camera_rotation:
degree = np.random.uniform(0, math.pi * 2)
rot = torch.tensor([
[np.cos(degree), -np.sin(degree), 0, 0],
[np.sin(degree), np.cos(degree), 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]).unsqueeze(0).float()
c2ws = torch.matmul(rot, c2ws)
# rotate normals
N, _, H, W = normals.shape
normals = normals * 2.0 - 1.0
normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
normals = F.normalize(normals, dim=1)
normals = (normals + 1.0) / 2.0
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
# random scaling
if np.random.rand() < 0.5:
scale = np.random.uniform(0.7, 1.1)
c2ws[:, :3, 3] *= scale
depths *= scale
# instrinsics of perspective cameras
K = FOV_to_intrinsics(self.fov)
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
data = {
'input_images': images[:self.input_view_num], # (6, 3, H, W)
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4)
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
# lrm generator input and supervision
'target_images': images[self.input_view_num:], # (V, 3, H, W)
'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
}
return data
class ValidationData(Dataset):
def __init__(self,
root_dir='objaverse/',
input_view_num=6,
input_image_size=320,
fov=30,
):
self.root_dir = Path(root_dir)
self.input_view_num = input_view_num
self.input_image_size = input_image_size
self.fov = fov
self.paths = sorted(os.listdir(self.root_dir))
print('============= length of dataset %d =============' % len(self.paths))
cam_distance = 4.0
azimuths = np.array([30, 90, 150, 210, 270, 330])
elevations = np.array([20, -10, 20, -10, 20, -10])
azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations)
x = cam_distance * np.cos(elevations) * np.cos(azimuths)
y = cam_distance * np.cos(elevations) * np.sin(azimuths)
z = cam_distance * np.sin(elevations)
cam_locations = np.stack([x, y, z], axis=-1)
cam_locations = torch.from_numpy(cam_locations).float()
c2ws = center_looking_at_camera_pose(cam_locations)
self.c2ws = c2ws.float()
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0)
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
self.render_c2ws = render_c2ws.float()
self.render_Ks = render_Ks.float()
def __len__(self):
return len(self.paths)
def load_im(self, path, color):
'''
replace background pixel with random color in rendering
'''
pil_img = Image.open(path)
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
image = np.asarray(pil_img, dtype=np.float32) / 255.
if image.shape[-1] == 4:
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)
else:
alpha = np.ones_like(image[:, :, :1])
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha
def __getitem__(self, index):
# load data
input_image_path = os.path.join(self.root_dir, self.paths[index])
'''background color, default: white'''
bkg_color = [1.0, 1.0, 1.0]
image_list = []
alpha_list = []
for idx in range(self.input_view_num):
image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
image_list.append(image)
alpha_list.append(alpha)
images = torch.stack(image_list, dim=0).float()
alphas = torch.stack(alpha_list, dim=0).float()
data = {
'input_images': images,
'input_alphas': alphas,
'input_c2ws': self.c2ws,
'input_Ks': self.Ks,
'render_c2ws': self.render_c2ws,
'render_Ks': self.render_Ks,
}
return data
import os
import json
import numpy as np
import webdataset as wds
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from PIL import Image
from pathlib import Path
from src.utils.train_util import instantiate_from_config
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size=8,
num_workers=4,
train=None,
validation=None,
test=None,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_configs = dict()
if train is not None:
self.dataset_configs['train'] = train
if validation is not None:
self.dataset_configs['validation'] = validation
if test is not None:
self.dataset_configs['test'] = test
def setup(self, stage):
if stage in ['fit']:
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
else:
raise NotImplementedError
def train_dataloader(self):
sampler = DistributedSampler(self.datasets['train'])
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def val_dataloader(self):
sampler = DistributedSampler(self.datasets['validation'])
return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def test_dataloader(self):
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
class ObjaverseData(Dataset):
def __init__(self,
root_dir='objaverse/',
meta_fname='valid_paths.json',
image_dir='rendering_zero123plus',
validation=False,
):
self.root_dir = Path(root_dir)
self.image_dir = image_dir
with open(os.path.join(root_dir, meta_fname)) as f:
lvis_dict = json.load(f)
paths = []
for k in lvis_dict.keys():
paths.extend(lvis_dict[k])
self.paths = paths
total_objects = len(self.paths)
if validation:
self.paths = self.paths[-16:] # used last 16 as validation
else:
self.paths = self.paths[:-16]
print('============= length of dataset %d =============' % len(self.paths))
def __len__(self):
return len(self.paths)
def load_im(self, path, color):
pil_img = Image.open(path)
image = np.asarray(pil_img, dtype=np.float32) / 255.
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha
def __getitem__(self, index):
while True:
image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index])
'''background color, default: white'''
bkg_color = [1., 1., 1.]
img_list = []
try:
for idx in range(7):
img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color)
img_list.append(img)
except Exception as e:
print(e)
index = np.random.randint(0, len(self.paths))
continue
break
imgs = torch.stack(img_list, dim=0).float()
data = {
'cond_imgs': imgs[0], # (3, H, W)
'target_imgs': imgs[1:], # (6, 3, H, W)
}
return data
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import v2
from torchvision.utils import make_grid, save_image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import pytorch_lightning as pl
from einops import rearrange, repeat
from src.utils.train_util import instantiate_from_config
class MVRecon(pl.LightningModule):
def __init__(
self,
lrm_generator_config,
lrm_path=None,
input_size=256,
render_size=192,
):
super(MVRecon, self).__init__()
self.input_size = input_size
self.render_size = render_size
# init modules
self.lrm_generator = instantiate_from_config(lrm_generator_config)
if lrm_path is not None:
lrm_ckpt = torch.load(lrm_path)
self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
self.validation_step_outputs = []
def on_fit_start(self):
if self.global_rank == 0:
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
def prepare_batch_data(self, batch):
lrm_generator_input = {}
render_gt = {} # for supervision
# input images
images = batch['input_images']
images = v2.functional.resize(
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
lrm_generator_input['images'] = images.to(self.device)
# input cameras and render cameras
input_c2ws = batch['input_c2ws'].flatten(-2)
input_Ks = batch['input_Ks'].flatten(-2)
target_c2ws = batch['target_c2ws'].flatten(-2)
target_Ks = batch['target_Ks'].flatten(-2)
render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
input_extrinsics = input_c2ws[:, :, :12]
input_intrinsics = torch.stack([
input_Ks[:, :, 0], input_Ks[:, :, 4],
input_Ks[:, :, 2], input_Ks[:, :, 5],
], dim=-1)
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
# add noise to input cameras
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
lrm_generator_input['cameras'] = cameras.to(self.device)
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
# target images
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
# random crop
render_size = np.random.randint(self.render_size, 513)
target_images = v2.functional.resize(
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
target_depths = v2.functional.resize(
target_depths, render_size, interpolation=0, antialias=True)
target_alphas = v2.functional.resize(
target_alphas, render_size, interpolation=0, antialias=True)
crop_params = v2.RandomCrop.get_params(
target_images, output_size=(self.render_size, self.render_size))
target_images = v2.functional.crop(target_images, *crop_params)
target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
lrm_generator_input['render_size'] = render_size
lrm_generator_input['crop_params'] = crop_params
render_gt['target_images'] = target_images.to(self.device)
render_gt['target_depths'] = target_depths.to(self.device)
render_gt['target_alphas'] = target_alphas.to(self.device)
return lrm_generator_input, render_gt
def prepare_validation_batch_data(self, batch):
lrm_generator_input = {}
# input images
images = batch['input_images']
images = v2.functional.resize(
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
lrm_generator_input['images'] = images.to(self.device)
input_c2ws = batch['input_c2ws'].flatten(-2)
input_Ks = batch['input_Ks'].flatten(-2)
input_extrinsics = input_c2ws[:, :, :12]
input_intrinsics = torch.stack([
input_Ks[:, :, 0], input_Ks[:, :, 4],
input_Ks[:, :, 2], input_Ks[:, :, 5],
], dim=-1)
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
lrm_generator_input['cameras'] = cameras.to(self.device)
render_c2ws = batch['render_c2ws'].flatten(-2)
render_Ks = batch['render_Ks'].flatten(-2)
render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
lrm_generator_input['render_size'] = 384
lrm_generator_input['crop_params'] = None
return lrm_generator_input
def forward_lrm_generator(
self,
images,
cameras,
render_cameras,
render_size=192,
crop_params=None,
chunk_size=1,
):
planes = torch.utils.checkpoint.checkpoint(
self.lrm_generator.forward_planes,
images,
cameras,
use_reentrant=False,
)
frames = []
for i in range(0, render_cameras.shape[1], chunk_size):
frames.append(
torch.utils.checkpoint.checkpoint(
self.lrm_generator.synthesizer,
planes,
cameras=render_cameras[:, i:i+chunk_size],
render_size=render_size,
crop_params=crop_params,
use_reentrant=False
)
)
frames = {
k: torch.cat([r[k] for r in frames], dim=1)
for k in frames[0].keys()
}
return frames
def forward(self, lrm_generator_input):
images = lrm_generator_input['images']
cameras = lrm_generator_input['cameras']
render_cameras = lrm_generator_input['render_cameras']
render_size = lrm_generator_input['render_size']
crop_params = lrm_generator_input['crop_params']
out = self.forward_lrm_generator(
images,
cameras,
render_cameras,
render_size=render_size,
crop_params=crop_params,
chunk_size=1,
)
render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
render_depths = out['images_depth']
render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
out = {
'render_images': render_images,
'render_depths': render_depths,
'render_alphas': render_alphas,
}
return out
def training_step(self, batch, batch_idx):
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
render_out = self.forward(lrm_generator_input)
loss, loss_dict = self.compute_loss(render_out, render_gt)
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
if self.global_step % 1000 == 0 and self.global_rank == 0:
B, N, C, H, W = render_gt['target_images'].shape
N_in = lrm_generator_input['images'].shape[1]
input_images = v2.functional.resize(
lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
input_images = torch.cat(
[input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
input_images = rearrange(
input_images, 'b n c h w -> b c h (n w)')
target_images = rearrange(
render_gt['target_images'], 'b n c h w -> b c h (n w)')
render_images = rearrange(
render_out['render_images'], 'b n c h w -> b c h (n w)')
target_alphas = rearrange(
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
render_alphas = rearrange(
repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
target_depths = rearrange(
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
render_depths = rearrange(
repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
MAX_DEPTH = torch.max(target_depths)
target_depths = target_depths / MAX_DEPTH * target_alphas
render_depths = render_depths / MAX_DEPTH
grid = torch.cat([
input_images,
target_images, render_images,
target_alphas, render_alphas,
target_depths, render_depths,
], dim=-2)
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
return loss
def compute_loss(self, render_out, render_gt):
# NOTE: the rgb value range of OpenLRM is [0, 1]
render_images = render_out['render_images']
target_images = render_gt['target_images'].to(render_images)
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
loss_mse = F.mse_loss(render_images, target_images)
loss_lpips = 2.0 * self.lpips(render_images, target_images)
render_alphas = render_out['render_alphas']
target_alphas = render_gt['target_alphas']
loss_mask = F.mse_loss(render_alphas, target_alphas)
loss = loss_mse + loss_lpips + loss_mask
prefix = 'train'
loss_dict = {}
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
@torch.no_grad()
def validation_step(self, batch, batch_idx):
lrm_generator_input = self.prepare_validation_batch_data(batch)
render_out = self.forward(lrm_generator_input)
render_images = render_out['render_images']
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
self.validation_step_outputs.append(render_images)
def on_validation_epoch_end(self):
images = torch.cat(self.validation_step_outputs, dim=-1)
all_images = self.all_gather(images)
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
if self.global_rank == 0:
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
save_image(grid, image_path)
print(f"Saved image to {image_path}")
self.validation_step_outputs.clear()
def configure_optimizers(self):
lr = self.learning_rate
params = []
params.append({"params": self.lrm_generator.parameters(), "lr": lr, "weight_decay": 0.01 })
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/10)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import v2
from torchvision.utils import make_grid, save_image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import pytorch_lightning as pl
from einops import rearrange, repeat
from src.utils.train_util import instantiate_from_config
# Regulrarization loss for FlexiCubes
def sdf_reg_loss_batch(sdf, all_edges):
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
sdf_f1x6x2 = sdf_f1x6x2[mask]
sdf_diff = F.binary_cross_entropy_with_logits(
sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
F.binary_cross_entropy_with_logits(
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
return sdf_diff
class MVRecon(pl.LightningModule):
def __init__(
self,
lrm_generator_config,
input_size=256,
render_size=512,
init_ckpt=None,
):
super(MVRecon, self).__init__()
self.input_size = input_size
self.render_size = render_size
# init modules
self.lrm_generator = instantiate_from_config(lrm_generator_config)
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
# Load weights from pretrained MVRecon model, and use the mlp
# weights to initialize the weights of sdf and rgb mlps.
if init_ckpt is not None:
sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
sd_fc = {}
for k, v in sd.items():
if k.startswith('lrm_generator.synthesizer.decoder.net.'):
if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
# Here we assume the density filed's isosurface threshold is t,
# we reverse the sign of density filed to initialize SDF field.
# -(w*x + b - t) = (-w)*x + (t - b)
if 'weight' in k:
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
else:
sd_fc[k.replace('net.', 'net_sdf.')] = 10.0 - v[0:1]
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
else:
sd_fc[k.replace('net.', 'net_sdf.')] = v
sd_fc[k.replace('net.', 'net_rgb.')] = v
else:
sd_fc[k] = v
sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
# missing `net_deformation` and `net_weight` parameters
self.lrm_generator.load_state_dict(sd_fc, strict=False)
print(f'Loaded weights from {init_ckpt}')
self.validation_step_outputs = []
def on_fit_start(self):
device = torch.device(f'cuda:{self.global_rank}')
self.lrm_generator.init_flexicubes_geometry(device)
if self.global_rank == 0:
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
def prepare_batch_data(self, batch):
lrm_generator_input = {}
render_gt = {}
# input images
images = batch['input_images']
images = v2.functional.resize(
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
lrm_generator_input['images'] = images.to(self.device)
# input cameras and render cameras
input_c2ws = batch['input_c2ws']
input_Ks = batch['input_Ks']
target_c2ws = batch['target_c2ws']
render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
render_w2cs = torch.linalg.inv(render_c2ws)
input_extrinsics = input_c2ws.flatten(-2)
input_extrinsics = input_extrinsics[:, :, :12]
input_intrinsics = input_Ks.flatten(-2)
input_intrinsics = torch.stack([
input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
], dim=-1)
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
# add noise to input_cameras
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
lrm_generator_input['cameras'] = cameras.to(self.device)
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
# target images
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
render_size = self.render_size
target_images = v2.functional.resize(
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
target_depths = v2.functional.resize(
target_depths, render_size, interpolation=0, antialias=True)
target_alphas = v2.functional.resize(
target_alphas, render_size, interpolation=0, antialias=True)
target_normals = v2.functional.resize(
target_normals, render_size, interpolation=3, antialias=True)
lrm_generator_input['render_size'] = render_size
render_gt['target_images'] = target_images.to(self.device)
render_gt['target_depths'] = target_depths.to(self.device)
render_gt['target_alphas'] = target_alphas.to(self.device)
render_gt['target_normals'] = target_normals.to(self.device)
return lrm_generator_input, render_gt
def prepare_validation_batch_data(self, batch):
lrm_generator_input = {}
# input images
images = batch['input_images']
images = v2.functional.resize(
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
lrm_generator_input['images'] = images.to(self.device)
# input cameras
input_c2ws = batch['input_c2ws'].flatten(-2)
input_Ks = batch['input_Ks'].flatten(-2)
input_extrinsics = input_c2ws[:, :, :12]
input_intrinsics = torch.stack([
input_Ks[:, :, 0], input_Ks[:, :, 4],
input_Ks[:, :, 2], input_Ks[:, :, 5],
], dim=-1)
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
lrm_generator_input['cameras'] = cameras.to(self.device)
# render cameras
render_c2ws = batch['render_c2ws']
render_w2cs = torch.linalg.inv(render_c2ws)
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
lrm_generator_input['render_size'] = 384
return lrm_generator_input
def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
planes = torch.utils.checkpoint.checkpoint(
self.lrm_generator.forward_planes,
images,
cameras,
use_reentrant=False,
)
out = self.lrm_generator.forward_geometry(
planes,
render_cameras,
render_size,
)
return out
def forward(self, lrm_generator_input):
images = lrm_generator_input['images']
cameras = lrm_generator_input['cameras']
render_cameras = lrm_generator_input['render_cameras']
render_size = lrm_generator_input['render_size']
out = self.forward_lrm_generator(
images, cameras, render_cameras, render_size=render_size)
return out
def training_step(self, batch, batch_idx):
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
render_out = self.forward(lrm_generator_input)
loss, loss_dict = self.compute_loss(render_out, render_gt)
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
if self.global_step % 1000 == 0 and self.global_rank == 0:
B, N, C, H, W = render_gt['target_images'].shape
N_in = lrm_generator_input['images'].shape[1]
target_images = rearrange(
render_gt['target_images'], 'b n c h w -> b c h (n w)')
render_images = rearrange(
render_out['img'], 'b n c h w -> b c h (n w)')
target_alphas = rearrange(
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
render_alphas = rearrange(
repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
target_depths = rearrange(
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
render_depths = rearrange(
repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
target_normals = rearrange(
render_gt['target_normals'], 'b n c h w -> b c h (n w)')
render_normals = rearrange(
render_out['normal'], 'b n c h w -> b c h (n w)')
MAX_DEPTH = torch.max(target_depths)
target_depths = target_depths / MAX_DEPTH * target_alphas
render_depths = render_depths / MAX_DEPTH
grid = torch.cat([
target_images, render_images,
target_alphas, render_alphas,
target_depths, render_depths,
target_normals, render_normals,
], dim=-2)
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
save_image(grid, image_path)
print(f"Saved image to {image_path}")
return loss
def compute_loss(self, render_out, render_gt):
# NOTE: the rgb value range of OpenLRM is [0, 1]
render_images = render_out['img']
target_images = render_gt['target_images'].to(render_images)
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
loss_mse = F.mse_loss(render_images, target_images)
loss_lpips = 2.0 * self.lpips(render_images, target_images)
render_alphas = render_out['mask']
target_alphas = render_gt['target_alphas']
loss_mask = F.mse_loss(render_alphas, target_alphas)
render_depths = render_out['depth']
target_depths = render_gt['target_depths']
loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
render_normals = render_out['normal'] * 2.0 - 1.0
target_normals = render_gt['target_normals'] * 2.0 - 1.0
similarity = (render_normals * target_normals).sum(dim=-3).abs()
normal_mask = target_alphas.squeeze(-3)
loss_normal = 1 - similarity[normal_mask>0].mean()
loss_normal = 0.2 * loss_normal
# flexicubes regularization loss
sdf = render_out['sdf']
sdf_reg_loss = render_out['sdf_reg_loss']
sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
_, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
loss = loss_mse + loss_lpips + loss_mask + loss_depth + loss_normal + loss_reg
prefix = 'train'
loss_dict = {}
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
loss_dict.update({f'{prefix}/loss_normal': loss_normal})
loss_dict.update({f'{prefix}/loss_depth': loss_depth})
loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
@torch.no_grad()
def validation_step(self, batch, batch_idx):
lrm_generator_input = self.prepare_validation_batch_data(batch)
render_out = self.forward(lrm_generator_input)
render_images = render_out['img']
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
self.validation_step_outputs.append(render_images)
def on_validation_epoch_end(self):
images = torch.cat(self.validation_step_outputs, dim=-1)
all_images = self.all_gather(images)
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
if self.global_rank == 0:
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
save_image(grid, image_path)
print(f"Saved image to {image_path}")
self.validation_step_outputs.clear()
def configure_optimizers(self):
lr = self.learning_rate
optimizer = torch.optim.AdamW(
self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
\ No newline at end of file
# Copyright (c) 2023, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class BasicTransformerBlock(nn.Module):
"""
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
"""
# use attention from torch.nn.MultiHeadAttention
# Block contains a cross-attention layer, a self-attention layer, and a MLP
def __init__(
self,
inner_dim: int,
cond_dim: int,
num_heads: int,
eps: float,
attn_drop: float = 0.,
attn_bias: bool = False,
mlp_ratio: float = 4.,
mlp_drop: float = 0.,
):
super().__init__()
self.norm1 = nn.LayerNorm(inner_dim)
self.cross_attn = nn.MultiheadAttention(
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
dropout=attn_drop, bias=attn_bias, batch_first=True)
self.norm2 = nn.LayerNorm(inner_dim)
self.self_attn = nn.MultiheadAttention(
embed_dim=inner_dim, num_heads=num_heads,
dropout=attn_drop, bias=attn_bias, batch_first=True)
self.norm3 = nn.LayerNorm(inner_dim)
self.mlp = nn.Sequential(
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(mlp_drop),
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
nn.Dropout(mlp_drop),
)
def forward(self, x, cond):
# x: [N, L, D]
# cond: [N, L_cond, D_cond]
x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
before_sa = self.norm2(x)
x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
x = x + self.mlp(self.norm3(x))
return x
class TriplaneTransformer(nn.Module):
"""
Transformer with condition that generates a triplane representation.
Reference:
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
"""
def __init__(
self,
inner_dim: int,
image_feat_dim: int,
triplane_low_res: int,
triplane_high_res: int,
triplane_dim: int,
num_layers: int,
num_heads: int,
eps: float = 1e-6,
):
super().__init__()
# attributes
self.triplane_low_res = triplane_low_res
self.triplane_high_res = triplane_high_res
self.triplane_dim = triplane_dim
# modules
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
self.layers = nn.ModuleList([
BasicTransformerBlock(
inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(inner_dim, eps=eps)
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
def forward(self, image_feats):
# image_feats: [N, L_cond, D_cond]
N = image_feats.shape[0]
H = W = self.triplane_low_res
L = 3 * H * W
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
for layer in self.layers:
x = layer(x, image_feats)
x = self.norm(x)
# separate each plane and apply deconv
x = x.view(N, 3, H, W, -1)
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
x = self.deconv(x) # [3*N, D', H', W']
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
x = x.contiguous()
return x
# coding=utf-8
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch ViT model."""
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
)
from transformers import PreTrainedModel, ViTConfig
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
class ViTEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class ViTPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class ViTSelfAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class ViTSelfOutput(nn.Module):
"""
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViTAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTOutput(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTAttention(config)
self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
)
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(
self,
hidden_states: torch.Tensor,
adaln_input: torch.Tensor = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
self_attention_outputs = self.attention(
modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection
hidden_states = attention_output + hidden_states
# in ViT, layernorm is also applied after self-attention
layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ViTEncoder(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
adaln_input: torch.Tensor = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
adaln_input,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class ViTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ViTConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.cls_token.dtype)
class ViTModel(ViTPreTrainedModel):
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
super().__init__(config)
self.config = config
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> ViTPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
adaln_input: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
adaln_input=adaln_input,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class ViTPooler(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment