Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
\ No newline at end of file
# TRELLIS.2 ## 论文
[Native and Compact Structured Latents for 3D Generation](https://arxiv.org/abs/2512.14692?utm_source=chatgpt.com)
## 模型简介
TRELLIS.2是一款最先进的大型3D生成模型(40亿参数),专为高保真图像到3D生成而设计。它利用一种名为O-Voxel的新型“无场”稀疏体素结构,来重建和生成具有复杂拓扑结构、锐利特征和完整PBR材质的任意3D资产。
## 环境依赖
| 软件 | 版本 |
| :------: | :------: |
| DTK | 26.04 |
| Python | 3.10 |
| Transformers | 4.56.0 |
推荐使用镜像:
harbor.sourcefind.cn:5443/dcu/admin/base/custom:ubuntu22.04-dtk26.04-py3.10-20260526-trellis2
- 挂载地址`-v` 根据实际模型情况修改
```bash
docker run -it \
--shm-size 200g \
--network=host \
--name TRELLIS2 \
--privileged \
--device=/dev/kfd \
--device=/dev/dri \
--device=/dev/mkfd \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-u root \
-v /opt/hyhal/:/opt/hyhal/:ro \
-v /path/your_code_data/:/path/your_code_data/ \
harbor.sourcefind.cn:5443/dcu/admin/base/custom:ubuntu22.04-dtk26.04-py3.10-20260526-trellis2 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装
其它包安装:
```bash
git clone https://developer.sourcefind.cn/codes/modelzoo/trellis.2.git
```
## 预训练权重
**请根据`支持的DCU型号`选择对应模型下载,FP8模型仅在BW1100/BW1101上支持,其他型号请勿使用!**
| 模型名称 | 权重大小 | 数据类型 | 支持的DCU型号 | 最低卡数需求 |下载地址|
|:-----:|:----------:|:----------:|:----------:|:---------------------:|:----------:|
| TRELLIS.2-4B | 4B | BF16 | BW1000 | 1 | [HuggingFace](https://huggingface.co/microsoft/TRELLIS.2-4B) |
| TRELLIS-image-large | 1B | BF16 | BW1000 | 1 | [HuggingFace](https://huggingface.co/microsoft/TRELLIS-image-large?utm_source=chatgpt.com) |
| dinov3-vitl16-pretrain-lvd1689m | 0.5B | BF16 | BW1000 | 1 | [HuggingFace](https://huggingface.co/facebook/dinov3-vitl16-pretrain-lvd1689m?utm_source=chatgpt.com) |
| RMBG-2.0 | 0.5B | BF16 | BW1000 | 1 | [HuggingFace](https://huggingface.co/briaai/RMBG-2.0?utm_source=chatgpt.com) |
## 数据集
暂无
## 训练
暂无
## 推理
### Torch
#### 单机推理
**使用离线模型需要进行环境变量配置,如果网络条件允许在线拉取模型,可以忽略环境变量设置**
```bash
export HF_HOME=/path/to/hf_cache
export HUGGINGFACE_HUB_CACHE=$HF_HOME/hub
export HF_HUB_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
cd TRELLIS.2
python app.py
```
## 效果展示
<div align=center>
<img src="./doc/01.png"/>
</div>
### 精度
DCU与GPU精度一致,推理框架:Pytorch
## 源码仓库及问题反馈
- [此处填本项目gitlab地址](https://developer.sourcefind.cn/codes/modelzoo/trellis.2)
## 参考资料
- https://github.com/microsoft/TRELLIS.2
- https://microsoft.github.io/TRELLIS.2/?utm_source=chatgpt.com
- ......
其他说明:
关于model.properties(必要)、LICENSE(必要)、CONTRIBUTORS、模型图标(必要)等其它信息提供参照:[`ModelZooStd.md`](./ModelZooStd.md)
各个模型需要保留原项目README.md,改名为README_origin.md即可。
\ No newline at end of file
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
\ No newline at end of file
![](assets/teaser.webp)
# Native and Compact Structured Latents for 3D Generation
<a href="https://arxiv.org/abs/2512.14692"><img src="https://img.shields.io/badge/Paper-Arxiv-b31b1b.svg" alt="Paper"></a>
<a href="https://huggingface.co/microsoft/TRELLIS.2-4B"><img src="https://img.shields.io/badge/Hugging%20Face-Model-yellow" alt="Hugging Face"></a>
<a href="https://huggingface.co/spaces/microsoft/TRELLIS.2"><img src="https://img.shields.io/badge/Hugging%20Face-Demo-blueviolet"></a>
<a href="https://microsoft.github.io/TRELLIS.2"><img src="https://img.shields.io/badge/Project-Website-blue" alt="Project Page"></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-green" alt="License"></a>
https://github.com/user-attachments/assets/63b43a7e-acc7-4c81-a900-6da450527d8f
*(Compressed version due to GitHub size limits. See the full-quality video on our project page!)*
**TRELLIS.2** is a state-of-the-art large 3D generative model (4B parameters) designed for high-fidelity **image-to-3D** generation. It leverages a novel "field-free" sparse voxel structure termed **O-Voxel** to reconstruct and generate arbitrary 3D assets with complex topologies, sharp features, and full PBR materials.
## ✨ Features
### 1. High Quality, Resolution & Efficiency
Our 4B-parameter model generates high-resolution fully textured assets with exceptional fidelity and efficiency using vanilla DiTs. It utilizes a Sparse 3D VAE with 16× spatial downsampling to encode assets into a compact latent space.
| Resolution | Total Time* | Breakdown (Shape + Mat) |
| :--- | :--- | :--- |
| **512³** | **~3s** | 2s + 1s |
| **1024³** | **~17s** | 10s + 7s |
| **1536³** | **~60s** | 35s + 25s |
<small>*Tested on NVIDIA H100 GPU.</small>
### 2. Arbitrary Topology Handling
The **O-Voxel** representation breaks the limits of iso-surface fields. It robustly handles complex structures without lossy conversion:
***Open Surfaces** (e.g., clothing, leaves)
***Non-manifold Geometry**
***Internal Enclosed Structures**
### 3. Rich Texture Modeling
Beyond basic colors, TRELLIS.2 models arbitrary surface attributes including **Base Color, Roughness, Metallic, and Opacity**, enabling photorealistic rendering and transparency support.
### 4. Minimalist Processing
Data processing is streamlined for instant conversions that are fully **rendering-free** and **optimization-free**.
* **< 10s** (Single CPU): Textured Mesh → O-Voxel
* **< 100ms** (CUDA): O-Voxel → Textured Mesh
## 🗺️ Roadmap
- [x] Paper release
- [x] Release image-to-3D inference code
- [x] Release pretrained checkpoints (4B)
- [x] Hugging Face Spaces demo
- [x] Release shape-conditioned texture generation inference code
- [x] Release training code
## 🛠️ Installation
### Prerequisites
- **System**: Linux only.
- **Hardware**: An NVIDIA GPU (verified on A100/H100, 24GB+ recommended) or AMD GPU (verified on RX 9070 XT 16GB under ROCm).
- **Software**:
- **CUDA**: [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) 12.4 recommended.
- **ROCm**: [ROCm](https://rocm.docs.amd.com/en/latest/) 7.2 recommended.
- Python 3.10 or higher required.
### Installation Steps
1. Clone the repo:
```sh
git clone -b rocm https://github.com/Cardboard-box-a/TRELLIS.2_rocm.git --recursive
cd TRELLIS.2_rocm
```
2. Install PyTorch into your environment **before** running `setup.sh`. Use the index URL matching your platform:
**CUDA:**
```sh
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
```
**ROCm:**
```sh
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm7.2
```
3. Install the dependencies:
```sh
. ./setup.sh --basic --flash-attn --nvdiffrast --nvdiffrec --cumesh --o-voxel --flexgemm
```
Notes:
- `setup.sh` auto-detects CUDA vs ROCm and installs the appropriate variants.
- All packages including nvdiffrast and nvdiffrec work on both CUDA and ROCm.
- The installation may take a while — flash-attention builds from source on ROCm. Install flags one at a time if you hit issues.
- Run `. ./setup.sh --help` for the full list of flags.
## AMD ROCm Support
This branch has been tested on an **AMD RX 9070 XT 16GB** (gfx1201) under ROCm. The setup script auto-detects CUDA vs ROCm and installs the appropriate dependencies.
### Installation (ROCm)
First install ROCm PyTorch, then run setup:
```sh
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm7.2
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
. ./setup.sh --basic --flash-attn --cumesh --o-voxel --flexgemm --nvdiffrast --nvdiffrec
```
### Running
Flash Attention on ROCm requires the Triton backend. Export this before running:
```sh
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python app.py
```
If you prefer not to use Flash Attention, SDPA (Scaled Dot-Product Attention) is also supported. Set the attention backend before running:
```sh
export ATTN_BACKEND="sdpa"
python app.py
```
### AMD GPU Architecture
The `--flash-attn` step in `setup.sh` compiles for `gfx1201` (RX 9070 / RX 9070 XT) by default. If you have a different AMD GPU, edit the `GPU_ARCHS` line in `setup.sh` before running. Check your GPU's gfx architecture with `rocminfo | grep gfx`.
## 📦 Pretrained Weights
The pretrained model **TRELLIS.2-4B** is available on Hugging Face. Please refer to the model card there for more details.
| Model | Parameters | Resolution | Link |
| :--- | :--- | :--- | :--- |
| **TRELLIS.2-4B** | 4 Billion | 512³ - 1536³ | [Hugging Face](https://huggingface.co/microsoft/TRELLIS.2-4B) |
## 🚀 Usage
### 1. Image to 3D Generation
#### Minimal Example
Here is an [example](example.py) of how to use the pretrained models for 3D asset generation.
```python
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory
import cv2
import imageio
from PIL import Image
import torch
from trellis2.pipelines import Trellis2ImageTo3DPipeline
from trellis2.utils import render_utils
from trellis2.renderers import EnvMap
import o_voxel
# 1. Setup Environment Map
envmap = EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
))
# 2. Load Pipeline
pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B")
pipeline.cuda()
# 3. Load Image & Run
image = Image.open("assets/example_image/T.png")
mesh = pipeline.run(image)[0]
mesh.simplify(16777216) # nvdiffrast limit
# 4. Render Video
video = render_utils.make_pbr_vis_frames(render_utils.render_video(mesh, envmap=envmap))
imageio.mimsave("sample.mp4", video, fps=15)
# 5. Export to GLB
glb = o_voxel.postprocess.to_glb(
vertices = mesh.vertices,
faces = mesh.faces,
attr_volume = mesh.attrs,
coords = mesh.coords,
attr_layout = mesh.layout,
voxel_size = mesh.voxel_size,
aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target = 1000000,
texture_size = 4096,
remesh = True,
remesh_band = 1,
remesh_project = 0,
verbose = True
)
glb.export("sample.glb", extension_webp=True)
```
Upon execution, the script generates the following files:
- `sample.mp4`: A video visualizing the generated 3D asset with PBR materials and environmental lighting.
- `sample.glb`: The extracted PBR-ready 3D asset in GLB format.
**Note:** The `.glb` file is exported in `OPAQUE` mode by default. Although the alpha channel is preserved within the texture map, it is not active initially. To enable transparency, import the asset into your 3D software and manually connect the texture's alpha channel to the material's opacity or alpha input.
#### Web Demo
[app.py](app.py) provides a simple web demo for image to 3D asset generation. you can run the demo with the following command:
```sh
python app.py
```
Then, you can access the demo at the address shown in the terminal.
### 2. PBR Texture Generation
Please refer to the [example_texturing.py](example_texturing.py) for an example of how to generate PBR textures for a given 3D shape. Also, you can use the [app_texturing.py](app_texturing.py) to run a web demo for PBR texture generation.
## 🏋️ Training
We provide the full training codebase, enabling users to train **TRELLIS.2** from scratch or fine-tune it on custom datasets.
### 1. Data Preparation
Before training, raw 3D assets must be converted into the **O-Voxel** representation. This process includes mesh conversion, compact structured latent generation, and metadata preparation.
> 📂 **Please refer to [data_toolkit/README.md](data_toolkit/README.md) for detailed instructions on data preprocessing and dataset organization.**
### 2. Running Training
Training is managed through the `train.py` script, which accepts multiple command-line arguments to configure experiments:
* `--config`: Path to the experiment configuration file.
* `--output_dir`: Directory for training outputs.
* `--load_dir`: Directory to load checkpoints from (defaults to `output_dir`).
* `--ckpt`: Checkpoint step to resume from (defaults to the latest).
* `--data_dir`: Dataset path or a JSON string specifying dataset locations.
* `--auto_retry`: Number of automatic retries upon failure.
* `--tryrun`: Perform a dry run without actual training.
* `--profile`: Enable training profiling.
* `--num_nodes`: Number of nodes for distributed training.
* `--node_rank`: Rank of the current node.
* `--num_gpus`: Number of GPUs per node (defaults to all available GPUs).
* `--master_addr`: Master node address for distributed training.
* `--master_port`: Port for distributed training communication.
### SC-VAE Training
To train the shape SC-VAE, run:
```sh
python train.py \
--config configs/scvae/shape_vae_next_dc_f16c32_fp16.json \
--output_dir results/shape_vae_next_dc_f16c32_fp16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"mesh_dump\": \"datasets/ObjaverseXL_sketchfab/mesh_dumps\", \"dual_grid\": \"datasets/ObjaverseXL_sketchfab/dual_grid_256\", \"asset_stats\": \"datasets/ObjaverseXL_sketchfab/asset_stats\"}}"
```
This command trains the shape SC-VAE on the **Objaverse-XL** dataset using the `shape_vae_next_dc_f16c32_fp16.json` configuration. Training outputs will be saved to `results/shape_vae_next_dc_f16c32_fp16`.
The dataset is specified as a JSON string, where each dataset entry includes:
* `base`: Root directory of the dataset.
* `mesh_dump`: Directory containing preprocessed mesh dumps.
* `dual_grid`: Directory with precomputed dual-grid representations.
* `asset_stats`: Directory containing precomputed asset statistics.
To fine-tune the model at a higher resolution, use the `shape_vae_next_dc_f16c32_fp16_ft_512.json` configuration. Remember to update the `finetune_ckpt` field and adjust the dataset paths accordingly.
To train the texture SC-VAE, run:
```sh
python train.py \
--config configs/scvae/tex_vae_next_dc_f16c32_fp16.json \
--output_dir results/tex_vae_next_dc_f16c32_fp16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"pbr_dump\": \"datasets/ObjaverseXL_sketchfab/pbr_dumps\", \"pbr_voxel\": \"datasets/ObjaverseXL_sketchfab/pbr_voxels_256\", \"asset_stats\": \"datasets/ObjaverseXL_sketchfab/asset_stats\"}}"
```
### Flow Model Training
To train the sparse structure flow model, run:
```sh
python train.py \
--config configs/gen/ss_flow_img_dit_1_3B_64_bf16.json \
--output_dir results/ss_flow_img_dit_1_3B_64_bf16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"ss_latent\": \"datasets/ObjaverseXL_sketchfab/ss_latents/ss_enc_conv3d_16l8_fp16_64\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
```
This command trains the sparse-structure flow model on the **Objaverse-XL** dataset using the specified configuration file. Outputs are saved to `results/ss_flow_img_dit_1_3B_64_bf16`.
The dataset configuration includes:
* `base`: Root dataset directory.
* `ss_latent`: Directory containing precomputed sparse-structure latents.
* `render_cond`: Directory containing conditional rendering images.
The second- and third-stage flow models for shape and texture generation can be trained using the following configurations:
* Shape flow: `slat_flow_img2shape_dit_1_3B_512_bf16.json`
* Texture flow: `slat_flow_imgshape2tex_dit_1_3B_512_bf16.json`
Example commands:
```sh
# Shape flow model
python train.py \
--config configs/gen/slat_flow_img2shape_dit_1_3B_512_bf16.json \
--output_dir results/slat_flow_img2shape_dit_1_3B_512_bf16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"shape_latent\": \"datasets/ObjaverseXL_sketchfab/shape_latents/shape_enc_next_dc_f16c32_fp16_512\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
# Texture flow model
python train.py \
--config configs/gen/slat_flow_imgshape2tex_dit_1_3B_512_bf16.json \
--output_dir results/slat_flow_imgshape2tex_dit_1_3B_512_bf16 \
--data_dir "{\"ObjaverseXL_sketchfab\": {\"base\": \"datasets/ObjaverseXL_sketchfab\", \"shape_latent\": \"datasets/ObjaverseXL_sketchfab/shape_latents/shape_enc_next_dc_f16c32_fp16_512\", \"pbr_latent\": \"datasets/ObjaverseXL_sketchfab/pbr_latents/tex_enc_next_dc_f16c32_fp16_512\", \"render_cond\": \"datasets/ObjaverseXL_sketchfab/renders_cond\"}}"
```
Higher-resolution fine-tuning can be performed by updating the `finetune_ckpt` field in the following configuration files and adjusting the dataset paths accordingly:
* `slat_flow_img2shape_dit_1_3B_512_bf16_ft1024.json`
* `slat_flow_imgshape2tex_dit_1_3B_512_bf16_ft1024.json`
## 🧩 Related Packages
TRELLIS.2 is built upon several specialized high-performance packages developed by our team:
* **[O-Voxel](o-voxel):**
Core library handling the logic for converting between textured meshes and the O-Voxel representation, ensuring instant bidirectional transformation.
* **[FlexGEMM](https://github.com/Cardboard-box-a/FlexGEMM-rocm):**
Efficient sparse convolution implementation based on Triton, enabling rapid processing of sparse voxel structures. This fork adds ROCm/HIP support (ieee precision fix for AMD Triton kernels).
* **[CuMesh](https://github.com/Cardboard-box-a/CuMesh):**
CUDA-accelerated mesh utilities used for high-speed post-processing, remeshing, decimation, and UV-unwrapping. This fork includes ROCm/HIP support.
* **[nvdiffrast-hip](https://github.com/Cardboard-box-a/nvdiffrast-hip):**
HIP/ROCm port of nvdiffrast for AMD GPUs.
## ⚖️ License
This model and code are released under the **[MIT License](LICENSE)**.
Please note that certain dependencies operate under separate license terms:
- [**nvdiffrast**](https://github.com/NVlabs/nvdiffrast): Utilized for rendering generated 3D assets. This package is governed by its own [License](https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt).
- [**nvdiffrec**](https://github.com/Cardboard-box-a/nvdiffrec): Implements the split-sum renderer for PBR materials. This fork adds ROCm/HIP support. Governed by its own [License](https://github.com/NVlabs/nvdiffrec/blob/main/LICENSE.txt).
## 📚 Citation
If you find this model useful for your research, please cite our work:
```bibtex
@article{
xiang2025trellis2,
title={Native and Compact Structured Latents for 3D Generation},
author={Xiang, Jianfeng and Chen, Xiaoxue and Xu, Sicheng and Wang, Ruicheng and Lv, Zelong and Deng, Yu and Zhu, Hongyuan and Dong, Yue and Zhao, Hao and Yuan, Nicholas Jing and Yang, Jiaolong},
journal={Tech report},
year={2025}
}
```
<!-- BEGIN MICROSOFT SECURITY.MD V1.0.0 BLOCK -->
## Security
Microsoft takes the security of our software products and services seriously, which
includes all source code repositories in our GitHub organizations.
**Please do not report security vulnerabilities through public GitHub issues.**
For security reporting information, locations, contact information, and policies,
please review the latest guidance for Microsoft repositories at
[https://aka.ms/SECURITY.md](https://aka.ms/SECURITY.md).
<!-- END MICROSOFT SECURITY.MD BLOCK -->
\ No newline at end of file
# Sparse Structure Visualization Guide
This guide explains how to use sparse structure visualization features added to TRELLIS.2.
## Overview
The sparse structure is a 3D voxel grid that represents which parts of the 3D space are occupied by the object being generated. Visualizing this helps you understand:
- The initial "skeleton" or blueprint of the 3D object
- How different pipeline types (512, 1024_cascade, 1536_cascade) affect the sparse structure
- The distribution and density of occupied voxels
- The upsampling process in cascade modes (from LR to HR coordinates)
- Potential issues in the generation process
## Two Stages of Visualization
### Stage 1: Initial Sparse Structure
Generated by [`sample_sparse_structure()`](trellis2/pipelines/trellis2_image_to_3d.py:189) - this is the initial coarse voxel grid.
### Stage 2: High-Resolution Coordinates (Cascade Modes Only)
Generated by [`sample_shape_slat_cascade()`](trellis2/pipelines/trellis2_image_to_3d.py:280) - these are the upsampled coordinates after the decoder upsamples the sparse latent 4x.
**Note:** HR coordinates visualization is only available for cascade pipeline types (`1024_cascade` and `1536_cascade`).
### Stage 3: Quantized Coordinates (Cascade Modes Only)
Generated after the resolution adjustment loop in [`sample_shape_slat_cascade()`](trellis2/pipelines/trellis2_image_to_3d.py:412) - these are the coordinates after quantization, deduplication, and adaptive resolution adjustment.
**What this shows:**
- The final coordinate grid used for shape generation
- How many tokens after adaptive resolution reduction
- The actual spatial resolution being used (may be less than target)
### Stage 4: Final SLat Features (Cascade Modes Only)
Generated after flow model sampling and denormalization in [`sample_shape_slat_cascade()`](trellis2/pipelines/trellis2_image_to_3d.py:450) - these are the learned features at each coordinate.
**What this shows:**
- The actual learned shape features
- Feature value distributions across the object
- Quality of the generated shape representation
**Note:** SLat features visualization is only available for cascade pipeline types.
### Stage 5: Texture Features (Cascade Modes Only)
Generated during texture sampling in [`sample_tex_slat()`](trellis2/pipelines/trellis2_image_to_3d.py:567) - these are the learned texture attributes at each coordinate.
**What this shows:**
- Learned texture features (e.g., RGB colors, roughness, metallic properties)
- How texture varies across spatial locations
- Feature value distributions for each texture channel
**Note:** Texture features typically have multiple dimensions (e.g., 3 for RGB textures).
## Understanding the Visualizations
### What You're Seeing
The sparse structure coordinates have shape `[N, 4]` where:
- **Column 0**: Batch index (always 0 for single samples)
- **Column 1**: X coordinate (0 to resolution-1)
- **Column 2**: Y coordinate (0 to resolution-1)
- **Column 3**: Z coordinate (0 to resolution-1)
### Initial Sparse Structure vs. HR Coordinates
When using cascade modes, you'll see two sets of visualizations:
1. **Initial Sparse Structure** (e.g., `sparse_structure_1024_cascade_seed42_*.png`)
- Coarse 32³ voxel grid
- ~5,000 - 15,000 occupied voxels
- Generated directly from the sparse structure flow model
2. **HR Coordinates** (e.g., `hr_coords_1024_upsampled_*.png`)
- Upsampled coordinates (4x denser)
- ~20,000 - 60,000 coordinates
- Generated by the decoder upsampling the shape SLat
- Shows the refined structure before final shape generation
**Key Insight:** Comparing these two visualizations shows how the upsampling process refines the initial sparse structure into a more detailed representation.
## Available Visualization Methods
### 1. Matplotlib 3D Scatter Plot (`visualize_sparse_structure_matplotlib`)
Shows the sparse structure as a 3D scatter plot with color-coded Z coordinates.
**Best for:** Understanding the overall 3D shape and spatial distribution.
**Output:** Interactive 3D plot (or saved PNG file)
### 2. Voxel Grid Visualization (`visualize_sparse_structure_voxel`)
Displays the sparse structure as a 3D voxel grid where each occupied voxel is shown as a point.
**Best for:** Seeing the actual voxel structure and understanding resolution effects.
**Output:** 3D voxel visualization (or saved PNG file)
### 3. 2D Projections (`visualize_sparse_structure_projections`)
Shows three orthogonal 2D projections:
- **XY Projection**: Top view (looking down Z axis)
- **XZ Projection**: Side view (looking down Y axis)
- **YZ Projection**: Front view (looking down X axis)
**Best for:** Quick analysis of shape from different angles.
**Output:** Three 2D scatter plots in one figure (or saved PNG file)
### 4. Multi-View Visualization (`visualize_sparse_structure_multi_view`)
Combines 3D scatter plot with 2D projections in a single figure.
**Best for:** Comprehensive overview of sparse structure.
**Output:** Combined 3D + 2D visualization (or saved PNG file)
### 5. Statistical Analysis (`analyze_sparse_structure`)
Prints numerical statistics about the sparse structure:
- Total number of occupied voxels
- Coordinate ranges (X, Y, Z)
- Center position
- Standard deviation
- Bounding box volume
**Best for:** Quick quantitative analysis without visualization.
**Output:** Console output with statistics
### 6. SLat Features Visualization (`visualize_slat_features`)
Shows learned features in the shape Structured Latent (SLat) as a 3D scatter plot.
**Best for:** Understanding what the model has learned at each spatial location.
**Output:** 3D plot colored by feature values (or saved PNG file)
**Parameters:**
- `feature_idx`: Which feature dimension to visualize (default: 0)
- Multiple features can be visualized by calling with different indices
**Note:** Texture features have multiple dimensions (e.g., 3 for RGB), each representing learned texture attributes at each coordinate.
### 7. Texture Features Analysis (`analyze_slat_features`)
Prints numerical statistics about SLat features:
- Number of tokens (coordinates)
- Feature dimensions
- Statistics for each feature (min, max, mean, std)
- NaN/Inf value checks
- Coordinate ranges
**Best for:** Debugging feature values and checking for anomalies.
**Output:** Console output with feature statistics
## Quick Start with example_visualization.py
The repo includes [example_visualization.py](example_visualization.py), a standalone script that runs the full pipeline, saves stage visualizations, renders multiple views, and exports a raw `.obj` file — all in one shot. It is the fastest way to verify the pipeline is working correctly on your hardware.
### What it does
1. Runs the pipeline on a test image with all visualization stages enabled
2. Exports a raw `.obj` file (no renderer, no nvdiffrast) — load this in Blender to verify geometry completeness independently of the renderer
3. Renders N views using the same `render_snapshot` path as `app.py` and saves contact sheets
### Configuration
Edit the constants at the top of the file:
```python
IMAGE_PATH = "assets/example_image/T2.png" # input image
PIPELINE = "1024_cascade" # '512' | '1024' | '1024_cascade' | '1536_cascade'
SEED = 42
NVIEWS = 8 # render views
RENDER_RES = 1024 # render resolution
VIZ_DIR = "visualizations_render_test" # output directory
```
### Running
```sh
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" # AMD only
python example_visualization.py
```
Output files in `VIZ_DIR/`:
- `raw_mesh.obj` — raw geometry, no renderer involved. If this looks correct in Blender but renders look wrong, the bug is in the rasterizer path, not the pipeline.
- `render_frames/contact_shaded_all_views.png` — all rendered views side by side
- `render_frames/contact_normal_all_views.png` — surface normals
- `render_frames/contact_base_color_all_views.png` — albedo without lighting
- Per-stage visualization PNGs (sparse structure, HR coords, SLat features, etc.)
### Diagnosing issues
The `.obj` export is intentionally renderer-free. If the `.obj` geometry is complete but render images show only 15–30% coverage, the issue is in the nvdiffrast/rasterizer path. If the `.obj` itself looks wrong, the issue is earlier in the pipeline.
## Usage
### Basic Usage in Your Code
```python
from trellis2.pipelines import Trellis2ImageTo3DPipeline
from PIL import Image
# Load pipeline
pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B")
pipeline.cuda()
# Load image
image = Image.open("path/to/image.png")
# Run with visualization
mesh = pipeline.run(
image,
seed=42,
pipeline_type='1024_cascade',
visualize_sparse_structure=True, # Enable visualization
visualize_save_dir=None, # None = interactive display
)
```
### Saving Visualizations to Disk
```python
mesh = pipeline.run(
image,
seed=42,
pipeline_type='1024_cascade',
visualize_sparse_structure=True,
visualize_save_dir='my_visualizations', # Save to directory
)
```
This will create multiple files for each visualization stage.
### Statistical Analysis Only
```python
# Generate sparse structure
coords = pipeline.sample_sparse_structure(
pipeline.get_cond([image], 512),
resolution=32,
num_samples=1,
sampler_params={}
)
# Analyze without visualization
pipeline.analyze_sparse_structure(coords)
```
Output:
```
Sparse Structure Analysis:
Total occupied voxels: 15234
X range: [2, 29]
Y range: [5, 26]
Z range: [3, 28]
Center: [15.2, 15.8, 14.9]
Std dev: [6.3, 5.9, 7.1]
Bounding box volume: 5832
```
## Parameters
### `visualize_sparse_structure` (bool)
- **Default:** `False`
- **Description:** Enable or disable sparse structure visualization
- **Usage:** Set to `True` to visualize the sparse structure after generation
### `visualize_save_dir` (str or None)
- **Default:** `None`
- **Description:** Directory path to save visualization images
- **Usage:**
- `None`: Display visualizations interactively (blocks execution)
- `"/path/to/dir"`: Save visualizations to disk (non-blocking)
## Understanding the Visualizations
### What You're Seeing
The sparse structure coordinates have shape `[N, 4]` where:
- **Column 0**: Batch index (always 0 for single samples)
- **Column 1**: X coordinate (0 to resolution-1)
- **Column 2**: Y coordinate (0 to resolution-1)
- **Column 3**: Z coordinate (0 to resolution-1)
### Resolution Differences
Different pipeline types use different sparse structure resolutions:
| Pipeline Type | Sparse Structure Resolution | Grid Size | Typical Voxel Count |
|--------------|----------------------------|-----------|---------------------|
| 512 | 32 | 32³ = 32,768 | ~5,000 - 15,000 |
| 1024 | 64 | 64³ = 262,144 | ~20,000 - 50,000 |
| 1024_cascade | 32 | 32³ = 32,768 | ~5,000 - 15,000 |
| 1536_cascade | 32 | 32³ = 32,768 | ~5,000 - 15,000 |
**Note:** Cascade modes use the same sparse structure resolution as 512, but later upsample during shape generation.
### Color Coding
- **Z-coordinate coloring**: Points are colored by their Z position (using viridis colormap)
- **Higher Z values**: Yellow/green (top of object)
- **Lower Z values**: Purple/blue (bottom of object)
## Examples
### Example 1: Compare Different Pipeline Types
```python
import os
for pipeline_type in ['512', '1024_cascade', '1536_cascade']:
print(f"Generating with {pipeline_type}...")
mesh = pipeline.run(
image,
seed=42,
pipeline_type=pipeline_type,
visualize_sparse_structure=True,
visualize_save_dir=f'comparison/{pipeline_type}',
)
```
### Example 2: Debug Generation Issues
```python
# Generate with visualization to check sparse structure
mesh = pipeline.run(
image,
seed=42,
pipeline_type='1024_cascade',
visualize_sparse_structure=True,
visualize_save_dir='debug_output',
)
# If sparse structure looks abnormal, you can:
# 1. Check if voxel count is too high/low
# 2. Verify coordinate ranges are within expected bounds
# 3. Compare with known good examples
```
### Example 3: Batch Analysis
```python
import pandas as pd
results = []
for seed in range(10):
coords = pipeline.sample_sparse_structure(
pipeline.get_cond([image], 512),
resolution=32,
num_samples=1,
sampler_params={}
)
coords_np = coords.cpu().numpy()
results.append({
'seed': seed,
'num_voxels': len(coords),
'x_range': coords_np[:, 1].max() - coords_np[:, 1].min(),
# ... more fields
})
df = pd.DataFrame(results)
print(df.describe())
```
### Example 4: Complete Cascade Visualization
```python
# Visualize complete cascade process with all stages
mesh = pipeline.run(
image,
seed=42,
pipeline_type='1024_cascade',
visualize_sparse_structure=True,
visualize_save_dir='complete_cascade',
)
# This creates visualizations for:
# 1. Initial sparse structure
# 2. HR coordinates (upsampled)
# 3. Quantized coordinates
# 4. Final SLat features
# 5. Texture features
```
## Troubleshooting
### Issue: Plots don't display
**Solution:** Make sure you're running in an environment with display support (not headless). For headless environments, use `visualize_save_dir` to save files instead.
### Issue: Too many voxels, visualization is slow
**Solution:** The visualization can be slow for very large sparse structures (>50,000 voxels). Consider:
1. Using lower resolution pipeline types
2. Saving to disk instead of interactive display
3. Using statistical analysis instead of full visualization
### Issue: Out of memory during visualization
**Solution:** Matplotlib can use significant memory for large plots. Try:
1. Saving to disk instead of interactive display
2. Using only 2D projections method
3. Using statistical analysis only
## Advanced Usage
### Custom Visualization
You can also call individual visualization methods directly:
```python
# Get sparse structure
coords = pipeline.sample_sparse_structure(
pipeline.get_cond([image], 512),
resolution=32,
num_samples=1,
sampler_params={}
)
# Use specific visualization method
pipeline.visualize_sparse_structure_projections(
coords,
resolution=32,
title="My Custom Title",
save_path="custom_output.png"
)
```
### Integration with Existing Code
```python
# In your existing pipeline code
def my_generation_function(image, seed):
# Generate sparse structure
coords = pipeline.sample_sparse_structure(
pipeline.get_cond([image], 512),
resolution=32,
num_samples=1,
sampler_params={}
)
# Analyze
pipeline.analyze_sparse_structure(coords)
# Continue with generation
shape_slat = pipeline.sample_shape_slat(
pipeline.get_cond([image], 512),
pipeline.models['shape_slat_flow_model_512'],
coords,
{}
)
# ... rest of your code
```
## References
- Main pipeline code: `trellis2/pipelines/trellis2_image_to_3d.py`
- Example script: `example_visualization.py`
- Sparse structure sampling: `sample_sparse_structure()` method (line 189-236)
- Visualization methods: Lines 472-690 in `trellis2_image_to_3d.py`
from email.mime import image
import gradio as gr
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HSA_XNACK"] = "1"
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "garbage_collection_threshold:0.6,max_split_size_mb:128"
from datetime import datetime
import shutil
import cv2
from typing import *
import torch
# Cap PyTorch to 90% of VRAM. On ROCm, exceeding 100% faults the GPU driver
# and hangs the display rather than raising a Python OOM exception.
# 90% leaves headroom for the display driver and system allocations.
torch.cuda.set_per_process_memory_fraction(0.90)
import numpy as np
from PIL import Image
import base64
import io
from trellis2.modules.sparse import SparseTensor
from trellis2.pipelines import Trellis2ImageTo3DPipeline
from trellis2.renderers import EnvMap
from trellis2.utils import render_utils
from trellis2.utils.pipeline_logger import (
reset_log, get_logger, section, log_mesh, log_tensor, log_uv, elapsed, set_debug
)
import o_voxel
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
MODES = [
{"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
{"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
{"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
{"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
{"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
{"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
]
STEPS = 8
DEFAULT_MODE = 3
DEFAULT_STEP = 3
css = """
/* Overwrite Gradio Default Style */
.stepper-wrapper {
padding: 0;
}
.stepper-container {
padding: 0;
align-items: center;
}
.step-button {
flex-direction: row;
}
.step-connector {
transform: none;
}
.step-number {
width: 16px;
height: 16px;
}
.step-label {
position: relative;
bottom: 0;
}
.wrap.center.full {
inset: 0;
height: 100%;
}
.wrap.center.full.translucent {
background: var(--block-background-fill);
}
.meta-text-center {
display: block !important;
position: absolute !important;
top: unset !important;
bottom: 0 !important;
right: 0 !important;
transform: unset !important;
}
/* Previewer */
.previewer-container {
position: relative;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
width: 100%;
height: 722px;
margin: 0 auto;
padding: 20px;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
.previewer-container .tips-icon {
position: absolute;
right: 10px;
top: 10px;
z-index: 10;
border-radius: 10px;
color: #fff;
background-color: var(--color-accent);
padding: 3px 6px;
user-select: none;
}
.previewer-container .tips-text {
position: absolute;
right: 10px;
top: 50px;
color: #fff;
background-color: var(--color-accent);
border-radius: 10px;
padding: 6px;
text-align: left;
max-width: 300px;
z-index: 10;
transition: all 0.3s;
opacity: 0%;
user-select: none;
}
.previewer-container .tips-text p {
font-size: 14px;
line-height: 1.2;
}
.tips-icon:hover + .tips-text {
display: block;
opacity: 100%;
}
/* Row 1: Display Modes */
.previewer-container .mode-row {
width: 100%;
display: flex;
gap: 8px;
justify-content: center;
margin-bottom: 20px;
flex-wrap: wrap;
}
.previewer-container .mode-btn {
width: 24px;
height: 24px;
border-radius: 50%;
cursor: pointer;
opacity: 0.5;
transition: all 0.2s;
border: 2px solid #ddd;
object-fit: cover;
}
.previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
.previewer-container .mode-btn.active {
opacity: 1;
border-color: var(--color-accent);
transform: scale(1.1);
}
/* Row 2: Display Image */
.previewer-container .display-row {
margin-bottom: 20px;
min-height: 400px;
width: 100%;
flex-grow: 1;
display: flex;
justify-content: center;
align-items: center;
}
.previewer-container .previewer-main-image {
max-width: 100%;
max-height: 100%;
flex-grow: 1;
object-fit: contain;
display: none;
}
.previewer-container .previewer-main-image.visible {
display: block;
}
/* Row 3: Custom HTML Slider */
.previewer-container .slider-row {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
gap: 10px;
padding: 0 10px;
}
.previewer-container input[type=range] {
-webkit-appearance: none;
width: 100%;
max-width: 400px;
background: transparent;
}
.previewer-container input[type=range]::-webkit-slider-runnable-track {
width: 100%;
height: 8px;
cursor: pointer;
background: #ddd;
border-radius: 5px;
}
.previewer-container input[type=range]::-webkit-slider-thumb {
height: 20px;
width: 20px;
border-radius: 50%;
background: var(--color-accent);
cursor: pointer;
-webkit-appearance: none;
margin-top: -6px;
box-shadow: 0 2px 5px rgba(0,0,0,0.2);
transition: transform 0.1s;
}
.previewer-container input[type=range]::-webkit-slider-thumb:hover {
transform: scale(1.2);
}
/* Overwrite Previewer Block Style */
.gradio-container .padded:has(.previewer-container) {
padding: 0 !important;
}
.gradio-container:has(.previewer-container) [data-testid="block-label"] {
position: absolute;
top: 0;
left: 0;
}
"""
head = """
<script>
function refreshView(mode, step) {
// 1. Find current mode and step
const allImgs = document.querySelectorAll('.previewer-main-image');
for (let i = 0; i < allImgs.length; i++) {
const img = allImgs[i];
if (img.classList.contains('visible')) {
const id = img.id;
const [_, m, s] = id.split('-');
if (mode === -1) mode = parseInt(m.slice(1));
if (step === -1) step = parseInt(s.slice(1));
break;
}
}
// 2. Hide ALL images
// We select all elements with class 'previewer-main-image'
allImgs.forEach(img => img.classList.remove('visible'));
// 3. Construct the specific ID for the current state
// Format: view-m{mode}-s{step}
const targetId = 'view-m' + mode + '-s' + step;
const targetImg = document.getElementById(targetId);
// 4. Show ONLY the target
if (targetImg) {
targetImg.classList.add('visible');
}
// 5. Update Button Highlights
const allBtns = document.querySelectorAll('.mode-btn');
allBtns.forEach((btn, idx) => {
if (idx === mode) btn.classList.add('active');
else btn.classList.remove('active');
});
}
// --- Action: Switch Mode ---
function selectMode(mode) {
refreshView(mode, -1);
}
// --- Action: Slider Change ---
function onSliderChange(val) {
refreshView(-1, parseInt(val));
}
</script>
"""
empty_html = f"""
<div class="previewer-container">
<svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
</div>
"""
def image_to_base64(image):
buffered = io.BytesIO()
image = image.convert("RGB")
image.save(buffered, format="jpeg", quality=85)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def preprocess_image(image: Image.Image) -> Image.Image:
"""
Preprocess the input image.
Args:
image (Image.Image): The input image.
Returns:
Image.Image: The preprocessed image.
"""
processed_image = pipeline.preprocess_image(image)
return processed_image
def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
shape_slat, tex_slat, res = latents
return {
'shape_slat_feats': shape_slat.feats.cpu().numpy(),
'tex_slat_feats': tex_slat.feats.cpu().numpy(),
'coords': shape_slat.coords.cpu().numpy(),
'res': res,
}
def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
shape_slat = SparseTensor(
feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
coords=torch.from_numpy(state['coords']).cuda(),
)
tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
return shape_slat, tex_slat, state['res']
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def image_to_3d(
image: Image.Image,
seed: int,
resolution: str,
ss_guidance_strength: float,
ss_guidance_rescale: float,
ss_sampling_steps: int,
ss_rescale_t: float,
shape_slat_guidance_strength: float,
shape_slat_guidance_rescale: float,
shape_slat_sampling_steps: int,
shape_slat_rescale_t: float,
tex_slat_guidance_strength: float,
tex_slat_guidance_rescale: float,
tex_slat_sampling_steps: int,
tex_slat_rescale_t: float,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> str:
reset_log(f"resolution={resolution} seed={seed}")
L = get_logger()
L.info(f"image size={image.size} mode={image.mode}")
# --- Sampling ---
section("pipeline.run()")
outputs, latents = pipeline.run(
image,
seed=seed,
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"guidance_strength": ss_guidance_strength,
"guidance_rescale": ss_guidance_rescale,
"rescale_t": ss_rescale_t,
},
shape_slat_sampler_params={
"steps": shape_slat_sampling_steps,
"guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale,
"rescale_t": shape_slat_rescale_t,
},
tex_slat_sampler_params={
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
},
pipeline_type={
"512": "512",
"1024": "1024_cascade",
"1536": "1536_cascade",
}[resolution],
return_latent=True,
)
section("Post-run mesh inspection")
mesh = outputs[0]
L.info(f" mesh type: {type(mesh).__name__}")
log_mesh(mesh.vertices, mesh.faces, "pre-simplify")
if hasattr(mesh, 'coords'):
log_tensor(mesh.coords, "mesh.coords")
if hasattr(mesh, 'attrs'):
log_tensor(mesh.attrs, "mesh.attrs")
section("mesh.simplify(16777216)")
#mesh.simplify(16777216) # nvdiffrast limit
log_mesh(mesh.vertices, mesh.faces, "post-simplify")
section("render_snapshot")
L.info(f" resolution=1024 nviews={STEPS}")
images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
section("render_snapshot complete")
for key, frames in images.items():
arr = frames[0]
L.info(f" render[{key}][0]: shape={arr.shape} "
f"min={arr.min():.3f} max={arr.max():.3f} "
f"NaN={bool(np.isnan(arr).any())}")
#Comment these 3 lines if you want to test preview again.
#print("Skipping nvdiffrast preview render for ROCm compatibility...")
#dummy_img = np.array(image.resize((512, 512)).convert("RGB"))
#images = {mode['render_key']: [dummy_img] * STEPS for mode in MODES}
section("extract_glb (GLB export path)")
L.info(f" GLB re-decode will use resolution={res if 'res' in dir() else 'N/A'}")
state = pack_state(latents)
torch.cuda.empty_cache()
L.info(f" {elapsed()} image_to_3d complete")
# --- HTML Construction ---
# The Stack of 48 Images
images_html = ""
for m_idx, mode in enumerate(MODES):
for s_idx in range(STEPS):
# ID Naming Convention: view-m{mode}-s{step}
unique_id = f"view-m{m_idx}-s{s_idx}"
# Logic: Only Mode 0, Step 0 is visible initially
is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
vis_class = "visible" if is_visible else ""
# Image Source
img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
# Render the Tag
images_html += f"""
<img id="{unique_id}"
class="previewer-main-image {vis_class}"
src="{img_base64}"
loading="eager">
"""
# Button Row HTML
btns_html = ""
for idx, mode in enumerate(MODES):
active_class = "active" if idx == DEFAULT_MODE else ""
# Note: onclick calls the JS function defined in Head
btns_html += f"""
<img src="{mode['icon_base64']}"
class="mode-btn {active_class}"
onclick="selectMode({idx})"
title="{mode['name']}">
"""
# Assemble the full component
full_html = f"""
<div class="previewer-container">
<div class="tips-wrapper">
<div class="tips-icon">ūüí°Tips</div>
<div class="tips-text">
<p>‚óŹ <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
<p>‚óŹ <b>View Angle</b> - Drag the slider to change the view angle.</p>
</div>
</div>
<!-- Row 1: Viewport containing 48 static <img> tags -->
<div class="display-row">
{images_html}
</div>
<!-- Row 2 -->
<div class="mode-row" id="btn-group">
{btns_html}
</div>
<!-- Row 3: Slider -->
<div class="slider-row">
<input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
</div>
</div>
"""
return state, full_html
def extract_glb(
state: dict,
decimation_target: int,
texture_size: int,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
decimation_target (int): The target face count for decimation.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
L = get_logger()
section(f"extract_glb decimation={decimation_target} tex={texture_size}")
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shape_slat, tex_slat, res = unpack_state(state)
L.info(f" res={res}")
section("decode_latent (GLB path)")
mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
log_mesh(mesh.vertices, mesh.faces, "glb-decoded")
if hasattr(mesh, 'coords'):
log_tensor(mesh.coords, "glb.coords")
if hasattr(mesh, 'attrs'):
log_tensor(mesh.attrs, "glb.attrs")
section("o_voxel.postprocess.to_glb")
L.info(f" grid_size={res} decimation={decimation_target} texture_size={texture_size}")
glb = o_voxel.postprocess.to_glb(
vertices=mesh.vertices,
faces=mesh.faces,
attr_volume=mesh.attrs,
coords=mesh.coords,
attr_layout=pipeline.pbr_attr_layout,
grid_size=res,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target=decimation_target,
texture_size=texture_size,
remesh=True,
remesh_band=1,
remesh_project=0,
use_tqdm=True,
)
now = datetime.now()
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
os.makedirs(user_dir, exist_ok=True)
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
glb.export(glb_path, extension_webp=True)
torch.cuda.empty_cache()
return glb_path, glb_path
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
* Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset.
* Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time.
""")
with gr.Row():
with gr.Column(scale=1, min_width=360):
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=500000, step=10000)
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="Advanced Settings", open=False):
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
gr.Markdown("Stage 2: Shape Generation")
with gr.Row():
shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
gr.Markdown("Stage 3: Material Generation")
with gr.Row():
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
with gr.Column(scale=10):
with gr.Walkthrough(selected=0) as walkthrough:
with gr.Step("Preview", id=0):
preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
extract_btn = gr.Button("Extract GLB")
with gr.Step("Extract", id=1):
glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
download_btn = gr.DownloadButton(label="Download GLB")
with gr.Column(scale=1, min_width=172):
examples = gr.Examples(
examples=[
f'assets/example_image/{image}'
for image in os.listdir("assets/example_image")
],
inputs=[image_prompt],
fn=preprocess_image,
outputs=[image_prompt],
run_on_click=True,
examples_per_page=18,
)
output_buf = gr.State()
# Handlers
demo.load(start_session)
demo.unload(end_session)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
lambda: gr.Walkthrough(selected=0), outputs=walkthrough
).then(
image_to_3d,
inputs=[
image_prompt, seed, resolution,
ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
],
outputs=[output_buf, preview_output],
)
extract_btn.click(
lambda: gr.Walkthrough(selected=1), outputs=walkthrough
).then(
extract_glb,
inputs=[output_buf, decimation_target, texture_size],
outputs=[glb_output, download_btn],
)
# Launch the Gradio app
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true', help='Enable debug-level pipeline logging')
args = parser.parse_args()
set_debug(args.debug)
os.makedirs(TMP_DIR, exist_ok=True)
# Construct ui components
btn_img_base64_strs = {}
for i in range(len(MODES)):
icon = Image.open(MODES[i]['icon'])
MODES[i]['icon_base64'] = image_to_base64(icon)
pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
pipeline.cuda()
envmap = {
'forest': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'sunset': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'courtyard': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
}
demo.launch(css=css, head=head)
import gradio as gr
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from datetime import datetime
import shutil
from typing import *
import torch
import numpy as np
import trimesh
from PIL import Image
from trellis2.pipelines import Trellis2TexturingPipeline
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def preprocess_image(image: Image.Image) -> Image.Image:
"""
Preprocess the input image.
Args:
image (Image.Image): The input image.
Returns:
Image.Image: The preprocessed image.
"""
processed_image = pipeline.preprocess_image(image)
return processed_image
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def shapeimage_to_tex(
mesh_file: str,
image: Image.Image,
seed: int,
resolution: str,
texture_size: int,
tex_slat_guidance_strength: float,
tex_slat_guidance_rescale: float,
tex_slat_sampling_steps: int,
tex_slat_rescale_t: float,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> str:
mesh = trimesh.load(mesh_file)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.to_mesh()
output = pipeline.run(
mesh,
image,
seed=seed,
preprocess_image=False,
tex_slat_sampler_params={
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
},
resolution=int(resolution),
texture_size=texture_size,
)
now = datetime.now()
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
output.export(glb_path, extension_webp=True)
torch.cuda.empty_cache()
return glb_path, glb_path
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Texturing a mesh with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
* Upload a mesh and corresponding reference image (preferably with an alpha-masked foreground object) and click Generate to create a textured 3D asset.
""")
with gr.Row():
with gr.Column(scale=1, min_width=360):
mesh_file = gr.File(label="Upload Mesh", file_types=[".ply", ".obj", ".glb", ".gltf"], file_count="single")
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row():
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
with gr.Column(scale=10):
glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
download_btn = gr.DownloadButton(label="Download GLB")
# Handlers
demo.load(start_session)
demo.unload(end_session)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
shapeimage_to_tex,
inputs=[
mesh_file, image_prompt, seed, resolution, texture_size,
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
],
outputs=[glb_output, download_btn],
)
# Launch the Gradio app
if __name__ == "__main__":
os.makedirs(TMP_DIR, exist_ok=True)
pipeline = Trellis2TexturingPipeline.from_pretrained('microsoft/TRELLIS.2-4B', config_file="texturing_pipeline.json")
pipeline.cuda()
demo.launch()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment