Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
from typing import *
from abc import ABC, abstractmethod
class Sampler(ABC):
"""
A base class for samplers.
"""
@abstractmethod
def sample(
self,
model,
**kwargs
):
"""
Sample from a model.
"""
pass
\ No newline at end of file
from typing import *
class ClassifierFreeGuidanceSamplerMixin:
"""
A mixin class for samplers that apply classifier-free guidance.
"""
def _inference_model(self, model, x_t, t, cond, neg_cond, guidance_strength, guidance_rescale=0.0, **kwargs):
if guidance_strength == 1:
return super()._inference_model(model, x_t, t, cond, **kwargs)
elif guidance_strength == 0:
return super()._inference_model(model, x_t, t, neg_cond, **kwargs)
else:
pred_pos = super()._inference_model(model, x_t, t, cond, **kwargs)
pred_neg = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg
# CFG rescale
if guidance_rescale > 0:
x_0_pos = self._pred_to_xstart(x_t, t, pred_pos)
x_0_cfg = self._pred_to_xstart(x_t, t, pred)
std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True)
std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True)
x_0_rescaled = x_0_cfg * (std_pos / std_cfg)
x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg
pred = self._xstart_to_pred(x_t, t, x_0)
return pred
from typing import *
import torch
import numpy as np
from tqdm import tqdm
from easydict import EasyDict as edict
from .base import Sampler
from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
class FlowEulerSampler(Sampler):
"""
Generate samples from a flow-matching model using Euler sampling.
Args:
sigma_min: The minimum scale of noise in flow.
"""
def __init__(
self,
sigma_min: float,
):
self.sigma_min = sigma_min
def _eps_to_xstart(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
def _xstart_to_eps(self, x_t, t, x_0):
assert x_t.shape == x_0.shape
return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
def _v_to_xstart_eps(self, x_t, t, v):
assert x_t.shape == v.shape
eps = (1 - t) * v + x_t
x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
return x_0, eps
def _pred_to_xstart(self, x_t, t, pred):
return (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * pred
def _xstart_to_pred(self, x_t, t, x_0):
return ((1 - self.sigma_min) * x_t - x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
def _inference_model(self, model, x_t, t, cond=None, **kwargs):
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
return model(x_t, t, cond, **kwargs)
def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
return pred_x_0, pred_eps, pred_v
@torch.no_grad()
def sample_once(
self,
model,
x_t,
t: float,
t_prev: float,
cond: Optional[Any] = None,
**kwargs
):
"""
Sample x_{t-1} from the model using Euler method.
Args:
model: The model to sample from.
x_t: The [N x C x ...] tensor of noisy inputs at time t.
t: The current timestep.
t_prev: The previous timestep.
cond: conditional information.
**kwargs: Additional arguments for model inference.
Returns:
a dict containing the following
- 'pred_x_prev': x_{t-1}.
- 'pred_x_0': a prediction of x_0.
"""
pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
pred_x_prev = x_t - (t - t_prev) * pred_v
return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
@torch.no_grad()
def sample(
self,
model,
noise,
cond: Optional[Any] = None,
steps: int = 50,
rescale_t: float = 1.0,
verbose: bool = True,
tqdm_desc: str = "Sampling",
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
verbose: If True, show a progress bar.
tqdm_desc: A customized tqdm desc.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
sample = noise
t_seq = np.linspace(1, 0, steps + 1)
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
t_seq = t_seq.tolist()
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
for t, t_prev in tqdm(t_pairs, desc=tqdm_desc, disable=not verbose):
out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
sample = out.pred_x_prev
ret.pred_x_t.append(out.pred_x_prev)
ret.pred_x_0.append(out.pred_x_0)
ret.samples = sample
return ret
class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
"""
@torch.no_grad()
def sample(
self,
model,
noise,
cond,
neg_cond,
steps: int = 50,
rescale_t: float = 1.0,
guidance_strength: float = 3.0,
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
guidance_strength: The strength of classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, **kwargs)
class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
"""
@torch.no_grad()
def sample(
self,
model,
noise,
cond,
neg_cond,
steps: int = 50,
rescale_t: float = 1.0,
guidance_strength: float = 3.0,
guidance_interval: Tuple[float, float] = (0.0, 1.0),
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
guidance_strength: The strength of classifier-free guidance.
guidance_interval: The interval for classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, guidance_interval=guidance_interval, **kwargs)
from typing import *
class GuidanceIntervalSamplerMixin:
"""
A mixin class for samplers that apply classifier-free guidance with interval.
"""
def _inference_model(self, model, x_t, t, cond, guidance_strength, guidance_interval, **kwargs):
if guidance_interval[0] <= t <= guidance_interval[1]:
return super()._inference_model(model, x_t, t, cond, guidance_strength=guidance_strength, **kwargs)
else:
return super()._inference_model(model, x_t, t, cond, guidance_strength=1, **kwargs)
from typing import *
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from .base import Pipeline
from . import samplers, rembg
from ..modules.sparse import SparseTensor
from ..modules import image_feature_extractor
from ..representations import Mesh, MeshWithVoxel
from ..utils.pipeline_logger import get_logger, log_sparse, log_mesh, log_tensor, section, elapsed
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
class Trellis2ImageTo3DPipeline(Pipeline):
"""
Pipeline for inferring Trellis2 image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
shape_slat_sampler (samplers.Sampler): The sampler for the structured latent.
tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler.
shape_slat_sampler_params (dict): The parameters for the structured latent sampler.
tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
shape_slat_normalization (dict): The normalization parameters for the structured latent.
tex_slat_normalization (dict): The normalization parameters for the texture latent.
image_cond_model (Callable): The image conditioning model.
rembg_model (Callable): The model for removing background.
low_vram (bool): Whether to use low-VRAM mode.
"""
model_names_to_load = [
'sparse_structure_flow_model',
'sparse_structure_decoder',
'shape_slat_flow_model_512',
'shape_slat_flow_model_1024',
'shape_slat_decoder',
'tex_slat_flow_model_512',
'tex_slat_flow_model_1024',
'tex_slat_decoder',
]
def __init__(
self,
models: dict[str, nn.Module] = None,
sparse_structure_sampler: samplers.Sampler = None,
shape_slat_sampler: samplers.Sampler = None,
tex_slat_sampler: samplers.Sampler = None,
sparse_structure_sampler_params: dict = None,
shape_slat_sampler_params: dict = None,
tex_slat_sampler_params: dict = None,
shape_slat_normalization: dict = None,
tex_slat_normalization: dict = None,
image_cond_model: Callable = None,
rembg_model: Callable = None,
low_vram: bool = True,
default_pipeline_type: str = '1024_cascade',
):
if models is None:
return
super().__init__(models)
self.sparse_structure_sampler = sparse_structure_sampler
self.shape_slat_sampler = shape_slat_sampler
self.tex_slat_sampler = tex_slat_sampler
self.sparse_structure_sampler_params = sparse_structure_sampler_params
self.shape_slat_sampler_params = shape_slat_sampler_params
self.tex_slat_sampler_params = tex_slat_sampler_params
self.shape_slat_normalization = shape_slat_normalization
self.tex_slat_normalization = tex_slat_normalization
self.image_cond_model = image_cond_model
self.rembg_model = rembg_model
self.low_vram = low_vram
self.default_pipeline_type = default_pipeline_type
self.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
self._device = 'cpu'
@classmethod
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2ImageTo3DPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super().from_pretrained(path, config_file)
args = pipeline._pretrained_args
pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
pipeline.shape_slat_normalization = args['shape_slat_normalization']
pipeline.tex_slat_normalization = args['tex_slat_normalization']
pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
pipeline.low_vram = args.get('low_vram', True)
pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
pipeline.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
pipeline._device = 'cpu'
return pipeline
def to(self, device: torch.device) -> None:
self._device = device
if not self.low_vram:
super().to(device)
self.image_cond_model.to(device)
if self.rembg_model is not None:
self.rembg_model.to(device)
def preprocess_image(self, input: Image.Image) -> Image.Image:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if has_alpha:
output = input
else:
input = input.convert('RGB')
if self.low_vram:
self.rembg_model.to(self.device)
output = self.rembg_model(input)
if self.low_vram:
self.rembg_model.cpu()
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox) # type: ignore
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
self.image_cond_model.image_size = resolution
if self.low_vram:
self.image_cond_model.to(self.device)
cond = self.image_cond_model(image)
if self.low_vram:
self.image_cond_model.cpu()
if not include_neg_cond:
return {'cond': cond}
neg_cond = torch.zeros_like(cond)
return {
'cond': cond,
'neg_cond': neg_cond,
}
def sample_sparse_structure(
self,
cond: dict,
resolution: int,
num_samples: int = 1,
sampler_params: dict = {},
) -> torch.Tensor:
"""
Sample sparse structures with the given conditioning.
Args:
cond (dict): The conditioning information.
resolution (int): The resolution of the sparse structure.
num_samples (int): The number of samples to generate.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample sparse structure latent
flow_model = self.models['sparse_structure_flow_model']
reso = flow_model.resolution
in_channels = flow_model.in_channels
noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device)
sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
z_s = self.sparse_structure_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling sparse structure",
).samples
if self.low_vram:
flow_model.cpu()
# Decode sparse structure latent
decoder = self.models['sparse_structure_decoder']
if self.low_vram:
decoder.to(self.device)
decoded = decoder(z_s)>0
if self.low_vram:
decoder.cpu()
if resolution != decoded.shape[2]:
ratio = decoded.shape[2] // resolution
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
return coords
def sample_shape_slat(
self,
cond: dict,
flow_model,
coords: torch.Tensor,
sampler_params: dict = {},
) -> SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
noise = SparseTensor(
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
slat = self.shape_slat_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling shape SLat",
).samples
if self.low_vram:
flow_model.cpu()
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
def sample_shape_slat_cascade(
self,
lr_cond: dict,
cond: dict,
flow_model_lr,
flow_model,
lr_resolution: int,
resolution: int,
coords: torch.Tensor,
sampler_params: dict = {},
max_num_tokens: int = 49152,
visualize_hr_coords: bool = False,
visualize_save_dir: str = None,
) -> SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
visualize_hr_coords (bool): Whether to visualize high-resolution coordinates after upsampling.
visualize_save_dir (str): Directory to save visualization images. If None, displays interactively.
"""
# LR
noise = SparseTensor(
feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model_lr.to(self.device)
slat = self.shape_slat_sampler.sample(
flow_model_lr,
noise,
**lr_cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling shape SLat",
).samples
get_logger().debug(f"DEBUG SLAT: coords={slat.coords.shape}, spatial_shape={slat.spatial_shape}, "
f"coords_max={slat.coords[:,1:].contiguous().max(dim=0).values}, dtype={slat.feats.dtype}")
if self.low_vram:
flow_model_lr.cpu()
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
get_logger().debug(f"DEBUG SLAT[after *std + mean]: coords={slat.coords.shape}, spatial_shape={slat.spatial_shape}, "
f"coords_max={slat.coords[:,1:].contiguous().max(dim=0).values}, dtype={slat.feats.dtype}")
# Upsample
if self.low_vram:
self.models['shape_slat_decoder'].to(self.device)
self.models['shape_slat_decoder'].low_vram = True
hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4)
get_logger().debug(f"DEBUG CASCADE: hr_coords shape={hr_coords.shape}, max={hr_coords[:,1:].max(dim=0).values}, unique_x={hr_coords[:,1].unique().shape[0]}, unique_y={hr_coords[:,2].unique().shape[0]}, unique_z={hr_coords[:,3].unique().shape[0]}")
# Visualize high-resolution coordinates if requested
if visualize_hr_coords:
print("\n=== High-Resolution Coordinates Visualization (After Upsampling) ===")
self.analyze_sparse_structure(hr_coords)
# Calculate effective resolution for visualization
effective_resolution = lr_resolution * 4 # upsample_times=4
if visualize_save_dir:
import os
os.makedirs(visualize_save_dir, exist_ok=True)
base_path = os.path.join(visualize_save_dir, f"hr_coords_{resolution}_upsampled")
self.visualize_sparse_structure_matplotlib(
hr_coords,
title=f"HR Coordinates - Upsampled {resolution} (effective res: {effective_resolution})",
save_path=f"{base_path}_3d.png"
)
self.visualize_sparse_structure_voxel(
hr_coords,
resolution=effective_resolution,
title=f"HR Voxel Grid - Upsampled {resolution} (effective res: {effective_resolution})",
save_path=f"{base_path}_voxel.png"
)
self.visualize_sparse_structure_projections(
hr_coords,
resolution=effective_resolution,
title=f"HR Projections - Upsampled {resolution} (effective res: {effective_resolution})",
save_path=f"{base_path}_projections.png"
)
self.visualize_sparse_structure_multi_view(
hr_coords,
title=f"HR Multi-View - Upsampled {resolution} (effective res: {effective_resolution})",
save_path=f"{base_path}_multi_view.png"
)
else:
# Interactive visualization (no saving)
self.visualize_sparse_structure_matplotlib(
hr_coords,
title=f"HR Coordinates - Upsampled {resolution} (effective res: {effective_resolution})"
)
self.visualize_sparse_structure_voxel(
hr_coords,
resolution=effective_resolution,
title=f"HR Voxel Grid - Upsampled {resolution} (effective res: {effective_resolution})"
)
self.visualize_sparse_structure_projections(
hr_coords,
resolution=effective_resolution,
title=f"HR Projections - Upsampled {resolution} (effective res: {effective_resolution})"
)
self.visualize_sparse_structure_multi_view(
hr_coords,
title=f"HR Multi-View - Upsampled {resolution} (effective res: {effective_resolution})"
)
print("=== HR Coordinates Visualization Complete ===\n")
coord_set = set(map(tuple, hr_coords[:, 1:].cpu().numpy().tolist()))
has_neighbor = sum(1 for c in coord_set if any(
(c[0]+dx, c[1]+dy, c[2]+dz) in coord_set
for dx,dy,dz in [(1,0,0),(-1,0,0),(0,1,0),(0,-1,0),(0,0,1),(0,0,-1)]
)) / len(coord_set)
get_logger().debug(f"DEBUG TOPOLOGY: coords={len(coord_set)}, neighbor_coverage={has_neighbor:.3f}")
if self.low_vram:
self.models['shape_slat_decoder'].cpu()
self.models['shape_slat_decoder'].low_vram = False
hr_resolution = resolution
while True:
quant_coords = torch.cat([
hr_coords[:, :1],
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
coords = quant_coords.unique(dim=0)
get_logger().debug(f"DEBUG COORDS: num_tokens={coords.shape[0]}, max={coords[:,1:].max(dim=0).values}")
num_tokens = coords.shape[0]
if num_tokens < max_num_tokens or hr_resolution == 1024:
if hr_resolution != resolution:
print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.")
break
hr_resolution -= 128
# Visualize quantized coordinates if requested
if visualize_hr_coords:
print("\n=== Quantized Coordinates Visualization (After Resolution Adjustment) ===")
self.analyze_sparse_structure(coords)
if visualize_save_dir:
import os
os.makedirs(visualize_save_dir, exist_ok=True)
base_path = os.path.join(visualize_save_dir, f"quantized_coords_{hr_resolution}")
self.visualize_sparse_structure_matplotlib(
coords,
title=f"Quantized Coords - Resolution {hr_resolution}",
save_path=f"{base_path}_3d.png"
)
self.visualize_sparse_structure_voxel(
coords,
resolution=hr_resolution // 16,
title=f"Quantized Voxel Grid - Resolution {hr_resolution}",
save_path=f"{base_path}_voxel.png"
)
self.visualize_sparse_structure_projections(
coords,
resolution=hr_resolution // 16,
title=f"Quantized Projections - Resolution {hr_resolution}",
save_path=f"{base_path}_projections.png"
)
self.visualize_sparse_structure_multi_view(
coords,
title=f"Quantized Multi-View - Resolution {hr_resolution}",
save_path=f"{base_path}_multi_view.png"
)
else:
# Interactive visualization (no saving)
self.visualize_sparse_structure_matplotlib(
coords,
title=f"Quantized Coords - Resolution {hr_resolution}"
)
self.visualize_sparse_structure_voxel(
coords,
resolution=hr_resolution // 16,
title=f"Quantized Voxel Grid - Resolution {hr_resolution}"
)
self.visualize_sparse_structure_projections(
coords,
resolution=hr_resolution // 16,
title=f"Quantized Projections - Resolution {hr_resolution}"
)
self.visualize_sparse_structure_multi_view(
coords,
title=f"Quantized Multi-View - Resolution {hr_resolution}"
)
print("=== Quantized Coordinates Visualization Complete ===\n")
# Sample structured latent
noise = SparseTensor(
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
slat = self.shape_slat_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling shape SLat",
).samples
if self.low_vram:
flow_model.cpu()
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
get_logger().debug(f"CASCADE final slat: nan={torch.isnan(slat.feats).any().item()} inf={torch.isinf(slat.feats).any().item()} max={slat.feats.abs().max().item():.4f} dtype={slat.feats.dtype}")
# Visualize final SLat features if requested
if visualize_hr_coords:
print("\n=== Final SLat Features Visualization (After Denormalization) ===")
self.analyze_slat_features(slat)
if visualize_save_dir:
import os
os.makedirs(visualize_save_dir, exist_ok=True)
base_path = os.path.join(visualize_save_dir, f"final_slat_{hr_resolution}")
# Visualize first few features
for i in range(min(3, slat.feats.shape[1])):
self.visualize_slat_features(
slat,
title=f"Final SLat Feature {i} - Resolution {hr_resolution}",
save_path=f"{base_path}_feature{i}.png",
feature_idx=i
)
else:
# Interactive visualization (no saving)
for i in range(min(3, slat.feats.shape[1])):
self.visualize_slat_features(
slat,
title=f"Final SLat Feature {i} - Resolution {hr_resolution}",
feature_idx=i
)
print("=== Final SLat Features Visualization Complete ===\n")
return slat, hr_resolution
def decode_shape_slat(
self,
slat: SparseTensor,
resolution: int,
) -> Tuple[List[Mesh], List[SparseTensor]]:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
Returns:
List[Mesh]: The decoded meshes.
List[SparseTensor]: The decoded substructures.
"""
self.models['shape_slat_decoder'].set_resolution(resolution)
if self.low_vram:
self.models['shape_slat_decoder'].to(self.device)
self.models['shape_slat_decoder'].low_vram = True
ret = self.models['shape_slat_decoder'](slat, return_subs=True)
if self.low_vram:
self.models['shape_slat_decoder'].cpu()
self.models['shape_slat_decoder'].low_vram = False
return ret
def sample_tex_slat(
self,
cond: dict,
flow_model,
shape_slat: SparseTensor,
sampler_params: dict = {},
visualize: bool = False,
visualize_save_dir: str = None,
pipeline_type: str = 'unknown',
) -> SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
shape_slat (SparseTensor): The structured latent for shape
sampler_params (dict): Additional parameters for the sampler.
visualize (bool): Whether to visualize shape + colored texture slat.
visualize_save_dir (str): Directory to save visualizations. None = interactive.
pipeline_type (str): Pipeline name used in visualization titles.
"""
# Sample structured latent
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
shape_slat_norm = (shape_slat - mean) / std
in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
noise = shape_slat_norm.replace(feats=torch.randn(shape_slat_norm.coords.shape[0], in_channels - shape_slat_norm.feats.shape[1]).to(self.device))
sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
slat = self.tex_slat_sampler.sample(
flow_model,
noise,
concat_cond=shape_slat_norm,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling texture SLat",
).samples
if self.low_vram:
flow_model.cpu()
# Visualize: shape structure + colored texture slat
if visualize:
import os
print("\n=== Texture SLat Visualization ===")
self.analyze_slat_features(slat)
if visualize_save_dir:
os.makedirs(visualize_save_dir, exist_ok=True)
base_path = os.path.join(visualize_save_dir, f"tex_slat_{pipeline_type}")
# 1. Shape-only structure (occupancy/geometry)
self.visualize_sparse_structure_projections(
shape_slat.coords,
title=f"Shape Structure - {pipeline_type}",
save_path=f"{base_path}_shape_projections.png",
)
# 2. Combined: shape colored by tex-slat latent features (pseudo-RGB from first 3 dims)
self.visualize_tex_slat_colored(
slat,
title=f"Tex SLat Colored - {pipeline_type}",
save_path=f"{base_path}_colored.png",
)
# 3. Per-feature projections (first 3 latent dims)
for i in range(min(3, slat.feats.shape[1])):
self.visualize_slat_features(
slat,
title=f"Tex Feature {i} - {pipeline_type}",
save_path=f"{base_path}_feature{i}.png",
feature_idx=i,
)
else:
self.visualize_sparse_structure_projections(
shape_slat.coords,
title=f"Shape Structure - {pipeline_type}",
)
self.visualize_tex_slat_colored(
slat,
title=f"Tex SLat Colored - {pipeline_type}",
)
for i in range(min(3, slat.feats.shape[1])):
self.visualize_slat_features(
slat,
title=f"Tex Feature {i} - {pipeline_type}",
feature_idx=i,
)
print("=== Texture SLat Visualization Complete ===\n")
std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
def decode_tex_slat(
self,
slat: SparseTensor,
subs: List[SparseTensor],
) -> SparseTensor:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
Returns:
SparseTensor: The decoded texture voxels
"""
if self.low_vram:
self.models['tex_slat_decoder'].to(self.device)
ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5
if self.low_vram:
self.models['tex_slat_decoder'].cpu()
return ret
def visualize_sparse_structure_matplotlib(self, coords: torch.Tensor, title: str = "Sparse Structure", save_path: str = None):
"""
Visualize sparse structure coordinates using matplotlib 3D scatter plot.
Args:
coords: torch.Tensor of shape [N, 4] with [batch, x, y, z]
title: Title for the plot
save_path: Optional path to save the figure
"""
# Convert to numpy and extract spatial coordinates (drop batch index)
coords_np = coords.cpu().numpy()
x = coords_np[:, 1] # x coordinate
y = coords_np[:, 2] # y coordinate
z = coords_np[:, 3] # z coordinate
# Create 3D plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# Plot points
scatter = ax.scatter(x, y, z, c=z, cmap='viridis', s=1, alpha=0.6)
# Set labels and title
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(f'{title}\n{len(coords)} occupied voxels')
# Add colorbar
plt.colorbar(scatter, label='Z coordinate')
# Set equal aspect ratio
max_range = np.array([x.max()-x.min(), y.max()-y.min(), z.max()-z.min()]).max() / 2.0
mid_x = (x.max()+x.min()) * 0.5
mid_y = (y.max()+y.min()) * 0.5
mid_z = (z.max()+z.min()) * 0.5
ax.set_xlim(mid_x - max_range, mid_x + max_range)
ax.set_ylim(mid_y - max_range, mid_y + max_range)
ax.set_zlim(mid_z - max_range, mid_z + max_range)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved matplotlib visualization to {save_path}")
plt.show()
plt.close()
def visualize_sparse_structure_voxel(self, coords: torch.Tensor, resolution: int = 32, title: str = "Sparse Structure", save_path: str = None):
"""
Visualize sparse structure as a 3D voxel grid.
Args:
coords: torch.Tensor of shape [N, 4] with [batch, x, y, z]
resolution: Grid resolution (e.g., 32 for 32³ grid)
title: Title for the plot
save_path: Optional path to save the figure
"""
# Create empty 3D grid
grid = np.zeros((resolution, resolution, resolution), dtype=bool)
# Fill in occupied voxels
coords_np = coords.cpu().numpy()
for coord in coords_np:
_, x, y, z = coord
if 0 <= x < resolution and 0 <= y < resolution and 0 <= z < resolution:
grid[x, y, z] = True
# Get coordinates of occupied voxels
x, y, z = np.where(grid)
# Create 3D plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# Plot voxels
ax.scatter(x, y, z, c=z, cmap='viridis', s=10, alpha=0.3)
# Set labels
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(f'{title}\n{len(coords)} occupied voxels / {resolution**3} total')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved voxel visualization to {save_path}")
plt.show()
plt.close()
def visualize_sparse_structure_projections(self, coords: torch.Tensor, resolution: int = 32, title: str = "Sparse Structure", save_path: str = None):
"""
Visualize sparse structure using 2D projections (XY, XZ, YZ planes).
Args:
coords: torch.Tensor of shape [N, 4] with [batch, x, y, z]
resolution: Grid resolution
title: Title for the plot
save_path: Optional path to save the figure
"""
coords_np = coords.cpu().numpy()
x = coords_np[:, 1]
y = coords_np[:, 2]
z = coords_np[:, 3]
# Create figure with 3 subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# XY projection (looking down Z axis)
axes[0].scatter(x, y, c=z, cmap='viridis', s=1, alpha=0.5)
axes[0].set_xlabel('X')
axes[0].set_ylabel('Y')
axes[0].set_title('XY Projection (Top View)')
axes[0].set_xlim(0, resolution)
axes[0].set_ylim(0, resolution)
axes[0].set_aspect('equal')
# XZ projection (looking down Y axis)
axes[1].scatter(x, z, c=y, cmap='viridis', s=1, alpha=0.5)
axes[1].set_xlabel('X')
axes[1].set_ylabel('Z')
axes[1].set_title('XZ Projection (Side View)')
axes[1].set_xlim(0, resolution)
axes[1].set_ylim(0, resolution)
axes[1].set_aspect('equal')
# YZ projection (looking down X axis)
axes[2].scatter(y, z, c=x, cmap='viridis', s=1, alpha=0.5)
axes[2].set_xlabel('Y')
axes[2].set_ylabel('Z')
axes[2].set_title('YZ Projection (Front View)')
axes[2].set_xlim(0, resolution)
axes[2].set_ylim(0, resolution)
axes[2].set_aspect('equal')
plt.suptitle(f'{title}\n{len(coords)} occupied voxels', fontsize=14)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved projections visualization to {save_path}")
plt.show()
plt.close()
def visualize_sparse_structure_multi_view(self, coords: torch.Tensor, title: str = "Sparse Structure", save_path: str = None):
"""
Visualize sparse structure with multiple views (3D + 2D projections).
Args:
coords: torch.Tensor of shape [N, 4] with [batch, x, y, z]
title: Title for the plot
save_path: Optional path to save the figure
"""
import matplotlib.pyplot as plt
import numpy as np
coords_np = coords.cpu().numpy()
x, y, z = coords_np[:, 1], coords_np[:, 2], coords_np[:, 3]
# Create multi-view visualization
fig = plt.figure(figsize=(18, 6))
# 3D scatter plot
ax1 = fig.add_subplot(131, projection='3d')
ax1.scatter(x, y, z, c=z, cmap='viridis', s=1, alpha=0.6)
ax1.set_title('3D View')
ax1.set_xlabel('X'); ax1.set_ylabel('Y'); ax1.set_zlabel('Z')
# XY projection
ax2 = fig.add_subplot(132)
ax2.scatter(x, y, c=z, cmap='viridis', s=1, alpha=0.5)
ax2.set_title('XY Projection')
ax2.set_xlabel('X'); ax2.set_ylabel('Y')
ax2.set_aspect('equal')
# XZ projection
ax3 = fig.add_subplot(133)
ax3.scatter(x, z, c=y, cmap='viridis', s=1, alpha=0.5)
ax3.set_title('XZ Projection')
ax3.set_xlabel('X'); ax3.set_ylabel('Z')
ax3.set_aspect('equal')
plt.suptitle(f'{title}\n{len(coords)} occupied voxels', fontsize=14)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved multi-view visualization to {save_path}")
plt.show()
plt.close()
def analyze_sparse_structure(self, coords: torch.Tensor):
"""
Analyze and print statistics about the sparse structure.
Args:
coords: torch.Tensor of shape [N, 4]
"""
coords_np = coords.cpu().numpy()
x, y, z = coords_np[:, 1], coords_np[:, 2], coords_np[:, 3]
print(f"Sparse Structure Analysis:")
print(f" Total occupied voxels: {len(coords)}")
print(f" X range: [{x.min()}, {x.max()}]")
print(f" Y range: [{y.min()}, {y.max()}]")
print(f" Z range: [{z.min()}, {z.max()}]")
print(f" Center: [{x.mean():.1f}, {y.mean():.1f}, {z.mean():.1f}]")
print(f" Std dev: [{x.std():.1f}, {y.std():.1f}, {z.std():.1f}]")
print(f" Bounding box volume: {(x.max()-x.min()) * (y.max()-y.min()) * (z.max()-z.min())}")
def visualize_slat_features(self, slat: SparseTensor, title: str = "SLat Features", save_path: str = None, feature_idx: int = 0):
"""
Visualize features from a SparseTensor (shape SLat).
Args:
slat: SparseTensor with features at sparse coordinates
title: Title for the plot
save_path: Optional path to save the figure
feature_idx: Which feature dimension to visualize (default: 0)
"""
coords_np = slat.coords.cpu().numpy()
feats_np = slat.feats.cpu().numpy()
# Extract coordinates and selected feature
x = coords_np[:, 1]
y = coords_np[:, 2]
z = coords_np[:, 3]
feature_values = feats_np[:, feature_idx]
# Create 3D plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# Plot points colored by feature value
scatter = ax.scatter(x, y, z, c=feature_values, cmap='viridis', s=1, alpha=0.6)
# Set labels and title
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(f'{title}\nFeature {feature_idx} | Range: [{feature_values.min():.3f}, {feature_values.max():.3f}]')
# Add colorbar
plt.colorbar(scatter, label=f'Feature {feature_idx} Value')
# Set equal aspect ratio
max_range = np.array([x.max()-x.min(), y.max()-y.min(), z.max()-z.min()]).max() / 2.0
mid_x = (x.max()+x.min()) * 0.5
mid_y = (y.max()+y.min()) * 0.5
mid_z = (z.max()+z.min()) * 0.5
ax.set_xlim(mid_x - max_range, mid_x + max_range)
ax.set_ylim(mid_y - max_range, mid_y + max_range)
ax.set_zlim(mid_z - max_range, mid_z + max_range)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved SLat feature visualization to {save_path}")
plt.show()
plt.close()
def visualize_tex_slat_colored(self, slat: SparseTensor, title: str = "Tex SLat Colored", save_path: str = None):
"""
Visualize texture SLat with points colored by the first 3 latent feature dimensions
mapped to R, G, B — giving a pseudo-color view of the texture distribution across the shape.
Also shows three 2D projection panels (XY/XZ/YZ) beside the 3D view so you can see
coverage completeness at a glance.
Args:
slat: SparseTensor with texture latent features [N, C]
title: Title for the plot
save_path: Optional path to save the figure. None = interactive display.
"""
import numpy as np
coords_np = slat.coords.cpu().float().numpy()
feats_np = slat.feats.cpu().float().numpy()
x = coords_np[:, 1]
y = coords_np[:, 2]
z = coords_np[:, 3]
# Build per-point RGB from first 3 feature dims, normalised to [0, 1]
n_color_dims = min(3, feats_np.shape[1])
rgb = feats_np[:, :n_color_dims].copy()
for ch in range(n_color_dims):
lo, hi = rgb[:, ch].min(), rgb[:, ch].max()
rgb[:, ch] = (rgb[:, ch] - lo) / (hi - lo + 1e-8)
if n_color_dims < 3:
pad = np.ones((rgb.shape[0], 3 - n_color_dims))
rgb = np.concatenate([rgb, pad], axis=1)
rgb = np.clip(rgb, 0.0, 1.0)
fig = plt.figure(figsize=(22, 6))
fig.suptitle(f'{title} ({len(x)} voxels, {feats_np.shape[1]} feat dims)', fontsize=13)
# 3D scatter coloured by pseudo-RGB
ax3d = fig.add_subplot(141, projection='3d')
ax3d.scatter(x, y, z, c=rgb, s=1, alpha=0.6)
ax3d.set_xlabel('X'); ax3d.set_ylabel('Y'); ax3d.set_zlabel('Z')
ax3d.set_title('3D (pseudo-RGB)')
# XY projection
ax_xy = fig.add_subplot(142)
ax_xy.scatter(x, y, c=rgb, s=1, alpha=0.5)
ax_xy.set_xlabel('X'); ax_xy.set_ylabel('Y')
ax_xy.set_title('XY (top)')
ax_xy.set_aspect('equal')
# XZ projection
ax_xz = fig.add_subplot(143)
ax_xz.scatter(x, z, c=rgb, s=1, alpha=0.5)
ax_xz.set_xlabel('X'); ax_xz.set_ylabel('Z')
ax_xz.set_title('XZ (side)')
ax_xz.set_aspect('equal')
# YZ projection
ax_yz = fig.add_subplot(144)
ax_yz.scatter(y, z, c=rgb, s=1, alpha=0.5)
ax_yz.set_xlabel('Y'); ax_yz.set_ylabel('Z')
ax_yz.set_title('YZ (front)')
ax_yz.set_aspect('equal')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved tex-slat colored visualization to {save_path}")
plt.show()
plt.close()
def visualize_decoded_mesh(self, mesh, title: str = "Decoded Mesh", save_path_prefix: str = None):
"""
Visualize a decoded triangle mesh (vertices + faces).
Renders four panels:
- 3D scatter of vertices coloured by Z (subsampled to ≤50k points so matplotlib doesn't choke)
- XY / XZ / YZ 2D projections
Saves four separate PNGs when save_path_prefix is given (one per panel style matches
the naming convention used elsewhere in the pipeline):
<prefix>_3d.png, <prefix>_projections.png
"""
import numpy as np
import os
verts = mesh.vertices.cpu().float().numpy() # [V, 3]
n_verts = verts.shape[0]
n_faces = mesh.faces.shape[0]
MAX_SCATTER = 50_000
if n_verts > MAX_SCATTER:
idx = np.random.choice(n_verts, MAX_SCATTER, replace=False)
v = verts[idx]
else:
v = verts
x, y, z = v[:, 0], v[:, 1], v[:, 2]
subtitle = f"{n_verts:,} vertices {n_faces:,} faces" + (
f" (scatter: {len(x):,} sampled)" if n_verts > MAX_SCATTER else "")
# --- 3D scatter ---
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z, c=z, cmap='viridis', s=1, alpha=0.6)
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
ax.set_title(f'{title}\n{subtitle}')
plt.tight_layout()
if save_path_prefix:
p = f"{save_path_prefix}_3d.png"
plt.savefig(p, dpi=150, bbox_inches='tight')
print(f"Saved decoded mesh 3D to {p}")
plt.show(); plt.close()
# --- 3-panel 2D projections ---
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
axes[0].scatter(x, y, c=z, cmap='viridis', s=1, alpha=0.5)
axes[0].set_xlabel('X'); axes[0].set_ylabel('Y'); axes[0].set_title('XY (top)')
axes[0].set_aspect('equal')
axes[1].scatter(x, z, c=y, cmap='viridis', s=1, alpha=0.5)
axes[1].set_xlabel('X'); axes[1].set_ylabel('Z'); axes[1].set_title('XZ (side)')
axes[1].set_aspect('equal')
axes[2].scatter(y, z, c=x, cmap='viridis', s=1, alpha=0.5)
axes[2].set_xlabel('Y'); axes[2].set_ylabel('Z'); axes[2].set_title('YZ (front)')
axes[2].set_aspect('equal')
plt.suptitle(f'{title}\n{subtitle}', fontsize=13)
plt.tight_layout()
if save_path_prefix:
p = f"{save_path_prefix}_projections.png"
plt.savefig(p, dpi=150, bbox_inches='tight')
print(f"Saved decoded mesh projections to {p}")
plt.show(); plt.close()
def visualize_mesh_with_voxel(self, mv, title: str = "MeshWithVoxel", save_path_prefix: str = None):
"""
Visualize a MeshWithVoxel: overlays mesh vertices (grey) and texture voxel positions
(coloured by pseudo-RGB from first 3 attr dims) in one 5-panel figure.
Panels: 3D overlay, XY / XZ / YZ 2D projections.
"""
import numpy as np
verts = mv.vertices.cpu().float().numpy()
n_verts = verts.shape[0]
n_faces = mv.faces.shape[0]
coords = mv.coords.cpu().float().numpy() # [N, 3] (already stripped of batch dim)
attrs = mv.attrs.cpu().float().numpy() # [N, C]
n_vox = coords.shape[0]
MAX_SCATTER = 50_000
if n_verts > MAX_SCATTER:
vi = np.random.choice(n_verts, MAX_SCATTER, replace=False)
vp = verts[vi]
else:
vp = verts
if n_vox > MAX_SCATTER:
ci = np.random.choice(n_vox, MAX_SCATTER, replace=False)
cp = coords[ci]; ap = attrs[ci]
else:
cp = coords; ap = attrs
# Build pseudo-RGB from first 3 attr dims
n_color = min(3, ap.shape[1])
rgb = ap[:, :n_color].copy()
for ch in range(n_color):
lo, hi = rgb[:, ch].min(), rgb[:, ch].max()
rgb[:, ch] = (rgb[:, ch] - lo) / (hi - lo + 1e-8)
if n_color < 3:
rgb = np.concatenate([rgb, np.ones((rgb.shape[0], 3 - n_color))], axis=1)
rgb = np.clip(rgb, 0, 1)
subtitle = (f"Mesh: {n_verts:,}v {n_faces:,}f | Voxels: {n_vox:,}"
+ (" (both subsampled)" if n_verts > MAX_SCATTER or n_vox > MAX_SCATTER else ""))
# Voxel coords are integer indices; convert to world space for overlay
vox_world = cp * mv.voxel_size + mv.origin.cpu().numpy()
vx, vy, vz = vox_world[:, 0], vox_world[:, 1], vox_world[:, 2]
mx, my, mz = vp[:, 0], vp[:, 1], vp[:, 2]
# --- 3D overlay ---
fig = plt.figure(figsize=(11, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(mx, my, mz, c='lightgrey', s=1, alpha=0.3, label='mesh verts')
ax.scatter(vx, vy, vz, c=rgb, s=2, alpha=0.6, label='tex voxels')
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
ax.set_title(f'{title}\n{subtitle}')
plt.tight_layout()
if save_path_prefix:
p = f"{save_path_prefix}_3d.png"
plt.savefig(p, dpi=150, bbox_inches='tight')
print(f"Saved MeshWithVoxel 3D to {p}")
plt.show(); plt.close()
# --- 4-panel 2D projections ---
fig, axes = plt.subplots(1, 4, figsize=(24, 6))
def proj(ax_, hx, hy, hz, label):
ax_.scatter(hx, hy, c='lightgrey', s=1, alpha=0.25)
ax_.scatter(vx if label == 'XY' else (vx if label == 'XZ' else vy),
vy if label == 'XY' else (vz if label == 'XZ' else vz),
c=rgb, s=1, alpha=0.5)
ax_.set_aspect('equal')
axes[0].scatter(mx, my, c='lightgrey', s=1, alpha=0.25)
axes[0].scatter(vx, vy, c=rgb, s=1, alpha=0.5)
axes[0].set_xlabel('X'); axes[0].set_ylabel('Y'); axes[0].set_title('XY (top)'); axes[0].set_aspect('equal')
axes[1].scatter(mx, mz, c='lightgrey', s=1, alpha=0.25)
axes[1].scatter(vx, vz, c=rgb, s=1, alpha=0.5)
axes[1].set_xlabel('X'); axes[1].set_ylabel('Z'); axes[1].set_title('XZ (side)'); axes[1].set_aspect('equal')
axes[2].scatter(my, mz, c='lightgrey', s=1, alpha=0.25)
axes[2].scatter(vy, vz, c=rgb, s=1, alpha=0.5)
axes[2].set_xlabel('Y'); axes[2].set_ylabel('Z'); axes[2].set_title('YZ (front)'); axes[2].set_aspect('equal')
# 4th panel: voxel coverage ratio as bar chart per axis
axes[3].axis('off')
info = (f"Mesh vertices : {n_verts:,}\n"
f"Mesh faces : {n_faces:,}\n"
f"Tex voxels : {n_vox:,}\n"
f"Voxel size : {mv.voxel_size:.5f}\n"
f"Voxel world X : [{vx.min():.3f}, {vx.max():.3f}]\n"
f"Voxel world Y : [{vy.min():.3f}, {vy.max():.3f}]\n"
f"Voxel world Z : [{vz.min():.3f}, {vz.max():.3f}]\n"
f"Attr dims : {mv.attrs.shape[1]}\n"
f"Attr range : [{mv.attrs.min().item():.4f}, {mv.attrs.max().item():.4f}]")
axes[3].text(0.05, 0.95, info, transform=axes[3].transAxes,
fontsize=10, verticalalignment='top', fontfamily='monospace')
axes[3].set_title('Stats')
plt.suptitle(f'{title}\n{subtitle}', fontsize=13)
plt.tight_layout()
if save_path_prefix:
p = f"{save_path_prefix}_projections.png"
plt.savefig(p, dpi=150, bbox_inches='tight')
print(f"Saved MeshWithVoxel projections to {p}")
plt.show(); plt.close()
def analyze_slat_features(self, slat: SparseTensor):
"""
Analyze and print statistics about SLat features.
Args:
slat: SparseTensor with features
"""
coords_np = slat.coords.cpu().numpy()
feats_np = slat.feats.cpu().numpy()
print(f"\nSLat Features Analysis:")
print(f" Number of tokens: {slat.coords.shape[0]}")
print(f" Feature dimensions: {slat.feats.shape[1]}")
print(f" Feature statistics:")
for i in range(min(5, slat.feats.shape[1])): # Show first 5 features
feat = feats_np[:, i]
print(f" Feature {i}: min={feat.min():.4f}, max={feat.max():.4f}, mean={feat.mean():.4f}, std={feat.std():.4f}")
print(f" NaN values: {np.isnan(feats_np).any()}")
print(f" Inf values: {np.isinf(feats_np).any()}")
print(f" Coordinate range: X=[{coords_np[:, 1].min()}, {coords_np[:, 1].max()}], "
f"Y=[{coords_np[:, 2].min()}, {coords_np[:, 2].max()}], "
f"Z=[{coords_np[:, 3].min()}, {coords_np[:, 3].max()}]")
@torch.no_grad()
def decode_latent(
self,
shape_slat: SparseTensor,
tex_slat: SparseTensor,
resolution: int,
visualize: bool = False,
visualize_save_dir: str = None,
pipeline_type: str = 'unknown',
) -> List[MeshWithVoxel]:
"""
Decode the latent codes.
Args:
shape_slat (SparseTensor): The structured latent for shape.
tex_slat (SparseTensor): The structured latent for texture.
resolution (int): The resolution of the output.
"""
L = get_logger()
section(f"decode_latent resolution={resolution}")
section("decode_shape_slat")
log_sparse(shape_slat, "shape_slat-in")
meshes, subs = self.decode_shape_slat(shape_slat, resolution)
L.info(f" {elapsed()} decode_shape_slat produced {len(meshes)} mesh(es)")
for i, m in enumerate(meshes):
log_mesh(m.vertices, m.faces, f"shape_mesh[{i}]")
# Visualize decoded shape meshes
if visualize:
import os
for i, m in enumerate(meshes):
base = (os.path.join(visualize_save_dir, f"decoded_mesh_{pipeline_type}_s{i}")
if visualize_save_dir else None)
if base:
os.makedirs(visualize_save_dir, exist_ok=True)
self.visualize_decoded_mesh(m, title=f"Decoded Shape Mesh [{i}] - {pipeline_type}",
save_path_prefix=base)
section("decode_tex_slat")
log_sparse(tex_slat, "tex_slat-in")
tex_voxels = self.decode_tex_slat(tex_slat, subs)
L.info(f" {elapsed()} decode_tex_slat produced {len(tex_voxels)} voxel set(s)")
#Commented temporarily for speed.
"""
# Visualize texture voxels
if visualize:
import os
for i, v in enumerate(tex_voxels):
base = (os.path.join(visualize_save_dir, f"tex_voxels_{pipeline_type}_s{i}")
if visualize_save_dir else None)
if base:
os.makedirs(visualize_save_dir, exist_ok=True)
self.visualize_tex_slat_colored(v,
title=f"Tex Voxels [{i}] - {pipeline_type}",
save_path=f"{base}_colored.png" if base else None)
"""
section("build MeshWithVoxel")
out_mesh = []
for i, (m, v) in enumerate(zip(meshes, tex_voxels)):
L.info(f" {elapsed()} sample {i}:")
log_sparse(v, f"tex_voxels[{i}]")
L.info(f" spatial_shape={v.spatial_shape} "
f"coords_max={v.coords.max(dim=0).values.tolist()}")
log_mesh(m.vertices, m.faces, f"before-fill_holes[{i}]")
# CPU simplification via pyfqmr (QEM) before fill_holes to avoid
# GPU OOM from CuMesh's O(F*3) edge buffers on large meshes.
import pyfqmr, time
_target = 4_000_000
if m.faces.shape[0] > _target:
_v_np = m.vertices.detach().cpu().float().numpy()
_f_np = m.faces.detach().cpu().int().numpy()
L.info(f" [pyfqmr] simplify {m.faces.shape[0]}{_target} faces ...")
_t0 = time.perf_counter()
_simplifier = pyfqmr.Simplify()
_simplifier.setMesh(_v_np, _f_np)
_simplifier.simplify_mesh(_target, aggressiveness=7, verbose=False)
_sv, _sf, _sn = _simplifier.getMesh()
_dt = time.perf_counter() - _t0
L.info(f" [pyfqmr] done in {_dt:.2f}s → {len(_sv)} verts {len(_sf)} faces")
m.vertices = torch.from_numpy(_sv).to(dtype=torch.float32, device=m.vertices.device)
m.faces = torch.from_numpy(_sf).to(dtype=torch.int32, device=m.faces.device)
m.fill_holes()
log_mesh(m.vertices, m.faces, f"after-fill_holes[{i}]")
coords_xyz = v.coords[:, 1:].contiguous()
L.info(f" coords_xyz: {list(coords_xyz.shape)} "
f"range={[coords_xyz.min().item(), coords_xyz.max().item()]}")
L.info(f" attrs: {list(v.feats.shape)} "
f"range=[{v.feats.min().item():.4g},{v.feats.max().item():.4g}] "
f"NaN={torch.isnan(v.feats).any().item()}")
L.info(f" voxel_size={1/resolution:.6f} origin=[-0.5,-0.5,-0.5]")
mv = MeshWithVoxel(
m.vertices, m.faces,
origin = [-0.5, -0.5, -0.5],
voxel_size = 1 / resolution,
coords = coords_xyz,
attrs = v.feats,
voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
layout=self.pbr_attr_layout
)
L.info(f" MeshWithVoxel.voxel_shape={mv.voxel_shape} "
f"voxel_size={mv.voxel_size} origin={mv.origin}")
# Visualize final MeshWithVoxel
if visualize:
import os
base = (os.path.join(visualize_save_dir, f"mesh_with_voxel_{pipeline_type}_s{i}")
if visualize_save_dir else None)
if base:
os.makedirs(visualize_save_dir, exist_ok=True)
self.visualize_mesh_with_voxel(mv,
title=f"MeshWithVoxel [{i}] - {pipeline_type}",
save_path_prefix=base)
out_mesh.append(mv)
section("decode_latent complete")
return out_mesh
@torch.no_grad()
def run(
self,
image: Image.Image,
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
shape_slat_sampler_params: dict = {},
tex_slat_sampler_params: dict = {},
preprocess_image: bool = True,
return_latent: bool = False,
pipeline_type: Optional[str] = None,
max_num_tokens: int = 49152,
visualize_sparse_structure: bool = False,
visualize_save_dir: str = None,
) -> List[MeshWithVoxel]:
"""
Run the pipeline.
Args:
image (Image.Image): The image prompt.
num_samples (int): The number of samples to generate.
seed (int): The random seed.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler.
tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
preprocess_image (bool): Whether to preprocess the image.
return_latent (bool): Whether to return the latent codes.
pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
max_num_tokens (int): The maximum number of tokens to use.
visualize_sparse_structure (bool): Whether to visualize the sparse structure.
visualize_save_dir (str): Directory to save visualization images. If None, displays interactively.
"""
# Check pipeline type
pipeline_type = pipeline_type or self.default_pipeline_type
if pipeline_type == '512':
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found."
elif pipeline_type == '1024':
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
elif pipeline_type == '1024_cascade':
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
elif pipeline_type == '1536_cascade':
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
else:
raise ValueError(f"Invalid pipeline type: {pipeline_type}")
if preprocess_image:
image = self.preprocess_image(image)
torch.manual_seed(seed)
cond_512 = self.get_cond([image], 512)
cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None
ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type]
coords = self.sample_sparse_structure(
cond_512, ss_res,
num_samples, sparse_structure_sampler_params
)
# Visualize sparse structure if requested
if visualize_sparse_structure:
print("\n=== Sparse Structure Visualization ===")
self.analyze_sparse_structure(coords)
if visualize_save_dir:
import os
os.makedirs(visualize_save_dir, exist_ok=True)
base_path = os.path.join(visualize_save_dir, f"sparse_structure_{pipeline_type}_seed{seed}")
self.visualize_sparse_structure_matplotlib(
coords,
title=f"Sparse Structure - {pipeline_type} (seed={seed})",
save_path=f"{base_path}_3d.png"
)
self.visualize_sparse_structure_voxel(
coords,
resolution=ss_res,
title=f"Voxel Grid - {pipeline_type} (seed={seed})",
save_path=f"{base_path}_voxel.png"
)
self.visualize_sparse_structure_projections(
coords,
resolution=ss_res,
title=f"Projections - {pipeline_type} (seed={seed})",
save_path=f"{base_path}_projections.png"
)
self.visualize_sparse_structure_multi_view(
coords,
title=f"Multi-View - {pipeline_type} (seed={seed})",
save_path=f"{base_path}_multi_view.png"
)
else:
# Interactive visualization (no saving)
self.visualize_sparse_structure_matplotlib(
coords,
title=f"Sparse Structure - {pipeline_type} (seed={seed})"
)
self.visualize_sparse_structure_voxel(
coords,
resolution=ss_res,
title=f"Voxel Grid - {pipeline_type} (seed={seed})"
)
self.visualize_sparse_structure_projections(
coords,
resolution=ss_res,
title=f"Projections - {pipeline_type} (seed={seed})"
)
self.visualize_sparse_structure_multi_view(
coords,
title=f"Multi-View - {pipeline_type} (seed={seed})"
)
print("=== Visualization Complete ===\n")
if pipeline_type == '512':
shape_slat = self.sample_shape_slat(
cond_512, self.models['shape_slat_flow_model_512'],
coords, shape_slat_sampler_params
)
tex_slat = self.sample_tex_slat(
cond_512, self.models['tex_slat_flow_model_512'],
shape_slat, tex_slat_sampler_params,
visualize=visualize_sparse_structure,
visualize_save_dir=visualize_save_dir,
pipeline_type=pipeline_type,
)
res = 512
elif pipeline_type == '1024':
shape_slat = self.sample_shape_slat(
cond_1024, self.models['shape_slat_flow_model_1024'],
coords, shape_slat_sampler_params
)
tex_slat = self.sample_tex_slat(
cond_1024, self.models['tex_slat_flow_model_1024'],
shape_slat, tex_slat_sampler_params,
visualize=visualize_sparse_structure,
visualize_save_dir=visualize_save_dir,
pipeline_type=pipeline_type,
)
res = 1024
elif pipeline_type == '1024_cascade':
shape_slat, res = self.sample_shape_slat_cascade(
cond_512, cond_1024,
self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
512, 1024,
coords, shape_slat_sampler_params,
max_num_tokens,
visualize_hr_coords=visualize_sparse_structure,
visualize_save_dir=visualize_save_dir,
)
tex_slat = self.sample_tex_slat(
cond_1024, self.models['tex_slat_flow_model_1024'],
shape_slat, tex_slat_sampler_params,
visualize=visualize_sparse_structure,
visualize_save_dir=visualize_save_dir,
pipeline_type=pipeline_type,
)
elif pipeline_type == '1536_cascade':
shape_slat, res = self.sample_shape_slat_cascade(
cond_512, cond_1024,
self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
512, 1536,
coords, shape_slat_sampler_params,
max_num_tokens,
visualize_hr_coords=visualize_sparse_structure,
visualize_save_dir=visualize_save_dir,
)
tex_slat = self.sample_tex_slat(
cond_1024, self.models['tex_slat_flow_model_1024'],
shape_slat, tex_slat_sampler_params,
visualize=visualize_sparse_structure,
visualize_save_dir=visualize_save_dir,
pipeline_type=pipeline_type,
)
torch.cuda.empty_cache()
out_mesh = self.decode_latent(shape_slat, tex_slat, res,
visualize=visualize_sparse_structure,
visualize_save_dir=visualize_save_dir,
pipeline_type=pipeline_type)
if return_latent:
return out_mesh, (shape_slat, tex_slat, res)
else:
return out_mesh
from typing import *
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import trimesh
from .base import Pipeline
from . import samplers, rembg
from ..modules.sparse import SparseTensor
from ..modules import image_feature_extractor
import o_voxel
import cumesh
import nvdiffrast.torch as dr
import cv2
import flex_gemm
class Trellis2TexturingPipeline(Pipeline):
"""
Pipeline for inferring Trellis2 image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
shape_slat_normalization (dict): The normalization parameters for the structured latent.
tex_slat_normalization (dict): The normalization parameters for the texture latent.
image_cond_model (Callable): The image conditioning model.
rembg_model (Callable): The model for removing background.
low_vram (bool): Whether to use low-VRAM mode.
"""
model_names_to_load = [
'shape_slat_encoder',
'tex_slat_decoder',
'tex_slat_flow_model_512',
'tex_slat_flow_model_1024'
]
def __init__(
self,
models: dict[str, nn.Module] = None,
tex_slat_sampler: samplers.Sampler = None,
tex_slat_sampler_params: dict = None,
shape_slat_normalization: dict = None,
tex_slat_normalization: dict = None,
image_cond_model: Callable = None,
rembg_model: Callable = None,
low_vram: bool = True,
):
if models is None:
return
super().__init__(models)
self.tex_slat_sampler = tex_slat_sampler
self.tex_slat_sampler_params = tex_slat_sampler_params
self.shape_slat_normalization = shape_slat_normalization
self.tex_slat_normalization = tex_slat_normalization
self.image_cond_model = image_cond_model
self.rembg_model = rembg_model
self.low_vram = low_vram
self.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
self._device = 'cpu'
@classmethod
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super().from_pretrained(path, config_file)
args = pipeline._pretrained_args
pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
pipeline.shape_slat_normalization = args['shape_slat_normalization']
pipeline.tex_slat_normalization = args['tex_slat_normalization']
pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
pipeline.low_vram = args.get('low_vram', True)
pipeline.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
pipeline._device = 'cpu'
return pipeline
def to(self, device: torch.device) -> None:
self._device = device
if not self.low_vram:
super().to(device)
self.image_cond_model.to(device)
if self.rembg_model is not None:
self.rembg_model.to(device)
def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
"""
Preprocess the input mesh.
"""
vertices = mesh.vertices
vertices_min = vertices.min(axis=0)
vertices_max = vertices.max(axis=0)
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
vertices = (vertices - center) * scale
tmp = vertices[:, 1].copy()
vertices[:, 1] = -vertices[:, 2]
vertices[:, 2] = tmp
assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range'
return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False)
def preprocess_image(self, input: Image.Image) -> Image.Image:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if has_alpha:
output = input
else:
input = input.convert('RGB')
if self.low_vram:
self.rembg_model.to(self.device)
output = self.rembg_model(input)
if self.low_vram:
self.rembg_model.cpu()
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox) # type: ignore
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
self.image_cond_model.image_size = resolution
if self.low_vram:
self.image_cond_model.to(self.device)
cond = self.image_cond_model(image)
if self.low_vram:
self.image_cond_model.cpu()
if not include_neg_cond:
return {'cond': cond}
neg_cond = torch.zeros_like(cond)
return {
'cond': cond,
'neg_cond': neg_cond,
}
def encode_shape_slat(
self,
mesh: trimesh.Trimesh,
resolution: int = 1024,
) -> SparseTensor:
"""
Encode the meshes to structured latent.
Args:
mesh (trimesh.Trimesh): The mesh to encode.
resolution (int): The resolution of mesh
Returns:
SparseTensor: The encoded structured latent.
"""
vertices = torch.from_numpy(mesh.vertices).float()
faces = torch.from_numpy(mesh.faces).long()
voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
vertices.cpu(), faces.cpu(),
grid_size=resolution,
aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],
face_weight=1.0,
boundary_weight=0.2,
regularization_weight=1e-2,
timing=True,
)
vertices = SparseTensor(
feats=dual_vertices * resolution - voxel_indices,
coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1)
).to(self.device)
intersected = vertices.replace(intersected).to(self.device)
if self.low_vram:
self.models['shape_slat_encoder'].to(self.device)
shape_slat = self.models['shape_slat_encoder'](vertices, intersected)
if self.low_vram:
self.models['shape_slat_encoder'].cpu()
return shape_slat
def sample_tex_slat(
self,
cond: dict,
flow_model,
shape_slat: SparseTensor,
sampler_params: dict = {},
) -> SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
shape_slat (SparseTensor): The structured latent for shape
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
shape_slat = (shape_slat - mean) / std
in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
slat = self.tex_slat_sampler.sample(
flow_model,
noise,
concat_cond=shape_slat,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling texture SLat",
).samples
if self.low_vram:
flow_model.cpu()
std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
def decode_tex_slat(
self,
slat: SparseTensor,
) -> SparseTensor:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
Returns:
SparseTensor: The decoded texture voxels
"""
if self.low_vram:
self.models['tex_slat_decoder'].to(self.device)
ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5
if self.low_vram:
self.models['tex_slat_decoder'].cpu()
return ret
def postprocess_mesh(
self,
mesh: trimesh.Trimesh,
pbr_voxel: SparseTensor,
resolution: int = 1024,
texture_size: int = 1024,
) -> trimesh.Trimesh:
vertices = mesh.vertices
faces = mesh.faces
normals = mesh.vertex_normals
vertices_torch = torch.from_numpy(vertices).float().cuda()
faces_torch = torch.from_numpy(faces).int().cuda()
if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
uvs = mesh.visual.uv.copy()
uvs[:, 1] = 1 - uvs[:, 1]
uvs_torch = torch.from_numpy(uvs).float().cuda()
else:
_cumesh = cumesh.CuMesh()
_cumesh.init(vertices_torch, faces_torch)
vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True)
vertices_torch = vertices_torch.cuda()
faces_torch = faces_torch.cuda()
uvs_torch = uvs_torch.cuda()
vertices = vertices_torch.cpu().numpy()
faces = faces_torch.cpu().numpy()
uvs = uvs_torch.cpu().numpy()
normals = normals[vmap.cpu().numpy()]
# rasterize
ctx = dr.RasterizeCudaContext()
uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0)
rast, _ = dr.rasterize(
ctx, uvs_torch, faces_torch,
resolution=[texture_size, texture_size],
)
mask = rast[0, ..., 3] > 0
pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0]
attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device)
if mask.any(): attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d(
pbr_voxel.feats,
pbr_voxel.coords,
shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]),
grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3),
mode='trilinear',
)
# construct mesh
mask = mask.cpu().numpy()
base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
# extend
mask = (~mask).astype(np.uint8)
base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA)
metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None]
roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None]
alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None]
material = trimesh.visual.material.PBRMaterial(
baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)),
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8),
metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)),
metallicFactor=1.0,
roughnessFactor=1.0,
alphaMode='OPAQUE',
doubleSided=True,
)
# Swap Y and Z axes, invert Y (common conversion for GLB compatibility)
vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1]
normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1]
uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate
textured_mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
vertex_normals=normals,
process=False,
visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)
)
return textured_mesh
@torch.no_grad()
def run(
self,
mesh: trimesh.Trimesh,
image: Image.Image,
seed: int = 42,
tex_slat_sampler_params: dict = {},
preprocess_image: bool = True,
resolution: int = 1024,
texture_size: int = 2048,
) -> trimesh.Trimesh:
"""
Run the pipeline.
Args:
mesh (trimesh.Trimesh): The mesh to texture.
image (Image.Image): The image prompt.
seed (int): The random seed.
tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler.
preprocess_image (bool): Whether to preprocess the image.
"""
if preprocess_image:
image = self.preprocess_image(image)
mesh = self.preprocess_mesh(mesh)
torch.manual_seed(seed)
cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024)
shape_slat = self.encode_shape_slat(mesh, resolution)
tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024']
tex_slat = self.sample_tex_slat(
cond, tex_model,
shape_slat, tex_slat_sampler_params
)
pbr_voxel = self.decode_tex_slat(tex_slat)
out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size)
return out_mesh
import importlib
__attributes = {
'MeshRenderer': 'mesh_renderer',
'VoxelRenderer': 'voxel_renderer',
'PbrMeshRenderer': 'pbr_mesh_renderer',
'EnvMap': 'pbr_mesh_renderer',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .mesh_renderer import MeshRenderer
from .voxel_renderer import VoxelRenderer
from .pbr_mesh_renderer import PbrMeshRenderer, EnvMap
\ No newline at end of file
from typing import *
import torch
from easydict import EasyDict as edict
from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode
import torch.nn.functional as F
def intrinsics_to_projection(
intrinsics: torch.Tensor,
near: float,
far: float,
) -> torch.Tensor:
"""
OpenCV intrinsics to OpenGL perspective matrix
Args:
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
near (float): near plane to clip
far (float): far plane to clip
Returns:
(torch.Tensor): [4, 4] OpenGL perspective matrix
"""
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
ret[0, 0] = 2 * fx
ret[1, 1] = 2 * fy
ret[0, 2] = 2 * cx - 1
ret[1, 2] = - 2 * cy + 1
ret[2, 2] = (far + near) / (far - near)
ret[2, 3] = 2 * near * far / (near - far)
ret[3, 2] = 1.
return ret
class MeshRenderer:
"""
Renderer for the Mesh representation.
Args:
rendering_options (dict): Rendering options.
"""
def __init__(self, rendering_options={}, device='cuda'):
if 'dr' not in globals():
import nvdiffrast.torch as dr
self.rendering_options = edict({
"resolution": None,
"near": None,
"far": None,
"ssaa": 1,
"chunk_size": None,
"antialias": True,
"clamp_barycentric_coords": False,
})
self.rendering_options.update(rendering_options)
self.glctx = dr.RasterizeCudaContext(device=device)
self.device=device
def render(
self,
mesh : Mesh,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
return_types = ["mask", "normal", "depth"],
transformation : Optional[torch.Tensor] = None
) -> edict:
"""
Render the mesh.
Args:
mesh : meshmodel
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
return_types (list): list of return types, can be "attr", "mask", "depth", "coord", "normal"
Returns:
edict based on return_types containing:
attr (torch.Tensor): [C, H, W] rendered attr image
depth (torch.Tensor): [H, W] rendered depth image
normal (torch.Tensor): [3, H, W] rendered normal image
mask (torch.Tensor): [H, W] rendered mask image
"""
if 'dr' not in globals():
import nvdiffrast.torch as dr
resolution = self.rendering_options["resolution"]
near = self.rendering_options["near"]
far = self.rendering_options["far"]
ssaa = self.rendering_options["ssaa"]
chunk_size = self.rendering_options["chunk_size"]
antialias = self.rendering_options["antialias"]
clamp_barycentric_coords = self.rendering_options["clamp_barycentric_coords"]
if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
ret_dict = edict()
for type in return_types:
if type == "mask" :
ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device)
elif type == "depth":
ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device)
elif type == "normal":
ret_dict[type] = torch.full((3, resolution, resolution), 0.5, dtype=torch.float32, device=self.device)
elif type == "coord":
ret_dict[type] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device)
elif type == "attr":
if isinstance(mesh, MeshWithVoxel):
ret_dict[type] = torch.zeros((mesh.attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device)
else:
ret_dict[type] = torch.zeros((mesh.vertex_attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device)
return ret_dict
perspective = intrinsics_to_projection(intrinsics, near, far)
full_proj = (perspective @ extrinsics).unsqueeze(0)
extrinsics = extrinsics.unsqueeze(0)
vertices = mesh.vertices.unsqueeze(0)
vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
if transformation is not None:
vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2))
vertices = vertices_homo[..., :3].contiguous()
vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2))
vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
faces = mesh.faces
if 'normal' in return_types:
v0 = vertices_camera[0, mesh.faces[:, 0], :3]
v1 = vertices_camera[0, mesh.faces[:, 1], :3]
v2 = vertices_camera[0, mesh.faces[:, 2], :3]
e0 = v1 - v0
e1 = v2 - v0
face_normal = torch.cross(e0, e1, dim=1)
face_normal = F.normalize(face_normal, dim=1)
face_normal = torch.where(torch.sum(face_normal * v0, dim=1, keepdim=True) > 0, face_normal, -face_normal)
out_dict = edict()
if chunk_size is None:
rast, rast_db = dr.rasterize(
self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa)
)
if clamp_barycentric_coords:
rast[..., :2] = torch.clamp(rast[..., :2], 0, 1)
rast[..., :2] /= torch.where(rast[..., :2].sum(dim=-1, keepdim=True) > 1, rast[..., :2].sum(dim=-1, keepdim=True), torch.ones_like(rast[..., :2]))
for type in return_types:
img = None
if type == "mask" :
img = (rast[..., -1:] > 0).float()
if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
elif type == "depth":
img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0]
if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
elif type == "normal" :
img = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0]
if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
img = (img + 1) / 2
elif type == "coord":
img = dr.interpolate(vertices, rast, faces)[0]
if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
elif type == "attr":
if isinstance(mesh, MeshWithVoxel):
if 'grid_sample_3d' not in globals():
from flex_gemm.ops.grid_sample import grid_sample_3d
mask = rast[..., -1:] > 0
xyz = dr.interpolate(vertices, rast, faces)[0]
xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3)
img = grid_sample_3d(
mesh.attrs,
torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1),
mesh.voxel_shape,
xyz,
mode='trilinear'
)
img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask
elif isinstance(mesh, MeshWithPbrMaterial):
tri_id = rast[0, :, :, -1:]
mask = tri_id > 0
uv_coords = mesh.uv_coords.reshape(1, -1, 2)
texc, texd = dr.interpolate(
uv_coords,
rast,
torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3),
rast_db=rast_db,
diff_attrs='all'
)
# Fix problematic texture coordinates
texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3)
texc = torch.clamp(texc, min=-1e3, max=1e3)
texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3)
texd = torch.clamp(texd, min=-1e3, max=1e3)
mid = mesh.material_ids[(tri_id - 1).long()]
imgs = {
'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device),
'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device),
'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device),
'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
}
for id, mat in enumerate(mesh.materials):
mat_mask = (mid == id).float() * mask.float()
mat_texc = texc * mat_mask
mat_texd = texd * mat_mask
if mat.base_color_texture is not None:
base_color = dr.texture(
mat.base_color_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
imgs['base_color'] += base_color * mat.base_color_factor * mat_mask
else:
imgs['base_color'] += mat.base_color_factor * mat_mask
if mat.metallic_texture is not None:
metallic = dr.texture(
mat.metallic_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
imgs['metallic'] += metallic * mat.metallic_factor * mat_mask
else:
imgs['metallic'] += mat.metallic_factor * mat_mask
if mat.roughness_texture is not None:
roughness = dr.texture(
mat.roughness_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
imgs['roughness'] += roughness * mat.roughness_factor * mat_mask
else:
imgs['roughness'] += mat.roughness_factor * mat_mask
if mat.alpha_mode == AlphaMode.OPAQUE:
imgs['alpha'] += 1.0 * mat_mask
else:
if mat.alpha_texture is not None:
alpha = dr.texture(
mat.alpha_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
if mat.alpha_mode == AlphaMode.MASK:
imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
elif mat.alpha_mode == AlphaMode.BLEND:
imgs['alpha'] += alpha * mat.alpha_factor * mat_mask
else:
if mat.alpha_mode == AlphaMode.MASK:
imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
elif mat.alpha_mode == AlphaMode.BLEND:
imgs['alpha'] += mat.alpha_factor * mat_mask
img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0)
else:
img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces)[0]
if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
out_dict[type] = img
else:
z_buffer = torch.full((1, resolution * ssaa, resolution * ssaa), torch.inf, device=self.device, dtype=torch.float32)
for i in range(0, faces.shape[0], chunk_size):
faces_chunk = faces[i:i+chunk_size]
rast, rast_db = dr.rasterize(
self.glctx, vertices_clip, faces_chunk, (resolution * ssaa, resolution * ssaa)
)
z_filter = torch.logical_and(
rast[..., 3] != 0,
rast[..., 2] < z_buffer
)
z_buffer[z_filter] = rast[z_filter][..., 2]
for type in return_types:
img = None
if type == "mask" :
img = (rast[..., -1:] > 0).float()
elif type == "depth":
img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_chunk)[0]
elif type == "normal" :
face_normal_chunk = face_normal[i:i+chunk_size]
img = dr.interpolate(face_normal_chunk.unsqueeze(0), rast, torch.arange(face_normal_chunk.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0]
img = (img + 1) / 2
elif type == "coord":
img = dr.interpolate(vertices, rast, faces_chunk)[0]
elif type == "attr":
if isinstance(mesh, MeshWithVoxel):
if 'grid_sample_3d' not in globals():
from flex_gemm.ops.grid_sample import grid_sample_3d
mask = rast[..., -1:] > 0
xyz = dr.interpolate(vertices, rast, faces_chunk)[0]
xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3)
img = grid_sample_3d(
mesh.attrs,
torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1),
mesh.voxel_shape,
xyz,
mode='trilinear'
)
img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask
elif isinstance(mesh, MeshWithPbrMaterial):
tri_id = rast[0, :, :, -1:]
mask = tri_id > 0
uv_coords = mesh.uv_coords.reshape(1, -1, 2)
texc, texd = dr.interpolate(
uv_coords,
rast,
torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3),
rast_db=rast_db,
diff_attrs='all'
)
# Fix problematic texture coordinates
texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3)
texc = torch.clamp(texc, min=-1e3, max=1e3)
texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3)
texd = torch.clamp(texd, min=-1e3, max=1e3)
mid = mesh.material_ids[(tri_id - 1).long()]
imgs = {
'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device),
'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device),
'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device),
'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
}
for id, mat in enumerate(mesh.materials):
mat_mask = (mid == id).float() * mask.float()
mat_texc = texc * mat_mask
mat_texd = texd * mat_mask
if mat.base_color_texture is not None:
base_color = dr.texture(
mat.base_color_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
imgs['base_color'] += base_color * mat.base_color_factor * mat_mask
else:
imgs['base_color'] += mat.base_color_factor * mat_mask
if mat.metallic_texture is not None:
metallic = dr.texture(
mat.metallic_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
imgs['metallic'] += metallic * mat.metallic_factor * mat_mask
else:
imgs['metallic'] += mat.metallic_factor * mat_mask
if mat.roughness_texture is not None:
roughness = dr.texture(
mat.roughness_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
imgs['roughness'] += roughness * mat.roughness_factor * mat_mask
else:
imgs['roughness'] += mat.roughness_factor * mat_mask
if mat.alpha_mode == AlphaMode.OPAQUE:
imgs['alpha'] += 1.0 * mat_mask
else:
if mat.alpha_texture is not None:
alpha = dr.texture(
mat.alpha_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
if mat.alpha_mode == AlphaMode.MASK:
imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
elif mat.alpha_mode == AlphaMode.BLEND:
imgs['alpha'] += alpha * mat.alpha_factor * mat_mask
else:
if mat.alpha_mode == AlphaMode.MASK:
imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
elif mat.alpha_mode == AlphaMode.BLEND:
imgs['alpha'] += mat.alpha_factor * mat_mask
img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0)
else:
img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces_chunk)[0]
if type not in out_dict:
out_dict[type] = img
else:
out_dict[type][z_filter] = img[z_filter]
for type in return_types:
img = out_dict[type]
if ssaa > 1:
img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
img = img.squeeze()
else:
img = img.permute(0, 3, 1, 2).squeeze()
out_dict[type] = img
if isinstance(mesh, (MeshWithVoxel, MeshWithPbrMaterial)) and 'attr' in return_types:
for k, s in mesh.layout.items():
out_dict[k] = out_dict['attr'][s]
del out_dict['attr']
return out_dict
from typing import *
import torch
from easydict import EasyDict as edict
import numpy as np
import utils3d
from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode
import torch.nn.functional as F
from ..utils.pipeline_logger import get_logger, log_mesh, log_uv, log_tensor, elapsed, section
from ..modules.sparse.linear import ROCM_SAFE_CHUNK
def _safe_transform4x4(vertices_homo: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor:
"""
Chunked drop-in for torch.bmm(vertices_homo, matrix) to work around the
ROCm GEMM bug where N > ~800k produces corrupt results.
vertices_homo: [B, N, 4] matrix: [B, 4, 4]
"""
B, N, _ = vertices_homo.shape
if N <= ROCM_SAFE_CHUNK:
return torch.bmm(vertices_homo, matrix)
parts = []
for s in range(0, N, ROCM_SAFE_CHUNK):
e = min(s + ROCM_SAFE_CHUNK, N)
parts.append(torch.bmm(vertices_homo[:, s:e, :], matrix))
return torch.cat(parts, dim=1)
def cube_to_dir(s, x, y):
if s == 0: rx, ry, rz = torch.ones_like(x), -x, -y
elif s == 1: rx, ry, rz = -torch.ones_like(x), x, -y
elif s == 2: rx, ry, rz = x, y, torch.ones_like(x)
elif s == 3: rx, ry, rz = x, -y, -torch.ones_like(x)
elif s == 4: rx, ry, rz = x, torch.ones_like(x), -y
elif s == 5: rx, ry, rz = -x, -torch.ones_like(x), -y
return torch.stack((rx, ry, rz), dim=-1)
def latlong_to_cubemap(latlong_map, res):
if 'dr' not in globals():
import nvdiffrast.torch as dr
cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda')
for s in range(6):
gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'),
torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'),
indexing='ij')
v = F.normalize(cube_to_dir(s, gx, gy), dim=-1)
tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5
tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi
texcoord = torch.cat((tu, tv), dim=-1)
cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0]
return cubemap
class EnvMap:
def __init__(self, image: torch.Tensor):
self.image = image
@property
def _backend(self):
if not hasattr(self, '_nvdiffrec_envlight'):
if 'EnvironmentLight' not in globals():
from nvdiffrec_render.light import EnvironmentLight
cubemap = latlong_to_cubemap(self.image, [512, 512])
self._nvdiffrec_envlight = EnvironmentLight(cubemap)
self._nvdiffrec_envlight.build_mips()
return self._nvdiffrec_envlight
def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True):
return self._backend.shade(gb_pos, gb_normal, kd, ks, view_pos, specular)
def sample(self, directions: torch.Tensor):
if 'dr' not in globals():
import nvdiffrast.torch as dr
return dr.texture(
self._backend.base.unsqueeze(0),
directions.unsqueeze(0),
boundary_mode='cube',
)[0]
def intrinsics_to_projection(
intrinsics: torch.Tensor,
near: float,
far: float,
) -> torch.Tensor:
"""
OpenCV intrinsics to OpenGL perspective matrix
Args:
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
near (float): near plane to clip
far (float): far plane to clip
Returns:
(torch.Tensor): [4, 4] OpenGL perspective matrix
"""
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
ret[0, 0] = 2 * fx
ret[1, 1] = 2 * fy
ret[0, 2] = 2 * cx - 1
ret[1, 2] = - 2 * cy + 1
ret[2, 2] = (far + near) / (far - near)
ret[2, 3] = 2 * near * far / (near - far)
ret[3, 2] = 1.
return ret
def screen_space_ambient_occlusion(
depth: torch.Tensor,
normal: torch.Tensor,
perspective: torch.Tensor,
radius: float = 0.1,
bias: float = 1e-6,
samples: int = 64,
intensity: float = 1.0,
) -> torch.Tensor:
"""
Screen space ambient occlusion (SSAO)
Args:
depth (torch.Tensor): [H, W, 1] depth image
normal (torch.Tensor): [H, W, 3] normal image
perspective (torch.Tensor): [4, 4] camera projection matrix
radius (float): radius of the SSAO kernel
bias (float): bias to avoid self-occlusion
samples (int): number of samples to use for the SSAO kernel
intensity (float): intensity of the SSAO effect
Returns:
(torch.Tensor): [H, W, 1] SSAO image
"""
device = depth.device
H, W, _ = depth.shape
fx = perspective[0, 0]
fy = perspective[1, 1]
cx = perspective[0, 2]
cy = perspective[1, 2]
y_grid, x_grid = torch.meshgrid(
(torch.arange(H, device=device) + 0.5) / H * 2 - 1,
(torch.arange(W, device=device) + 0.5) / W * 2 - 1,
indexing='ij'
)
x_view = (x_grid.float() - cx) * depth[..., 0] / fx
y_view = (y_grid.float() - cy) * depth[..., 0] / fy
view_pos = torch.stack([x_view, y_view, depth[..., 0]], dim=-1) # [H, W, 3]
depth_feat = depth.permute(2, 0, 1).unsqueeze(0)
occlusion = torch.zeros((H, W), device=device)
# start sampling
for _ in range(samples):
# sample normal distribution, if inside, flip the sign
rnd_vec = torch.randn(H, W, 3, device=device)
rnd_vec = F.normalize(rnd_vec, p=2, dim=-1)
dot_val = torch.sum(rnd_vec * normal, dim=-1, keepdim=True)
sample_dir = torch.sign(dot_val) * rnd_vec
scale = torch.rand(H, W, 1, device=device)
scale = scale * scale
sample_pos = view_pos + sample_dir * radius * scale
sample_z = sample_pos[..., 2]
# project to screen space
z_safe = torch.clamp(sample_pos[..., 2], min=1e-5)
proj_u = (sample_pos[..., 0] * fx / z_safe) + cx
proj_v = (sample_pos[..., 1] * fy / z_safe) + cy
grid = torch.stack([proj_u, proj_v], dim=-1).unsqueeze(0)
geo_z = F.grid_sample(depth_feat, grid, mode='nearest', padding_mode='border').squeeze()
range_check = torch.abs(geo_z - sample_z) < radius
is_occluded = (geo_z <= sample_z - bias) & range_check
occlusion += is_occluded.float()
f_occ = occlusion / samples * intensity
f_occ = torch.clamp(f_occ, 0.0, 1.0)
return f_occ.unsqueeze(-1)
def aces_tonemapping(x: torch.Tensor) -> torch.Tensor:
"""
Applies ACES tone mapping curve to an HDR image tensor.
Input: x - HDR tensor, shape (..., 3), range [0, +inf)
Output: LDR tensor, same shape, range [0, 1]
"""
a = 2.51
b = 0.03
c = 2.43
d = 0.59
e = 0.14
# Apply the ACES fitted curve
mapped = (x * (a * x + b)) / (x * (c * x + d) + e)
# Clamp to [0, 1] for display or saving
return torch.clamp(mapped, 0.0, 1.0)
def gamma_correction(x: torch.Tensor, gamma: float = 2.2) -> torch.Tensor:
"""
Applies gamma correction to an HDR image tensor.
"""
return torch.clamp(x ** (1.0 / gamma), 0.0, 1.0)
class PbrMeshRenderer:
"""
Renderer for the PBR mesh.
Args:
rendering_options (dict): Rendering options.
"""
def __init__(self, rendering_options={}, device='cuda'):
if 'dr' not in globals():
import nvdiffrast.torch as dr
self.rendering_options = edict({
"resolution": None,
"near": None,
"far": None,
"ssaa": 1,
"peel_layers": 8,
})
self.rendering_options.update(rendering_options)
self.glctx = dr.RasterizeCudaContext(device=device)
self.device=device
def render(
self,
mesh : Mesh,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
envmap : Union[EnvMap, Dict[str, EnvMap]],
use_envmap_bg : bool = False,
transformation : Optional[torch.Tensor] = None
) -> edict:
"""
Render the mesh.
Args:
mesh : meshmodel
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
envmap (Union[EnvMap, Dict[str, EnvMap]]): environment map or a dictionary of environment maps
use_envmap_bg (bool): whether to use envmap as background
transformation (torch.Tensor): (4, 4) transformation matrix
Returns:
edict based on return_types containing:
shaded (torch.Tensor): [3, H, W] shaded color image
normal (torch.Tensor): [3, H, W] normal image
base_color (torch.Tensor): [3, H, W] base color image
metallic (torch.Tensor): [H, W] metallic image
roughness (torch.Tensor): [H, W] roughness image
"""
if 'dr' not in globals():
import nvdiffrast.torch as dr
if not isinstance(envmap, dict):
envmap = {'' : envmap}
num_envmaps = len(envmap)
resolution = self.rendering_options["resolution"]
near = self.rendering_options["near"]
far = self.rendering_options["far"]
ssaa = self.rendering_options["ssaa"]
if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
out_dict = edict(
normal=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device),
mask=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
base_color=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device),
metallic=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
roughness=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
alpha=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
clay=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
)
for i, k in enumerate(envmap.keys()):
shaded_key = f"shaded_{k}" if k != '' else "shaded"
out_dict[shaded_key] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device)
return out_dict
rays_o, rays_d = utils3d.torch.get_image_rays(
extrinsics, intrinsics, resolution * ssaa, resolution * ssaa
)
perspective = intrinsics_to_projection(intrinsics, near, far)
full_proj = (perspective @ extrinsics).unsqueeze(0)
extrinsics = extrinsics.unsqueeze(0)
L = get_logger()
section(f"PbrMeshRenderer.render res={resolution} ssaa={ssaa}")
vertices = mesh.vertices.unsqueeze(0)
vertices_orig = vertices.clone()
vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
if transformation is not None:
vertices_homo = _safe_transform4x4(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2))
vertices = vertices_homo[..., :3].contiguous()
vertices_camera = _safe_transform4x4(vertices_homo, extrinsics.transpose(-1, -2))
vertices_clip = _safe_transform4x4(vertices_homo, full_proj.transpose(-1, -2))
faces = mesh.faces
# ── Pre-rasterize sanity checks ──────────────────────────────────────
log_mesh(mesh.vertices, mesh.faces, "renderer-input")
L.info(f" {elapsed()} full_proj:\n{full_proj[0].cpu().numpy()}")
vc = vertices_clip[0] # [N, 4]
has_nan = torch.isnan(vc).any().item()
has_inf = torch.isinf(vc).any().item()
w_min, w_max = vc[:, 3].min().item(), vc[:, 3].max().item()
w_zero = (vc[:, 3].abs() < 1e-6).sum().item()
L.info(f" {elapsed()} vertices_clip: shape={list(vc.shape)} "
f"NaN={has_nan} inf={has_inf} "
f"x=[{vc[:,0].min().item():.4g},{vc[:,0].max().item():.4g}] "
f"y=[{vc[:,1].min().item():.4g},{vc[:,1].max().item():.4g}] "
f"z=[{vc[:,2].min().item():.4g},{vc[:,2].max().item():.4g}] "
f"w=[{w_min:.4g},{w_max:.4g}] w_zeros={w_zero}")
if has_nan or has_inf:
L.error(" ⚠ vertices_clip has NaN/inf — rasterizer will produce garbage!")
if w_min < 0:
L.warning(f" ⚠ vertices_clip has negative w values ({(vc[:,3]<0).sum().item()} vertices)"
" — behind camera, may cause artifacts")
# NDC coords after perspective divide
ndc = vc[:, :3] / vc[:, 3:4].clamp(min=1e-6)
L.info(f" {elapsed()} NDC (after w-divide): "
f"x=[{ndc[:,0].min().item():.4g},{ndc[:,0].max().item():.4g}] "
f"y=[{ndc[:,1].min().item():.4g},{ndc[:,1].max().item():.4g}] "
f"z=[{ndc[:,2].min().item():.4g},{ndc[:,2].max().item():.4g}] "
f"out_of_frustum={(ndc.abs() > 1.0).any(dim=1).sum().item()}/{vc.shape[0]}")
# Normal computation is skipped — all GPU and CPU smooth-normal approaches
# produce artifacts on ROCm GFX1201 for large meshes.
# A constant normal is used instead: normal view will be flat, but PBR/clay
# renders will be artifact-free.
_faces_cpu = mesh.faces.long().cpu() # [F, 3] — needed in the render loop
out_dict = edict()
shaded = torch.zeros((num_envmaps, resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device)
depth = torch.full((resolution * ssaa, resolution * ssaa, 1), 1e10, dtype=torch.float32, device=self.device)
normal = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device)
max_w = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
rast_test, _ = dr.rasterize(self.glctx, vertices_clip, faces, resolution=[resolution * ssaa, resolution * ssaa])
max_tri_id = rast_test[..., -1].max().item()
visible_px = (rast_test[..., -1] > 0).sum().item()
total_px = (resolution * ssaa) ** 2
L.info(f" {elapsed()} rasterize test: max_tri_id={max_tri_id:.0f} "
f"visible_px={visible_px}/{total_px} ({100.*visible_px/total_px:.1f}%)")
if max_tri_id > mesh.faces.shape[0]:
L.error(f" ⚠ max_tri_id {max_tri_id} > num_faces {mesh.faces.shape[0]} — CORRUPT RASTERIZE OUTPUT")
with dr.DepthPeeler(self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa)) as peeler:
for _ in range(self.rendering_options["peel_layers"]):
rast, rast_db = peeler.rasterize_next_layer()
if _ in [0, 1, 2]:
visible_pixels = (rast[..., -1] > 0).sum().item()
L.info(f" {elapsed()} DepthPeel layer={_} visible_px={visible_pixels}")
# Pos
pos = dr.interpolate(vertices, rast, faces)[0][0]
# Depth
gb_depth = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0][0]
# Constant normal pointing toward the camera (-Z in camera space).
# Smooth normal computation is unreliable on ROCm GFX1201 large meshes.
H = rast.shape[1]; W = rast.shape[2]
gb_normal = torch.zeros(H, W, 3, dtype=torch.float32, device=self.device)
gb_normal[..., 2] = -1.0
gb_normal = gb_normal * (rast[0, ..., 3:4] > 0).float()
gb_cam_normal = (extrinsics[..., :3, :3].reshape(1, 1, 3, 3) @ gb_normal.unsqueeze(-1)).squeeze(-1)
if _ == 0:
out_dict.normal = -gb_cam_normal * 0.5 + 0.5
mask = (rast[0, ..., -1:] > 0).float()
out_dict.mask = mask
# PBR attributes
if isinstance(mesh, MeshWithVoxel):
if 'grid_sample_3d' not in globals():
from flex_gemm.ops.grid_sample import grid_sample_3d
mask = rast[..., -1:] > 0
xyz = dr.interpolate(vertices_orig, rast, faces)[0]
xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3)
img = grid_sample_3d(
mesh.attrs,
torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1),
mesh.voxel_shape,
xyz,
mode='trilinear'
)
img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask
gb_basecolor = img[0, ..., mesh.layout['base_color']]
gb_metallic = img[0, ..., mesh.layout['metallic']]
gb_roughness = img[0, ..., mesh.layout['roughness']]
gb_alpha = img[0, ..., mesh.layout['alpha']]
elif isinstance(mesh, MeshWithPbrMaterial):
tri_id = rast[0, :, :, -1:]
mask = tri_id > 0
if _ == 0: # log once per render call
L.info(f" {elapsed()} MeshWithPbrMaterial: "
f"uv_coords={list(mesh.uv_coords.shape)} "
f"material_ids={list(mesh.material_ids.shape)} "
f"num_materials={len(mesh.materials)}")
log_uv(mesh.uv_coords.reshape(-1, 2), "mesh.uv_coords")
fi_min = mesh.material_ids.min().item()
fi_max = mesh.material_ids.max().item()
L.info(f" {elapsed()} material_ids range=[{fi_min},{fi_max}] "
f"num_materials={len(mesh.materials)}")
if fi_max >= len(mesh.materials):
L.error(f" ⚠ material_ids max {fi_max} >= num_materials {len(mesh.materials)}!")
uv_coords = mesh.uv_coords.reshape(1, -1, 2)
texc, texd = dr.interpolate(
uv_coords,
rast,
torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3),
rast_db=rast_db,
diff_attrs='all'
)
if _ == 0:
log_tensor(texc, "texc-pre-clamp")
# Fix problematic texture coordinates
texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3)
texc = torch.clamp(texc, min=-1e3, max=1e3)
texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3)
texd = torch.clamp(texd, min=-1e3, max=1e3)
if _ == 0:
log_tensor(texc, "texc-post-clamp")
mid = mesh.material_ids[(tri_id - 1).long()]
gb_basecolor = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device)
gb_metallic = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
gb_roughness = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
gb_alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
for id, mat in enumerate(mesh.materials):
mat_mask = (mid == id).float() * mask.float()
mat_texc = texc * mat_mask
mat_texd = texd * mat_mask
if mat.base_color_texture is not None:
bc = dr.texture(
mat.base_color_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
gb_basecolor += bc * mat.base_color_factor * mat_mask
else:
gb_basecolor += mat.base_color_factor * mat_mask
if mat.metallic_texture is not None:
m = dr.texture(
mat.metallic_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
gb_metallic += m * mat.metallic_factor * mat_mask
else:
gb_metallic += mat.metallic_factor * mat_mask
if mat.roughness_texture is not None:
r = dr.texture(
mat.roughness_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
gb_roughness += r * mat.roughness_factor * mat_mask
else:
gb_roughness += mat.roughness_factor * mat_mask
if mat.alpha_mode == AlphaMode.OPAQUE:
gb_alpha += 1.0 * mat_mask
else:
if mat.alpha_texture is not None:
a = dr.texture(
mat.alpha_texture.image.unsqueeze(0),
mat_texc,
mat_texd,
filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
)[0]
if mat.alpha_mode == AlphaMode.MASK:
gb_alpha += (a * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
elif mat.alpha_mode == AlphaMode.BLEND:
gb_alpha += a * mat.alpha_factor * mat_mask
else:
if mat.alpha_mode == AlphaMode.MASK:
gb_alpha += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
elif mat.alpha_mode == AlphaMode.BLEND:
gb_alpha += mat.alpha_factor * mat_mask
if _ == 0:
out_dict.base_color = gb_basecolor
out_dict.metallic = gb_metallic
out_dict.roughness = gb_roughness
out_dict.alpha = gb_alpha
# Shading
gb_basecolor = torch.clamp(gb_basecolor, 0.0, 1.0) ** 2.2
gb_metallic = torch.clamp(gb_metallic, 0.0, 1.0)
gb_roughness = torch.clamp(gb_roughness, 0.0, 1.0)
gb_alpha = torch.clamp(gb_alpha, 0.0, 1.0)
gb_orm = torch.cat([
torch.zeros_like(gb_metallic),
gb_roughness,
gb_metallic,
], dim=-1)
_log = get_logger()
_log.debug(f"--- RASTERIZATION DEBUG --- pos sum: {pos.sum().item()} | max: {pos.max().item()}")
_log.debug(f"gb_normal sum: {gb_normal.sum().item()} | gb_basecolor sum: {gb_basecolor.sum().item()} | gb_orm sum: {gb_orm.sum().item()} | mask sum: {mask.float().sum().item()}")
gb_shaded = torch.stack([
e.shade(
pos.unsqueeze(0),
gb_normal.unsqueeze(0),
gb_basecolor.unsqueeze(0),
gb_orm.unsqueeze(0),
rays_o,
specular=True,
)[0]
for e in envmap.values()
], dim=0)
# Compositing
w = (1 - alpha) * gb_alpha
depth = torch.where(w > max_w, gb_depth, depth)
normal = torch.where(w > max_w, gb_cam_normal, normal)
max_w = torch.maximum(max_w, w)
shaded += w * gb_shaded
alpha += w
# Ambient occulusion
f_occ = screen_space_ambient_occlusion(
depth, normal, perspective, intensity=1.5
)
shaded *= (1 - f_occ)
out_dict.clay = (1 - f_occ)
# Background
if use_envmap_bg:
bg = torch.stack([e.sample(rays_d) for e in envmap.values()], dim=0)
shaded += (1 - alpha) * bg
for i, k in enumerate(envmap.keys()):
shaded_key = f"shaded_{k}" if k != '' else "shaded"
out_dict[shaded_key] = shaded[i]
# SSAA
for k in out_dict.keys():
if ssaa > 1:
out_dict[k] = F.interpolate(out_dict[k].unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
else:
out_dict[k] = out_dict[k].permute(2, 0, 1)
out_dict[k] = out_dict[k].squeeze()
# Post processing
for k in envmap.keys():
shaded_key = f"shaded_{k}" if k != '' else "shaded"
out_dict[shaded_key] = aces_tonemapping(out_dict[shaded_key])
out_dict[shaded_key] = gamma_correction(out_dict[shaded_key])
return out_dict
import torch
from easydict import EasyDict as edict
from ..representations import Voxel
from easydict import EasyDict as edict
class VoxelRenderer:
"""
Renderer for the Voxel representation.
Args:
rendering_options (dict): Rendering options.
"""
def __init__(self, rendering_options={}) -> None:
self.rendering_options = edict({
"resolution": None,
"near": 0.1,
"far": 10.0,
"ssaa": 1,
})
self.rendering_options.update(rendering_options)
def render(
self,
voxel: Voxel,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
colors_overwrite: torch.Tensor = None
) -> edict:
"""
Render the gausssian.
Args:
voxel (Voxel): Voxel representation.
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
colors_overwrite (torch.Tensor): (N, 3) override color
Returns:
edict containing:
color (torch.Tensor): (3, H, W) rendered color image
depth (torch.Tensor): (H, W) rendered depth
alpha (torch.Tensor): (H, W) rendered alpha
...
"""
# lazy import
if 'o_voxel' not in globals():
import o_voxel
renderer = o_voxel.rasterize.VoxelRenderer(self.rendering_options)
positions = voxel.position
attrs = voxel.attrs if colors_overwrite is None else colors_overwrite
voxel_size = voxel.voxel_size
# Render
render_ret = renderer.render(positions, attrs, voxel_size, extrinsics, intrinsics)
ret = {
'depth': render_ret['depth'],
'alpha': render_ret['alpha'],
}
if colors_overwrite is not None:
ret['color'] = render_ret['attr']
else:
for k, s in voxel.layout.items():
ret[k] = render_ret['attr'][s]
return ret
import importlib
__attributes = {
'Mesh': 'mesh',
'Voxel': 'voxel',
'MeshWithVoxel': 'mesh',
'MeshWithPbrMaterial': 'mesh',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial
from .voxel import Voxel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment