Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
import importlib
__attributes = {
'FlexiDualGridDataset': 'flexi_dual_grid',
'SparseVoxelPbrDataset':'sparse_voxel_pbr',
'SparseStructureLatent': 'sparse_structure_latent',
'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
'SLat': 'structured_latent',
'ImageConditionedSLat': 'structured_latent',
'SLatShape': 'structured_latent_shape',
'ImageConditionedSLatShape': 'structured_latent_shape',
'SLatPbr': 'structured_latent_svpbr',
'ImageConditionedSLatPbr': 'structured_latent_svpbr',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .flexi_dual_grid import FlexiDualGridDataset
from .sparse_voxel_pbr import SparseVoxelPbrDataset
from .sparse_structure_latent import SparseStructureLatent, ImageConditionedSparseStructureLatent
from .structured_latent import SLat, ImageConditionedSLat
from .structured_latent_shape import SLatShape, ImageConditionedSLatShape
from .structured_latent_svpbr import SLatPbr, ImageConditionedSLatPbr
\ No newline at end of file
from typing import *
import json
from abc import abstractmethod
import os
import json
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
class StandardDatasetBase(Dataset):
"""
Base class for standard datasets.
Args:
roots (str): paths to the dataset
"""
def __init__(self,
roots: str,
):
super().__init__()
try:
self.roots = json.loads(roots)
root_type = 'obj'
except:
self.roots = roots.split(',')
root_type = 'list'
self.instances = []
self.metadata = pd.DataFrame()
self._stats = {}
if root_type == 'obj':
for key, root in self.roots.items():
self._stats[key] = {}
metadata = pd.DataFrame(columns=['sha256']).set_index('sha256')
for _, r in root.items():
metadata = metadata.combine_first(pd.read_csv(os.path.join(r, 'metadata.csv')).set_index('sha256'))
self._stats[key]['Total'] = len(metadata)
metadata, stats = self.filter_metadata(metadata)
self._stats[key].update(stats)
self.instances.extend([(root, sha256) for sha256 in metadata.index.values])
self.metadata = pd.concat([self.metadata, metadata])
else:
for root in self.roots:
key = os.path.basename(root)
self._stats[key] = {}
metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
self._stats[key]['Total'] = len(metadata)
metadata, stats = self.filter_metadata(metadata)
self._stats[key].update(stats)
self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
metadata.set_index('sha256', inplace=True)
self.metadata = pd.concat([self.metadata, metadata])
@abstractmethod
def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
pass
@abstractmethod
def get_instance(self, root, instance: str) -> Dict[str, Any]:
pass
def __len__(self):
return len(self.instances)
def __getitem__(self, index) -> Dict[str, Any]:
try:
root, instance = self.instances[index]
return self.get_instance(root, instance)
except Exception as e:
print(f'Error loading {instance}: {e}')
return self.__getitem__(np.random.randint(0, len(self)))
def __str__(self):
lines = []
lines.append(self.__class__.__name__)
lines.append(f' - Total instances: {len(self)}')
lines.append(f' - Sources:')
for key, stats in self._stats.items():
lines.append(f' - {key}:')
for k, v in stats.items():
lines.append(f' - {k}: {v}')
return '\n'.join(lines)
class ImageConditionedMixin:
def __init__(self, roots, *, image_size=518, **kwargs):
self.image_size = image_size
super().__init__(roots, **kwargs)
def filter_metadata(self, metadata):
metadata, stats = super().filter_metadata(metadata)
metadata = metadata[metadata['cond_rendered'].notna()]
stats['Cond rendered'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
pack = super().get_instance(root, instance)
image_root = os.path.join(root['render_cond'], instance)
with open(os.path.join(image_root, 'transforms.json')) as f:
metadata = json.load(f)
n_views = len(metadata['frames'])
view = np.random.randint(n_views)
metadata = metadata['frames'][view]
image_path = os.path.join(image_root, metadata['file_path'])
image = Image.open(image_path)
alpha = np.array(image.getchannel(3))
bbox = np.array(alpha).nonzero()
bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_hsize = hsize
aug_center_offset = [0, 0]
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
image = image.crop(aug_bbox)
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
alpha = image.getchannel(3)
image = image.convert('RGB')
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
alpha = torch.tensor(np.array(alpha)).float() / 255.0
image = image * alpha.unsqueeze(0)
pack['cond'] = image
return pack
class MultiImageConditionedMixin:
def __init__(self, roots, *, image_size=518, max_image_cond_view = 4, **kwargs):
self.image_size = image_size
self.max_image_cond_view = max_image_cond_view
super().__init__(roots, **kwargs)
def filter_metadata(self, metadata):
metadata, stats = super().filter_metadata(metadata)
metadata = metadata[metadata['cond_rendered'].notna()]
stats['Cond rendered'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
pack = super().get_instance(root, instance)
image_root = os.path.join(root['render_cond'], instance)
with open(os.path.join(image_root, 'transforms.json')) as f:
metadata = json.load(f)
n_views = len(metadata['frames'])
n_sample_views = np.random.randint(1, self.max_image_cond_view+1)
assert n_views >= n_sample_views, f'Not enough views to sample {n_sample_views} unique images.'
sampled_views = np.random.choice(n_views, size=n_sample_views, replace=False)
cond_images = []
for v in sampled_views:
frame_info = metadata['frames'][v]
image_path = os.path.join(image_root, frame_info['file_path'])
image = Image.open(image_path)
alpha = np.array(image.getchannel(3))
bbox = np.array(alpha).nonzero()
bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_hsize = hsize
aug_center = center
aug_bbox = [
int(aug_center[0] - aug_hsize),
int(aug_center[1] - aug_hsize),
int(aug_center[0] + aug_hsize),
int(aug_center[1] + aug_hsize),
]
img = image.crop(aug_bbox)
img = img.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
alpha = img.getchannel(3)
img = img.convert('RGB')
img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0
alpha = torch.tensor(np.array(alpha)).float() / 255.0
img = img * alpha.unsqueeze(0)
cond_images.append(img)
pack['cond'] = [torch.stack(cond_images, dim=0)] # (V,3,H,W)
return pack
import os
import numpy as np
import pickle
import torch
import utils3d
from .components import StandardDatasetBase
from ..modules import sparse as sp
from ..renderers import MeshRenderer
from ..representations import Mesh
from ..utils.data_utils import load_balanced_group_indices
import o_voxel
class FlexiDualGridVisMixin:
@torch.no_grad()
def visualize_sample(self, x: dict):
mesh = x['mesh']
renderer = MeshRenderer({'near': 1, 'far': 3})
renderer.rendering_options.resolution = 512
renderer.rendering_options.ssaa = 4
# Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
exts = []
ints = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(30)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
exts.append(extrinsics)
ints.append(intrinsics)
# Build each representation
images = []
for m in mesh:
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = \
renderer.render(m.cuda(), ext, intr)['normal']
images.append(image)
images = torch.stack(images)
return images
class FlexiDualGridDataset(FlexiDualGridVisMixin, StandardDatasetBase):
"""
Flexible Dual Grid Dataset
Args:
roots (str): path to the dataset
resolution (int): resolution of the voxel grid
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
"""
def __init__(
self,
roots,
resolution: int = 1024,
max_active_voxels: int = 1000000,
max_num_faces: int = None,
min_aesthetic_score: float = 5.0,
):
self.resolution = resolution
self.min_aesthetic_score = min_aesthetic_score
self.max_active_voxels = max_active_voxels
self.max_num_faces = max_num_faces
self.value_range = (0, 1)
super().__init__(roots)
self.loads = [self.metadata.loc[sha256, f'dual_grid_size'] for _, sha256 in self.instances]
def __str__(self):
lines = [
super().__str__(),
f' - Resolution: {self.resolution}',
]
return '\n'.join(lines)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'dual_grid_converted'] == True]
stats['Dual Grid Converted'] = len(metadata)
if self.min_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
metadata = metadata[metadata[f'dual_grid_size'] <= self.max_active_voxels]
stats[f'Active Voxels <= {self.max_active_voxels}'] = len(metadata)
if self.max_num_faces is not None:
metadata = metadata[metadata['num_faces'] <= self.max_num_faces]
stats[f'Faces <= {self.max_num_faces}'] = len(metadata)
return metadata, stats
def read_mesh(self, root, instance):
with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f:
dump = pickle.load(f)
start = 0
vertices = []
faces = []
for obj in dump['objects']:
if obj['vertices'].size == 0 or obj['faces'].size == 0:
continue
vertices.append(obj['vertices'])
faces.append(obj['faces'] + start)
start += len(obj['vertices'])
vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float()
faces = torch.from_numpy(np.concatenate(faces, axis=0)).long()
vertices_min = vertices.min(dim=0)[0]
vertices_max = vertices.max(dim=0)[0]
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
vertices = (vertices - center) * scale
assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
return {'mesh': [Mesh(vertices=vertices, faces=faces)]}
def read_dual_grid(self, root, instance):
coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4)
vertices = sp.SparseTensor(
(attr['vertices'] / 255.0).float(),
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
)
intersected = vertices.replace(torch.cat([
attr['intersected'] % 2,
attr['intersected'] // 2 % 2,
attr['intersected'] // 4 % 2,
], dim=-1).bool())
return {'vertices': vertices, 'intersected': intersected}
def get_instance(self, root, instance):
mesh = self.read_mesh(root['mesh_dump'], instance)
dual_grid = self.read_dual_grid(root['dual_grid'], instance)
return {**mesh, **dual_grid}
@staticmethod
def collate_fn(batch, split_size=None):
if split_size is None:
group_idx = [list(range(len(batch)))]
else:
group_idx = load_balanced_group_indices([b['vertices'].feats.shape[0] for b in batch], split_size)
packs = []
for group in group_idx:
sub_batch = [batch[i] for i in group]
pack = {}
keys = [k for k in sub_batch[0].keys()]
for k in keys:
if isinstance(sub_batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in sub_batch])
elif isinstance(sub_batch[0][k], sp.SparseTensor):
pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0)
elif isinstance(sub_batch[0][k], list):
pack[k] = sum([b[k] for b in sub_batch], [])
else:
pack[k] = [b[k] for b in sub_batch]
packs.append(pack)
if split_size is None:
return packs[0]
return packs
\ No newline at end of file
import os
import json
from typing import *
import numpy as np
import torch
from ..representations import Voxel
from ..renderers import VoxelRenderer
from .components import StandardDatasetBase, ImageConditionedMixin
from .. import models
from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics
class SparseStructureLatentVisMixin:
def __init__(
self,
*args,
pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16.json',
ss_dec_path: Optional[str] = None,
ss_dec_ckpt: Optional[str] = None,
**kwargs
):
super().__init__(*args, **kwargs)
self.ss_dec = None
self.pretrained_ss_dec = pretrained_ss_dec
self.ss_dec_path = ss_dec_path
self.ss_dec_ckpt = ss_dec_ckpt
def _loading_ss_dec(self):
if self.ss_dec is not None:
return
if self.ss_dec_path is not None:
cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_ss_dec)
self.ss_dec = decoder.cuda().eval()
def _delete_ss_dec(self):
del self.ss_dec
self.ss_dec = None
@torch.no_grad()
def decode_latent(self, z, batch_size=4):
self._loading_ss_dec()
ss = []
if self.normalization:
z = z * self.std.to(z.device) + self.mean.to(z.device)
for i in range(0, z.shape[0], batch_size):
ss.append(self.ss_dec(z[i:i+batch_size]))
ss = torch.cat(ss, dim=0)
self._delete_ss_dec()
return ss
@torch.no_grad()
def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
x_0 = self.decode_latent(x_0.cuda())
renderer = VoxelRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.ssaa = 4
# build camera
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
yaw_offset = -16 / 180 * np.pi
yaw = [y + yaw_offset for y in yaw]
pitch = [20 / 180 * np.pi for _ in range(4)]
exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
images = []
# Build each representation
x_0 = x_0.cuda()
for i in range(x_0.shape[0]):
coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
resolution = x_0.shape[-1]
color = coords / resolution
rep = Voxel(
origin=[-0.5, -0.5, -0.5],
voxel_size=1/resolution,
coords=coords,
attrs=color,
layout={
'color': slice(0, 3),
}
)
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(rep, ext, intr, colors_overwrite=color)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images.append(image)
return torch.stack(images)
class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
"""
Sparse structure latent dataset
Args:
roots (str): path to the dataset
min_aesthetic_score (float): minimum aesthetic score
normalization (dict): normalization stats
pretrained_ss_dec (str): name of the pretrained sparse structure decoder
ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
"""
def __init__(self,
roots: str,
*,
min_aesthetic_score: float = 5.0,
normalization: Optional[dict] = None,
pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
ss_dec_path: Optional[str] = None,
ss_dec_ckpt: Optional[str] = None,
):
self.min_aesthetic_score = min_aesthetic_score
self.normalization = normalization
self.value_range = (0, 1)
super().__init__(
roots,
pretrained_ss_dec=pretrained_ss_dec,
ss_dec_path=ss_dec_path,
ss_dec_ckpt=ss_dec_ckpt,
)
if self.normalization is not None:
self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata['ss_latent_encoded'] == True]
stats['With latent'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
latent = np.load(os.path.join(root['ss_latent'], f'{instance}.npz'))
z = torch.tensor(latent['z']).float()
if self.normalization is not None:
z = (z - self.mean) / self.std
pack = {
'x_0': z,
}
return pack
class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent):
"""
Image-conditioned sparse structure dataset
"""
pass
\ No newline at end of file
import os
import io
from typing import Union
import numpy as np
import pickle
import torch
from PIL import Image
import o_voxel
import utils3d
from .components import StandardDatasetBase
from ..modules import sparse as sp
from ..renderers import VoxelRenderer
from ..representations import Voxel
from ..representations.mesh import MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
from ..utils.data_utils import load_balanced_group_indices
def is_power_of_two(n: int) -> bool:
return n > 0 and (n & (n - 1)) == 0
def nearest_power_of_two(n: int) -> int:
if n < 1:
raise ValueError("n must be >= 1")
if is_power_of_two(n):
return n
lower = 2 ** (n.bit_length() - 1)
upper = 2 ** n.bit_length()
if n - lower < upper - n:
return lower
else:
return upper
class SparseVoxelPbrVisMixin:
@torch.no_grad()
def visualize_sample(self, x: Union[sp.SparseTensor, dict]):
x = x if isinstance(x, sp.SparseTensor) else x['x']
renderer = VoxelRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.ssaa = 4
# Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
exts = []
ints = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(30)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
exts.append(extrinsics)
ints.append(intrinsics)
images = {k: [] for k in self.layout}
# Build each representation
x = x.cuda()
for i in range(x.shape[0]):
rep = Voxel(
origin=[-0.5, -0.5, -0.5],
voxel_size=1/self.resolution,
coords=x[i].coords[:, 1:].contiguous(),
attrs=None,
layout={
'color': slice(0, 3),
}
)
for k in self.layout:
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
attr = x[i].feats[:, self.layout[k]].expand(-1, 3)
res = renderer.render(rep, ext, intr, colors_overwrite=attr)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images[k].append(image)
for k in self.layout:
images[k] = torch.stack(images[k])
return images
class SparseVoxelPbrDataset(SparseVoxelPbrVisMixin, StandardDatasetBase):
"""
Sparse Voxel PBR dataset.
Args:
roots (str): path to the dataset
resolution (int): resolution of the voxel grid
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
"""
def __init__(
self,
roots,
resolution: int = 1024,
max_active_voxels: int = 1000000,
max_num_faces: int = None,
min_aesthetic_score: float = 5.0,
attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'],
with_mesh: bool = True,
):
self.resolution = resolution
self.min_aesthetic_score = min_aesthetic_score
self.max_active_voxels = max_active_voxels
self.max_num_faces = max_num_faces
self.with_mesh = with_mesh
self.value_range = (-1, 1)
self.channels = {
'base_color': 3,
'metallic': 1,
'roughness': 1,
'emissive': 3,
'alpha': 1,
}
self.layout = {}
start = 0
for attr in attrs:
self.layout[attr] = slice(start, start + self.channels[attr])
start += self.channels[attr]
super().__init__(roots)
self.loads = [self.metadata.loc[sha256, f'num_pbr_voxels'] for _, sha256 in self.instances]
def __str__(self):
lines = [
super().__str__(),
f' - Resolution: {self.resolution}',
f' - Attributes: {list(self.layout.keys())}',
]
return '\n'.join(lines)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata['pbr_voxelized'] == True]
stats['PBR Voxelized'] = len(metadata)
if self.min_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
metadata = metadata[metadata['num_pbr_voxels'] <= self.max_active_voxels]
stats[f'Active voxels <= {self.max_active_voxels}'] = len(metadata)
if self.max_num_faces is not None:
metadata = metadata[metadata['num_faces'] <= self.max_num_faces]
stats[f'Faces <= {self.max_num_faces}'] = len(metadata)
return metadata, stats
@staticmethod
def _texture_from_dump(pack) -> Texture:
png_bytes = pack['image']
image = Image.open(io.BytesIO(png_bytes))
if image.width != image.height or not is_power_of_two(image.width):
size = nearest_power_of_two(max(image.width, image.height))
image = image.resize((size, size), Image.LANCZOS)
texture = torch.tensor(np.array(image) / 255.0, dtype=torch.float32).reshape(image.height, image.width, -1)
filter_mode = {
'Linear': TextureFilterMode.LINEAR,
'Closest': TextureFilterMode.CLOSEST,
'Cubic': TextureFilterMode.LINEAR,
'Smart': TextureFilterMode.LINEAR,
}[pack['interpolation']]
wrap_mode = {
'REPEAT': TextureWrapMode.REPEAT,
'EXTEND': TextureWrapMode.CLAMP_TO_EDGE,
'CLIP': TextureWrapMode.CLAMP_TO_EDGE,
'MIRROR': TextureWrapMode.MIRRORED_REPEAT,
}[pack['extension']]
return Texture(texture, filter_mode=filter_mode, wrap_mode=wrap_mode)
def read_mesh_with_texture(self, root, instance):
with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f:
dump = pickle.load(f)
# Fix dump alpha map
for mat in dump['materials']:
if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE':
mat['alphaMode'] = 'BLEND'
# process material
materials = []
for mat in dump['materials']:
materials.append(PbrMaterial(
base_color_texture=self._texture_from_dump(mat['baseColorTexture']) if mat['baseColorTexture'] is not None else None,
base_color_factor=mat['baseColorFactor'],
metallic_texture=self._texture_from_dump(mat['metallicTexture']) if mat['metallicTexture'] is not None else None,
metallic_factor=mat['metallicFactor'],
roughness_texture=self._texture_from_dump(mat['roughnessTexture']) if mat['roughnessTexture'] is not None else None,
roughness_factor=mat['roughnessFactor'],
alpha_texture=self._texture_from_dump(mat['alphaTexture']) if mat['alphaTexture'] is not None else None,
alpha_factor=mat['alphaFactor'],
alpha_mode={
'OPAQUE': AlphaMode.OPAQUE,
'MASK': AlphaMode.MASK,
'BLEND': AlphaMode.BLEND,
}[mat['alphaMode']],
alpha_cutoff=mat['alphaCutoff'],
))
materials.append(PbrMaterial(
base_color_factor=[0.8, 0.8, 0.8],
alpha_factor=1.0,
metallic_factor=0.0,
roughness_factor=0.5,
alpha_mode=AlphaMode.OPAQUE,
alpha_cutoff=0.5,
)) # append default material
# process mesh
start = 0
vertices = []
faces = []
material_ids = []
uv_coords = []
for obj in dump['objects']:
if obj['vertices'].size == 0 or obj['faces'].size == 0:
continue
vertices.append(obj['vertices'])
faces.append(obj['faces'] + start)
obj['mat_ids'][obj['mat_ids'] == -1] = len(materials) - 1
material_ids.append(obj['mat_ids'])
uv_coords.append(obj['uvs'] if obj['uvs'] is not None else np.zeros((obj['faces'].shape[0], 3, 2), dtype=np.float32))
start += len(obj['vertices'])
vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float()
faces = torch.from_numpy(np.concatenate(faces, axis=0)).long()
material_ids = torch.from_numpy(np.concatenate(material_ids, axis=0)).long()
uv_coords = torch.from_numpy(np.concatenate(uv_coords, axis=0)).float()
# Normalize vertices
vertices_min = vertices.min(dim=0)[0]
vertices_max = vertices.max(dim=0)[0]
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
vertices = (vertices - center) * scale
assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
return {'mesh': [MeshWithPbrMaterial(
vertices=vertices,
faces=faces,
material_ids=material_ids,
uv_coords=uv_coords,
materials=materials,
)]}
def read_pbr_voxel(self, root, instance):
coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4)
feats = torch.concat([attr[k] for k in self.layout], dim=-1) / 255.0 * 2 - 1
x = sp.SparseTensor(
feats.float(),
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
)
return {'x': x}
def get_instance(self, root, instance):
if self.with_mesh:
mesh = self.read_mesh_with_texture(root['pbr_dump'], instance)
pbr_voxel = self.read_pbr_voxel(root['pbr_voxel'], instance)
return {**mesh, **pbr_voxel}
else:
return self.read_pbr_voxel(root['pbr_voxel'], instance)
@staticmethod
def collate_fn(batch, split_size=None):
if split_size is None:
group_idx = [list(range(len(batch)))]
else:
group_idx = load_balanced_group_indices([b['x'].feats.shape[0] for b in batch], split_size)
packs = []
for group in group_idx:
sub_batch = [batch[i] for i in group]
pack = {}
keys = [k for k in sub_batch[0].keys()]
for k in keys:
if isinstance(sub_batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in sub_batch])
elif isinstance(sub_batch[0][k], sp.SparseTensor):
pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0)
elif isinstance(sub_batch[0][k], list):
pack[k] = sum([b[k] for b in sub_batch], [])
else:
pack[k] = [b[k] for b in sub_batch]
packs.append(pack)
if split_size is None:
return packs[0]
return packs
import json
import os
from typing import *
import numpy as np
import torch
import utils3d.torch
from .components import StandardDatasetBase, ImageConditionedMixin
from ..modules.sparse.basic import SparseTensor
from .. import models
from ..utils.render_utils import get_renderer
from ..utils.data_utils import load_balanced_group_indices
class SLatVisMixin:
def __init__(
self,
*args,
pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
slat_dec_path: Optional[str] = None,
slat_dec_ckpt: Optional[str] = None,
**kwargs
):
super().__init__(*args, **kwargs)
self.slat_dec = None
self.pretrained_slat_dec = pretrained_slat_dec
self.slat_dec_path = slat_dec_path
self.slat_dec_ckpt = slat_dec_ckpt
def _loading_slat_dec(self):
if self.slat_dec is not None:
return
if self.slat_dec_path is not None:
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_slat_dec)
self.slat_dec = decoder.cuda().eval()
def _delete_slat_dec(self):
del self.slat_dec
self.slat_dec = None
@torch.no_grad()
def decode_latent(self, z, batch_size=4):
self._loading_slat_dec()
reps = []
if self.normalization is not None:
z = z * self.std.to(z.device) + self.mean.to(z.device)
for i in range(0, z.shape[0], batch_size):
reps.append(self.slat_dec(z[i:i+batch_size]))
reps = sum(reps, [])
self._delete_slat_dec()
return reps
@torch.no_grad()
def visualize_sample(self, x_0: Union[SparseTensor, dict]):
x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
reps = self.decode_latent(x_0.cuda())
# Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
exts = []
ints = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(40)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
exts.append(extrinsics)
ints.append(intrinsics)
renderer = get_renderer(reps[0])
images = []
for representation in reps:
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(representation, ext, intr)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images.append(image)
images = torch.stack(images)
return images
class SLat(SLatVisMixin, StandardDatasetBase):
"""
structured latent V2 dataset
Args:
roots (str): path to the dataset
min_aesthetic_score (float): minimum aesthetic score
max_tokens (int): maximum number of tokens
latent_key (str): key of the latent to be used
normalization (dict): normalization stats
pretrained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
"""
def __init__(self,
roots: str,
*,
min_aesthetic_score: float = 5.0,
max_tokens: int = 32768,
latent_key: str = 'shape_latent',
normalization: Optional[dict] = None,
pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
slat_dec_path: Optional[str] = None,
slat_dec_ckpt: Optional[str] = None,
):
self.normalization = normalization
self.min_aesthetic_score = min_aesthetic_score
self.max_tokens = max_tokens
self.latent_key = latent_key
self.value_range = (0, 1)
super().__init__(
roots,
pretrained_slat_dec=pretrained_slat_dec,
slat_dec_path=slat_dec_path,
slat_dec_ckpt=slat_dec_ckpt,
)
self.loads = [self.metadata.loc[sha256, f'{latent_key}_tokens'] for _, sha256 in self.instances]
if self.normalization is not None:
self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True]
stats['With latent'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
metadata = metadata[metadata[f'{self.latent_key}_tokens'] <= self.max_tokens]
stats[f'Num tokens <= {self.max_tokens}'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
data = np.load(os.path.join(root[self.latent_key], f'{instance}.npz'))
coords = torch.tensor(data['coords']).int()
feats = torch.tensor(data['feats']).float()
if self.normalization is not None:
feats = (feats - self.mean) / self.std
return {
'coords': coords,
'feats': feats,
}
@staticmethod
def collate_fn(batch, split_size=None):
if split_size is None:
group_idx = [list(range(len(batch)))]
else:
group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
packs = []
for group in group_idx:
sub_batch = [batch[i] for i in group]
pack = {}
coords = []
feats = []
layout = []
start = 0
for i, b in enumerate(sub_batch):
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
feats.append(b['feats'])
layout.append(slice(start, start + b['coords'].shape[0]))
start += b['coords'].shape[0]
coords = torch.cat(coords)
feats = torch.cat(feats)
pack['x_0'] = SparseTensor(
coords=coords,
feats=feats,
)
pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
pack['x_0'].register_spatial_cache('layout', layout)
# collate other data
keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
for k in keys:
if isinstance(sub_batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in sub_batch])
elif isinstance(sub_batch[0][k], list):
pack[k] = sum([b[k] for b in sub_batch], [])
else:
pack[k] = [b[k] for b in sub_batch]
packs.append(pack)
if split_size is None:
return packs[0]
return packs
class ImageConditionedSLat(ImageConditionedMixin, SLat):
"""
Image conditioned structured latent dataset
"""
pass
import os
import json
from typing import *
import numpy as np
import torch
from .. import models
from .components import ImageConditionedMixin
from ..modules.sparse import SparseTensor
from .structured_latent import SLatVisMixin, SLat
from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics
class SLatShapeVisMixin(SLatVisMixin):
def _loading_slat_dec(self):
if self.slat_dec is not None:
return
if self.slat_dec_path is not None:
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_slat_dec)
decoder.set_resolution(self.resolution)
self.slat_dec = decoder.cuda().eval()
@torch.no_grad()
def visualize_sample(self, x_0: Union[SparseTensor, dict]):
x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
reps = self.decode_latent(x_0.cuda())
# build camera
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
yaw_offset = -16 / 180 * np.pi
yaw = [y + yaw_offset for y in yaw]
pitch = [20 / 180 * np.pi for _ in range(4)]
exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
# render
renderer = get_renderer(reps[0])
images = []
for representation in reps:
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(representation, ext, intr)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['normal']
images.append(image)
images = torch.stack(images)
return images
class SLatShape(SLatShapeVisMixin, SLat):
"""
structured latent for shape generation
Args:
roots (str): path to the dataset
resolution (int): resolution of the shape
min_aesthetic_score (float): minimum aesthetic score
max_tokens (int): maximum number of tokens
latent_key (str): key of the latent to be used
normalization (dict): normalization stats
pretrained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
"""
def __init__(self,
roots: str,
*,
resolution: int,
min_aesthetic_score: float = 5.0,
max_tokens: int = 32768,
normalization: Optional[dict] = None,
pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
slat_dec_path: Optional[str] = None,
slat_dec_ckpt: Optional[str] = None,
):
super().__init__(
roots,
min_aesthetic_score=min_aesthetic_score,
max_tokens=max_tokens,
latent_key='shape_latent',
normalization=normalization,
pretrained_slat_dec=pretrained_slat_dec,
slat_dec_path=slat_dec_path,
slat_dec_ckpt=slat_dec_ckpt,
)
self.resolution = resolution
class ImageConditionedSLatShape(ImageConditionedMixin, SLatShape):
"""
Image conditioned structured latent for shape generation
"""
pass
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
import json
from typing import *
import numpy as np
import torch
import cv2
from .. import models
from .components import StandardDatasetBase, ImageConditionedMixin
from ..modules.sparse import SparseTensor, sparse_cat
from ..representations import MeshWithVoxel
from ..renderers import PbrMeshRenderer, EnvMap
from ..utils.data_utils import load_balanced_group_indices
from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics
class SLatPbrVisMixin:
def __init__(
self,
*args,
pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16',
pbr_slat_dec_path: Optional[str] = None,
pbr_slat_dec_ckpt: Optional[str] = None,
pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
shape_slat_dec_path: Optional[str] = None,
shape_slat_dec_ckpt: Optional[str] = None,
**kwargs
):
super().__init__(*args, **kwargs)
self.pbr_slat_dec = None
self.pretrained_pbr_slat_dec = pretrained_pbr_slat_dec
self.pbr_slat_dec_path = pbr_slat_dec_path
self.pbr_slat_dec_ckpt = pbr_slat_dec_ckpt
self.shape_slat_dec = None
self.pretrained_shape_slat_dec = pretrained_shape_slat_dec
self.shape_slat_dec_path = shape_slat_dec_path
self.shape_slat_dec_ckpt = shape_slat_dec_ckpt
def _loading_slat_dec(self):
if self.pbr_slat_dec is not None and self.shape_slat_dec is not None:
return
if self.pbr_slat_dec_path is not None:
cfg = json.load(open(os.path.join(self.pbr_slat_dec_path, 'config.json'), 'r'))
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.pbr_slat_dec_path, 'ckpts', f'decoder_{self.pbr_slat_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_pbr_slat_dec)
self.pbr_slat_dec = decoder.cuda().eval()
if self.shape_slat_dec_path is not None:
cfg = json.load(open(os.path.join(self.shape_slat_dec_path, 'config.json'), 'r'))
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.shape_slat_dec_path, 'ckpts', f'decoder_{self.shape_slat_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_shape_slat_dec)
decoder.set_resolution(self.resolution)
self.shape_slat_dec = decoder.cuda().eval()
def _delete_slat_dec(self):
del self.pbr_slat_dec
self.pbr_slat_dec = None
del self.shape_slat_dec
self.shape_slat_dec = None
@torch.no_grad()
def decode_latent(self, z, shape_z, batch_size=4):
self._loading_slat_dec()
reps = []
if self.shape_slat_normalization is not None:
shape_z = shape_z * self.shape_slat_std.to(z.device) + self.shape_slat_mean.to(z.device)
if self.pbr_slat_normalization is not None:
z = z * self.pbr_slat_std.to(z.device) + self.pbr_slat_mean.to(z.device)
for i in range(0, z.shape[0], batch_size):
mesh, subs = self.shape_slat_dec(shape_z[i:i+batch_size], return_subs=True)
vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs) * 0.5 + 0.5
reps.extend([
MeshWithVoxel(
m.vertices, m.faces,
origin = [-0.5, -0.5, -0.5],
voxel_size = 1 / self.resolution,
coords = v.coords[:, 1:],
attrs = v.feats,
voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
layout = self.layout,
)
for m, v in zip(mesh, vox)
])
self._delete_slat_dec()
return reps
@torch.no_grad()
def visualize_sample(self, sample: dict):
shape_z = sample['concat_cond'].cuda()
z = sample['x_0'].cuda()
reps = self.decode_latent(z, shape_z)
# build camera
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
yaw_offset = -16 / 180 * np.pi
yaw = [y + yaw_offset for y in yaw]
pitch = [20 / 180 * np.pi for _ in range(4)]
exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
# render
renderer = PbrMeshRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.near = 1
renderer.rendering_options.far = 100
renderer.rendering_options.ssaa = 2
renderer.rendering_options.peel_layers = 8
envmap = EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
))
images = {}
for representation in reps:
image = {}
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(representation, ext, intr, envmap=envmap)
for k, v in res.items():
if k not in images:
images[k] = []
if k not in image:
image[k] = torch.zeros(3, 1024, 1024).cuda()
image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = v
for k in images.keys():
images[k].append(image[k])
for k in images.keys():
images[k] = torch.stack(images[k], dim=0)
return images
class SLatPbr(SLatPbrVisMixin, StandardDatasetBase):
"""
structured latent for sparse voxel pbr dataset
Args:
roots (str): path to the dataset
latent_key (str): key of the latent to be used
min_aesthetic_score (float): minimum aesthetic score
normalization (dict): normalization stats
resolution (int): resolution of decoded sparse voxel
attrs (list): attributes to be decoded
pretained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
"""
def __init__(self,
roots: str,
*,
resolution: int,
min_aesthetic_score: float = 5.0,
max_tokens: int = 32768,
full_pbr: bool = False,
pbr_slat_normalization: Optional[dict] = None,
shape_slat_normalization: Optional[dict] = None,
attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'],
pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16',
pbr_slat_dec_path: Optional[str] = None,
pbr_slat_dec_ckpt: Optional[str] = None,
pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
shape_slat_dec_path: Optional[str] = None,
shape_slat_dec_ckpt: Optional[str] = None,
**kwargs
):
self.resolution = resolution
self.pbr_slat_normalization = pbr_slat_normalization
self.shape_slat_normalization = shape_slat_normalization
self.min_aesthetic_score = min_aesthetic_score
self.max_tokens = max_tokens
self.full_pbr = full_pbr
self.value_range = (0, 1)
super().__init__(
roots,
pretrained_pbr_slat_dec=pretrained_pbr_slat_dec,
pbr_slat_dec_path=pbr_slat_dec_path,
pbr_slat_dec_ckpt=pbr_slat_dec_ckpt,
pretrained_shape_slat_dec=pretrained_shape_slat_dec,
shape_slat_dec_path=shape_slat_dec_path,
shape_slat_dec_ckpt=shape_slat_dec_ckpt,
**kwargs
)
self.loads = [self.metadata.loc[sha256, 'pbr_latent_tokens'] for _, sha256 in self.instances]
if self.pbr_slat_normalization is not None:
self.pbr_slat_mean = torch.tensor(self.pbr_slat_normalization['mean']).reshape(1, -1)
self.pbr_slat_std = torch.tensor(self.pbr_slat_normalization['std']).reshape(1, -1)
if self.shape_slat_normalization is not None:
self.shape_slat_mean = torch.tensor(self.shape_slat_normalization['mean']).reshape(1, -1)
self.shape_slat_std = torch.tensor(self.shape_slat_normalization['std']).reshape(1, -1)
self.attrs = attrs
self.channels = {
'base_color': 3,
'metallic': 1,
'roughness': 1,
'emissive': 3,
'alpha': 1,
}
self.layout = {}
start = 0
for attr in attrs:
self.layout[attr] = slice(start, start + self.channels[attr])
start += self.channels[attr]
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata['pbr_latent_encoded'] == True]
stats['With PBR latent'] = len(metadata)
metadata = metadata[metadata['shape_latent_encoded'] == True]
stats['With shape latent'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
metadata = metadata[metadata['pbr_latent_tokens'] <= self.max_tokens]
stats[f'Num tokens <= {self.max_tokens}'] = len(metadata)
if self.full_pbr:
metadata = metadata[metadata['num_basecolor_tex'] > 0]
metadata = metadata[metadata['num_metallic_tex'] > 0]
metadata = metadata[metadata['num_roughness_tex'] > 0]
stats['Full PBR'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
# PBR latent
data = np.load(os.path.join(root['pbr_latent'], f'{instance}.npz'))
coords = torch.tensor(data['coords']).int()
coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1)
feats = torch.tensor(data['feats']).float()
if self.pbr_slat_normalization is not None:
feats = (feats - self.pbr_slat_mean) / self.pbr_slat_std
pbr_z = SparseTensor(feats, coords)
# Shape latent
data = np.load(os.path.join(root['shape_latent'], f'{instance}.npz'))
coords = torch.tensor(data['coords']).int()
coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1)
feats = torch.tensor(data['feats']).float()
if self.shape_slat_normalization is not None:
feats = (feats - self.shape_slat_mean) / self.shape_slat_std
shape_z = SparseTensor(feats, coords)
assert torch.equal(shape_z.coords, pbr_z.coords), \
f"Shape latent and PBR latent have different coordinates: {shape_z.coords.shape} vs {pbr_z.coords.shape}"
return {
'x_0': pbr_z,
'concat_cond': shape_z,
}
@staticmethod
def collate_fn(batch, split_size=None):
if split_size is None:
group_idx = [list(range(len(batch)))]
else:
group_idx = load_balanced_group_indices([b['x_0'].feats.shape[0] for b in batch], split_size)
packs = []
for group in group_idx:
sub_batch = [batch[i] for i in group]
pack = {}
keys = [k for k in sub_batch[0].keys()]
for k in keys:
if isinstance(sub_batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in sub_batch])
elif isinstance(sub_batch[0][k], SparseTensor):
pack[k] = sparse_cat([b[k] for b in sub_batch], dim=0)
elif isinstance(sub_batch[0][k], list):
pack[k] = sum([b[k] for b in sub_batch], [])
else:
pack[k] = [b[k] for b in sub_batch]
packs.append(pack)
if split_size is None:
return packs[0]
return packs
class ImageConditionedSLatPbr(ImageConditionedMixin, SLatPbr):
"""
Image conditioned structured latent dataset
"""
pass
import importlib
__attributes = {
# Sparse Structure
'SparseStructureEncoder': 'sparse_structure_vae',
'SparseStructureDecoder': 'sparse_structure_vae',
'SparseStructureFlowModel': 'sparse_structure_flow',
# SLat Generation
'SLatFlowModel': 'structured_latent_flow',
'ElasticSLatFlowModel': 'structured_latent_flow',
# SC-VAEs
'SparseUnetVaeEncoder': 'sc_vaes.sparse_unet_vae',
'SparseUnetVaeDecoder': 'sc_vaes.sparse_unet_vae',
'FlexiDualGridVaeEncoder': 'sc_vaes.fdg_vae',
'FlexiDualGridVaeDecoder': 'sc_vaes.fdg_vae'
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
def from_pretrained(path: str, **kwargs):
"""
Load a model from a pretrained checkpoint.
Args:
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
**kwargs: Additional arguments for the model constructor.
"""
import os
import json
from safetensors.torch import load_file
is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
if is_local:
config_file = f"{path}.json"
model_file = f"{path}.safetensors"
else:
from huggingface_hub import hf_hub_download
path_parts = path.split('/')
repo_id = f'{path_parts[0]}/{path_parts[1]}'
model_name = '/'.join(path_parts[2:])
config_file = hf_hub_download(repo_id, f"{model_name}.json")
model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
with open(config_file, 'r') as f:
config = json.load(f)
model = __getattr__(config['name'])(**config['args'], **kwargs)
model.load_state_dict(load_file(model_file), strict=False)
return model
# For Pylance
if __name__ == '__main__':
from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
from .sparse_structure_flow import SparseStructureFlowModel
from .structured_latent_flow import SLatFlowModel, ElasticSLatFlowModel
from .sc_vaes.sparse_unet_vae import SparseUnetVaeEncoder, SparseUnetVaeDecoder
from .sc_vaes.fdg_vae import FlexiDualGridVaeEncoder, FlexiDualGridVaeDecoder
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import sparse as sp
from ...utils.pipeline_logger import get_logger
from .sparse_unet_vae import (
SparseResBlock3d,
SparseConvNeXtBlock3d,
SparseResBlockDownsample3d,
SparseResBlockUpsample3d,
SparseResBlockS2C3d,
SparseResBlockC2S3d,
)
from .sparse_unet_vae import (
SparseUnetVaeEncoder,
SparseUnetVaeDecoder,
)
from ...representations import Mesh
from o_voxel.convert import flexible_dual_grid_to_mesh
class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder):
def __init__(
self,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
down_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
):
super().__init__(
6,
model_channels,
latent_channels,
num_blocks,
block_type,
down_block_type,
block_args,
use_fp16,
)
def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False):
x = vertices.replace(torch.cat([
vertices.feats - 0.5,
intersected.feats.float() - 0.5,
], dim=1))
return super().forward(x, sample_posterior, return_raw)
class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
def __init__(
self,
resolution: int,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
up_block_type: List[str],
block_args: List[Dict[str, Any]],
voxel_margin: float = 0.5,
use_fp16: bool = False,
):
self.resolution = resolution
self.voxel_margin = voxel_margin
super().__init__(
7,
model_channels,
latent_channels,
num_blocks,
block_type,
up_block_type,
block_args,
use_fp16,
)
def set_resolution(self, resolution: int) -> None:
self.resolution = resolution
def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs):
decoded = super().forward(x, **kwargs)
if self.training:
h, subs_gt, subs = decoded
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
intersected_logits = h.replace(h.feats[..., 3:6])
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
mesh = [Mesh(*flexible_dual_grid_to_mesh(
v.coords[:, 1:], v.feats, i.feats, q.feats,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
grid_size=self.resolution,
train=True
)) for v, i, q in zip(vertices, gt_intersected, quad_lerp)]
return mesh, vertices, intersected_logits, subs_gt, subs
else:
out_list = list(decoded) if isinstance(decoded, tuple) else [decoded]
h = out_list[0]
get_logger().debug(f"post-forward dtype={h.feats.dtype} has_nan={torch.isnan(h.feats).any()}")
get_logger().debug(f"DEBUG 1: VAE output h.feats has NaNs: {torch.isnan(h.feats).any().item()}")
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
intersected = h.replace(h.feats[..., 3:6] > 0)
get_logger().debug(f"DEBUG INTERSECTED: total={intersected.feats.shape[0]}, "
f"true={intersected.feats.any(dim=-1).sum().item()}, "
f"ratio={intersected.feats.any(dim=-1).float().mean():.3f}")
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
mesh = [Mesh(*flexible_dual_grid_to_mesh(
v.coords[:, 1:], v.feats, i.feats, q.feats,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
grid_size=self.resolution,
train=False
)) for v, i, q in zip(vertices, intersected, quad_lerp)]
get_logger().debug(f"DEBUG 2: o_voxel mesh[0] vertices has NaNs: {torch.isnan(mesh[0].vertices).any().item()}")
out_list[0] = mesh
return out_list[0] if len(out_list) == 1 else tuple(out_list)
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...modules.utils import convert_module_to_f16, convert_module_to_bf16, convert_module_to_f32, zero_module
from ...modules import sparse as sp
from ...modules.sparse.linear import rocm_safe_linear, ROCM_SAFE_CHUNK
from ...modules.norm import LayerNorm32
from ...utils.pipeline_logger import get_logger
class SparseResBlock3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
downsample: bool = False,
upsample: bool = False,
resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest',
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.downsample = downsample
self.upsample = upsample
self.resample_mode = resample_mode
self.use_checkpoint = use_checkpoint
assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
if resample_mode == 'nearest':
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
elif resample_mode =='spatial2channel' and not self.downsample:
self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3)
elif resample_mode =='spatial2channel' and self.downsample:
self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
if resample_mode == 'nearest':
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
elif resample_mode =='spatial2channel' and self.downsample:
self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
elif resample_mode =='spatial2channel' and not self.downsample:
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
self.updown = None
if self.downsample:
if resample_mode == 'nearest':
self.updown = sp.SparseDownsample(2)
elif resample_mode =='spatial2channel':
self.updown = sp.SparseSpatial2Channel(2)
elif self.upsample:
self.to_subdiv = sp.SparseLinear(channels, 8)
if resample_mode == 'nearest':
self.updown = sp.SparseUpsample(2)
elif resample_mode =='spatial2channel':
self.updown = sp.SparseChannel2Spatial(2)
def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.downsample:
x = self.updown(x)
elif self.upsample:
x = self.updown(x, subdiv.replace(subdiv.feats > 0))
return x
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
subdiv = None
if self.upsample:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
if self.resample_mode == 'spatial2channel':
h = self.conv1(h)
h = self._updown(h, subdiv)
x = self._updown(x, subdiv)
if self.resample_mode == 'nearest':
h = self.conv1(h)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
if self.upsample:
return h, subdiv
return h
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseResBlockDownsample3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
self.updown = sp.SparseDownsample(2)
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.updown(h)
x = self.updown(x)
h = self.conv1(h)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
return h
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseResBlockUpsample3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.pred_subdiv = pred_subdiv
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
if self.pred_subdiv:
self.to_subdiv = sp.SparseLinear(channels, 8)
self.updown = sp.SparseUpsample(2)
def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.pred_subdiv:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = self.conv1(h)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
if self.pred_subdiv:
return h, subdiv
else:
return h
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseResBlockS2C3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
self.updown = sp.SparseSpatial2Channel(2)
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
h = self.updown(h)
x = self.updown(x)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
return h
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseResBlockC2S3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.pred_subdiv = pred_subdiv
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
if pred_subdiv:
self.to_subdiv = sp.SparseLinear(channels, 8)
self.updown = sp.SparseChannel2Spatial(2)
def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.pred_subdiv:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
# ROCm: cast to fp32 before threshold - bf16 trained weights produce shifted logits\n
subdiv_binarized = subdiv.replace(subdiv.feats.float() > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
if self.pred_subdiv:
return h, subdiv
else:
return h
def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False)
else:
return self._forward(x, subdiv)
class SparseConvNeXtBlock3d(nn.Module):
def __init__(
self,
channels: int,
mlp_ratio: float = 4.0,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.use_checkpoint = use_checkpoint
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.conv = sp.SparseConv3d(channels, channels, 3)
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.SiLU(),
zero_module(nn.Linear(int(channels * mlp_ratio), channels)),
)
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = self.conv(x)
h = h.replace(self.norm(h.feats))
# ROCm GFX1201 bug workaround: chunk MLP (two nn.Linear layers inside) for large N
# The MLP is row-independent so chunking is exact, not an approximation
feats = h.feats
N = feats.shape[0]
if N <= ROCM_SAFE_CHUNK:
h = h.replace(self.mlp(feats))
else:
out = torch.empty_like(feats)
for s in range(0, N, ROCM_SAFE_CHUNK):
e = min(s + ROCM_SAFE_CHUNK, N)
out[s:e] = self.mlp(feats[s:e])
h = h.replace(out)
return h + x
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseUnetVaeEncoder(nn.Module):
"""
Sparse Swin Transformer Unet VAE model.
"""
def __init__(
self,
in_channels: int,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
down_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = sp.SparseLinear(in_channels, model_channels[0])
self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels)
self.blocks = nn.ModuleList([])
for i in range(len(num_blocks)):
self.blocks.append(nn.ModuleList([]))
for j in range(num_blocks[i]):
self.blocks[-1].append(
globals()[block_type[i]](
model_channels[i],
**block_args[i],
)
)
if i < len(num_blocks) - 1:
self.blocks[-1].append(
globals()[down_block_type[i]](
model_channels[i],
model_channels[i+1],
**block_args[i],
)
)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16 (actually bfloat16 for ROCm stability).
"""
self.blocks.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.blocks.apply(convert_module_to_f32)
def initialize_weights(self) -> None:
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False):
h = self.input_layer(x)
h = h.type(self.dtype)
for i, res in enumerate(self.blocks):
for j, block in enumerate(res):
h = block(h)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.to_latent(h)
# Sample from the posterior distribution
mean, logvar = h.feats.chunk(2, dim=-1)
if sample_posterior:
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
else:
z = mean
z = h.replace(z)
if return_raw:
return z, mean, logvar
else:
return z
class SparseUnetVaeDecoder(nn.Module):
"""
Sparse Swin Transformer Unet VAE model.
"""
def __init__(
self,
out_channels: int,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
up_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.out_channels = out_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.use_fp16 = use_fp16
self.pred_subdiv = pred_subdiv
self.dtype = torch.float16 if use_fp16 else torch.float32
self.low_vram = False
self.output_layer = sp.SparseLinear(model_channels[-1], out_channels)
self.from_latent = sp.SparseLinear(latent_channels, model_channels[0])
self.blocks = nn.ModuleList([])
for i in range(len(num_blocks)):
self.blocks.append(nn.ModuleList([]))
for j in range(num_blocks[i]):
self.blocks[-1].append(
globals()[block_type[i]](
model_channels[i],
**block_args[i],
)
)
if i < len(num_blocks) - 1:
self.blocks[-1].append(
globals()[up_block_type[i]](
model_channels[i],
model_channels[i+1],
pred_subdiv=pred_subdiv,
**block_args[i],
)
)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16 (actually bfloat16 for ROCm stability).
"""
self.blocks.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.blocks.apply(convert_module_to_f32)
def initialize_weights(self) -> None:
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor:
assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs"
assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs"
h = self.from_latent(x)
get_logger().debug(f"DECODER from_latent: nan={torch.isnan(h.feats).any().item()} inf={torch.isinf(h.feats).any().item()} max={h.feats.float().abs().max().item():.4f} dtype={h.feats.dtype}")
h = h.type(self.dtype)
get_logger().debug(f"DECODER after dtype cast: nan={torch.isnan(h.feats).any().item()} inf={torch.isinf(h.feats).any().item()} max={h.feats.float().abs().max().item():.4f} dtype={h.feats.dtype}")
subs_gt = []
subs = []
for i, res in enumerate(self.blocks):
for j, block in enumerate(res):
if i < len(self.blocks) - 1 and j == len(res) - 1:
if self.pred_subdiv:
if self.training:
subs_gt.append(h.get_spatial_cache('subdivision'))
h, sub = block(h)
subs.append(sub)
else:
h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
else:
h = block(h)
if not torch.isfinite(h.feats).all():
print(f"FATAL: NaN/Inf at decoder block i={i} j={j} type={type(block).__name__} max={h.feats.float().abs().max().item():.4f}")
import sys; sys.exit(1)
h = h.type(x.dtype)
get_logger().debug(f"DECODER post-blocks cast: nan={torch.isnan(h.feats).any().item()} inf={torch.isinf(h.feats).any().item()} max={h.feats.float().abs().max().item():.4f} dtype={h.feats.dtype}")
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
get_logger().debug(f"DECODER post-layernorm: nan={torch.isnan(h.feats).any().item()} inf={torch.isinf(h.feats).any().item()} max={h.feats.float().abs().max().item():.4f}")
get_logger().debug(f"DECODER output_layer input: shape={h.feats.shape} stride={h.feats.stride()} contiguous={h.feats.is_contiguous()}")
get_logger().debug(f"DECODER output_layer weight: shape={self.output_layer.weight.shape} dtype={self.output_layer.weight.dtype}")
get_logger().debug(f"DECODER pre-output_layer: feats shape={h.feats.shape} contiguous={h.feats.is_contiguous()} weight shape={self.output_layer.weight.shape} weight dtype={self.output_layer.weight.dtype}")
# ROCm workaround: ensure contiguous before F.linear
h = h.replace(h.feats.contiguous())
h = self.output_layer(h)
get_logger().debug(f"DECODER post-output_layer: nan={torch.isnan(h.feats).any().item()} inf={torch.isinf(h.feats).any().item()} max={h.feats.float().abs().max().item():.4f} dtype={h.feats.dtype}")
get_logger().debug(f"DEBUG OUTPUT_LAYER: dtype={h.feats.dtype} has_nan={torch.isnan(h.feats).any().item()} max_abs={h.feats.abs().max().item() if h.feats.numel() > 0 else 0}")
if self.training and self.pred_subdiv:
return h, subs_gt, subs
else:
if return_subs:
return h, subs
else:
return h
# REPLACE WITH:
def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor:
assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling"
h = self.from_latent(x)
get_logger().debug(f"UPSAMPLE from_latent: dtype={h.feats.dtype} nan={torch.isnan(h.feats).any().item()} inf={torch.isinf(h.feats).any().item()} max={h.feats.float().abs().max().item():.4f}")
h = h.type(self.dtype)
get_logger().debug(f"UPSAMPLE after type cast to {self.dtype}: nan={torch.isnan(h.feats).any().item()} inf={torch.isinf(h.feats).any().item()} max={h.feats.float().abs().max().item():.4f}")
for i, res in enumerate(self.blocks):
if i == upsample_times:
return h.coords
for j, block in enumerate(res):
if i < len(self.blocks) - 1 and j == len(res) - 1:
h, sub = block(h)
else:
h = block(h)
if torch.isnan(h.feats).any() or torch.isinf(h.feats).any():
print(f"UPSAMPLE NaN/Inf at block i={i} j={j} type={type(block).__name__} max={h.feats.float().abs().max().item():.4f}")
break
else:
continue
break
def dump_debug(self, tag: str, tensor) -> None:
import os
os.makedirs('/tmp/trellis_debug', exist_ok=True)
path = f'/tmp/trellis_debug/{tag}.pt'
torch.save({'feats': tensor.feats.float().cpu(), 'coords': tensor.coords.cpu()}, path)
print(f"DUMPED {tag}: feats dtype={tensor.feats.dtype} shape={tensor.feats.shape} "
f"has_nan={torch.isnan(tensor.feats).any().item()} "
f"has_inf={torch.isinf(tensor.feats).any().item()} "
f"max={tensor.feats.float().abs().max().item():.4f}")
from contextlib import contextmanager
from typing import *
import math
from ..modules import sparse as sp
from ..utils.elastic_utils import ElasticModuleMixin
class SparseTransformerElasticMixin(ElasticModuleMixin):
def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
return x.feats.shape[0]
@contextmanager
def with_mem_ratio(self, mem_ratio=1.0):
if mem_ratio == 1.0:
yield 1.0
return
num_blocks = len(self.blocks)
num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
for i in range(num_blocks):
self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
yield exact_mem_ratio
for i in range(num_blocks):
self.blocks[i].use_checkpoint = False
from typing import *
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..modules.utils import convert_module_to, manual_cast, str_to_dtype
from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
from ..modules.attention import RotaryPositionEmbedder
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
dim: the dimension of the output.
max_period: controls the minimum frequency of the embeddings.
Returns:
an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class SparseStructureFlowModel(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
cond_channels: int,
out_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
pe_mode: Literal["ape", "rope"] = "ape",
rope_freq: Tuple[float, float] = (1.0, 10000.0),
dtype: str = 'float32',
use_checkpoint: bool = False,
share_mod: bool = False,
initialization: str = 'vanilla',
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
**kwargs
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.pe_mode = pe_mode
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = str_to_dtype(dtype)
self.t_embedder = TimestepEmbedder(model_channels)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(model_channels, 6 * model_channels, bias=True)
)
if pe_mode == "ape":
pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
pos_emb = pos_embedder(coords)
self.register_buffer("pos_emb", pos_emb)
elif pe_mode == "rope":
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3)
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
rope_phases = pos_embedder(coords)
self.register_buffer("rope_phases", rope_phases)
if pe_mode != "rope":
self.rope_phases = None
self.input_layer = nn.Linear(in_channels, model_channels)
self.blocks = nn.ModuleList([
ModulatedTransformerCrossBlock(
model_channels,
cond_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode='full',
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
rope_freq=rope_freq,
share_mod=share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
)
for _ in range(num_blocks)
])
self.out_layer = nn.Linear(model_channels, out_channels)
self.initialize_weights()
self.convert_to(self.dtype)
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to(self, dtype: torch.dtype) -> None:
"""
Convert the torso of the model to the specified dtype.
"""
self.dtype = dtype
self.blocks.apply(partial(convert_module_to, dtype=dtype))
def initialize_weights(self) -> None:
if self.initialization == 'vanilla':
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
if self.share_mod:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
else:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
elif self.initialization == 'scaled':
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels)))
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Scaled init for to_out and ffn2
def _scaled_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels))
if module.bias is not None:
nn.init.constant_(module.bias, 0)
for block in self.blocks:
block.self_attn.to_out.apply(_scaled_init)
block.cross_attn.to_out.apply(_scaled_init)
block.mlp.mlp[2].apply(_scaled_init)
# Initialize input layer to make the initial representation have variance 1
nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels))
nn.init.zeros_(self.input_layer.bias)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
if self.share_mod:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
else:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
h = self.input_layer(h)
if self.pe_mode == "ape":
h = h + self.pos_emb[None]
t_emb = self.t_embedder(t)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = manual_cast(t_emb, self.dtype)
h = manual_cast(h, self.dtype)
cond = manual_cast(cond, self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond, self.rope_phases)
h = manual_cast(h, x.dtype)
h = F.layer_norm(h, h.shape[-1:])
h = self.out_layer(h)
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
return h
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment