Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
TRELLIS.2
Commits
f05e915f
Commit
f05e915f
authored
May 27, 2026
by
weishb
Browse files
首次提交
parent
297bf637
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3333 additions
and
0 deletions
+3333
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/__init__.cpython-310.pyc
...2/pipelines/samplers/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/base.cpython-310.pyc
...llis2/pipelines/samplers/__pycache__/base.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc
..._pycache__/classifier_free_guidance_mixin.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc
...pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc
...plers/__pycache__/guidance_interval_mixin.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/base.py
TRELLIS.2_DCU/trellis2/pipelines/samplers/base.py
+20
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py
...lis2/pipelines/samplers/classifier_free_guidance_mixin.py
+29
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/flow_euler.py
TRELLIS.2_DCU/trellis2/pipelines/samplers/flow_euler.py
+208
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/guidance_interval_mixin.py
...CU/trellis2/pipelines/samplers/guidance_interval_mixin.py
+13
-0
TRELLIS.2_DCU/trellis2/pipelines/trellis2_image_to_3d.py
TRELLIS.2_DCU/trellis2/pipelines/trellis2_image_to_3d.py
+1538
-0
TRELLIS.2_DCU/trellis2/pipelines/trellis2_texturing.py
TRELLIS.2_DCU/trellis2/pipelines/trellis2_texturing.py
+408
-0
TRELLIS.2_DCU/trellis2/renderers/__init__.py
TRELLIS.2_DCU/trellis2/renderers/__init__.py
+33
-0
TRELLIS.2_DCU/trellis2/renderers/__pycache__/__init__.cpython-310.pyc
...U/trellis2/renderers/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/renderers/__pycache__/mesh_renderer.cpython-310.pyc
...llis2/renderers/__pycache__/mesh_renderer.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/renderers/__pycache__/pbr_mesh_renderer.cpython-310.pyc
...2/renderers/__pycache__/pbr_mesh_renderer.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/renderers/__pycache__/voxel_renderer.cpython-310.pyc
...lis2/renderers/__pycache__/voxel_renderer.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/renderers/mesh_renderer.py
TRELLIS.2_DCU/trellis2/renderers/mesh_renderer.py
+414
-0
TRELLIS.2_DCU/trellis2/renderers/pbr_mesh_renderer.py
TRELLIS.2_DCU/trellis2/renderers/pbr_mesh_renderer.py
+571
-0
TRELLIS.2_DCU/trellis2/renderers/voxel_renderer.py
TRELLIS.2_DCU/trellis2/renderers/voxel_renderer.py
+68
-0
TRELLIS.2_DCU/trellis2/representations/__init__.py
TRELLIS.2_DCU/trellis2/representations/__init__.py
+31
-0
No files found.
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/base.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/samplers/base.py
0 → 100644
View file @
f05e915f
from
typing
import
*
from
abc
import
ABC
,
abstractmethod
class
Sampler
(
ABC
):
"""
A base class for samplers.
"""
@
abstractmethod
def
sample
(
self
,
model
,
**
kwargs
):
"""
Sample from a model.
"""
pass
\ No newline at end of file
TRELLIS.2_DCU/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py
0 → 100644
View file @
f05e915f
from
typing
import
*
class
ClassifierFreeGuidanceSamplerMixin
:
"""
A mixin class for samplers that apply classifier-free guidance.
"""
def
_inference_model
(
self
,
model
,
x_t
,
t
,
cond
,
neg_cond
,
guidance_strength
,
guidance_rescale
=
0.0
,
**
kwargs
):
if
guidance_strength
==
1
:
return
super
().
_inference_model
(
model
,
x_t
,
t
,
cond
,
**
kwargs
)
elif
guidance_strength
==
0
:
return
super
().
_inference_model
(
model
,
x_t
,
t
,
neg_cond
,
**
kwargs
)
else
:
pred_pos
=
super
().
_inference_model
(
model
,
x_t
,
t
,
cond
,
**
kwargs
)
pred_neg
=
super
().
_inference_model
(
model
,
x_t
,
t
,
neg_cond
,
**
kwargs
)
pred
=
guidance_strength
*
pred_pos
+
(
1
-
guidance_strength
)
*
pred_neg
# CFG rescale
if
guidance_rescale
>
0
:
x_0_pos
=
self
.
_pred_to_xstart
(
x_t
,
t
,
pred_pos
)
x_0_cfg
=
self
.
_pred_to_xstart
(
x_t
,
t
,
pred
)
std_pos
=
x_0_pos
.
std
(
dim
=
list
(
range
(
1
,
x_0_pos
.
ndim
)),
keepdim
=
True
)
std_cfg
=
x_0_cfg
.
std
(
dim
=
list
(
range
(
1
,
x_0_cfg
.
ndim
)),
keepdim
=
True
)
x_0_rescaled
=
x_0_cfg
*
(
std_pos
/
std_cfg
)
x_0
=
guidance_rescale
*
x_0_rescaled
+
(
1
-
guidance_rescale
)
*
x_0_cfg
pred
=
self
.
_xstart_to_pred
(
x_t
,
t
,
x_0
)
return
pred
TRELLIS.2_DCU/trellis2/pipelines/samplers/flow_euler.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
numpy
as
np
from
tqdm
import
tqdm
from
easydict
import
EasyDict
as
edict
from
.base
import
Sampler
from
.classifier_free_guidance_mixin
import
ClassifierFreeGuidanceSamplerMixin
from
.guidance_interval_mixin
import
GuidanceIntervalSamplerMixin
class
FlowEulerSampler
(
Sampler
):
"""
Generate samples from a flow-matching model using Euler sampling.
Args:
sigma_min: The minimum scale of noise in flow.
"""
def
__init__
(
self
,
sigma_min
:
float
,
):
self
.
sigma_min
=
sigma_min
def
_eps_to_xstart
(
self
,
x_t
,
t
,
eps
):
assert
x_t
.
shape
==
eps
.
shape
return
(
x_t
-
(
self
.
sigma_min
+
(
1
-
self
.
sigma_min
)
*
t
)
*
eps
)
/
(
1
-
t
)
def
_xstart_to_eps
(
self
,
x_t
,
t
,
x_0
):
assert
x_t
.
shape
==
x_0
.
shape
return
(
x_t
-
(
1
-
t
)
*
x_0
)
/
(
self
.
sigma_min
+
(
1
-
self
.
sigma_min
)
*
t
)
def
_v_to_xstart_eps
(
self
,
x_t
,
t
,
v
):
assert
x_t
.
shape
==
v
.
shape
eps
=
(
1
-
t
)
*
v
+
x_t
x_0
=
(
1
-
self
.
sigma_min
)
*
x_t
-
(
self
.
sigma_min
+
(
1
-
self
.
sigma_min
)
*
t
)
*
v
return
x_0
,
eps
def
_pred_to_xstart
(
self
,
x_t
,
t
,
pred
):
return
(
1
-
self
.
sigma_min
)
*
x_t
-
(
self
.
sigma_min
+
(
1
-
self
.
sigma_min
)
*
t
)
*
pred
def
_xstart_to_pred
(
self
,
x_t
,
t
,
x_0
):
return
((
1
-
self
.
sigma_min
)
*
x_t
-
x_0
)
/
(
self
.
sigma_min
+
(
1
-
self
.
sigma_min
)
*
t
)
def
_inference_model
(
self
,
model
,
x_t
,
t
,
cond
=
None
,
**
kwargs
):
t
=
torch
.
tensor
([
1000
*
t
]
*
x_t
.
shape
[
0
],
device
=
x_t
.
device
,
dtype
=
torch
.
float32
)
return
model
(
x_t
,
t
,
cond
,
**
kwargs
)
def
_get_model_prediction
(
self
,
model
,
x_t
,
t
,
cond
=
None
,
**
kwargs
):
pred_v
=
self
.
_inference_model
(
model
,
x_t
,
t
,
cond
,
**
kwargs
)
pred_x_0
,
pred_eps
=
self
.
_v_to_xstart_eps
(
x_t
=
x_t
,
t
=
t
,
v
=
pred_v
)
return
pred_x_0
,
pred_eps
,
pred_v
@
torch
.
no_grad
()
def
sample_once
(
self
,
model
,
x_t
,
t
:
float
,
t_prev
:
float
,
cond
:
Optional
[
Any
]
=
None
,
**
kwargs
):
"""
Sample x_{t-1} from the model using Euler method.
Args:
model: The model to sample from.
x_t: The [N x C x ...] tensor of noisy inputs at time t.
t: The current timestep.
t_prev: The previous timestep.
cond: conditional information.
**kwargs: Additional arguments for model inference.
Returns:
a dict containing the following
- 'pred_x_prev': x_{t-1}.
- 'pred_x_0': a prediction of x_0.
"""
pred_x_0
,
pred_eps
,
pred_v
=
self
.
_get_model_prediction
(
model
,
x_t
,
t
,
cond
,
**
kwargs
)
pred_x_prev
=
x_t
-
(
t
-
t_prev
)
*
pred_v
return
edict
({
"pred_x_prev"
:
pred_x_prev
,
"pred_x_0"
:
pred_x_0
})
@
torch
.
no_grad
()
def
sample
(
self
,
model
,
noise
,
cond
:
Optional
[
Any
]
=
None
,
steps
:
int
=
50
,
rescale_t
:
float
=
1.0
,
verbose
:
bool
=
True
,
tqdm_desc
:
str
=
"Sampling"
,
**
kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
verbose: If True, show a progress bar.
tqdm_desc: A customized tqdm desc.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
sample
=
noise
t_seq
=
np
.
linspace
(
1
,
0
,
steps
+
1
)
t_seq
=
rescale_t
*
t_seq
/
(
1
+
(
rescale_t
-
1
)
*
t_seq
)
t_seq
=
t_seq
.
tolist
()
t_pairs
=
list
((
t_seq
[
i
],
t_seq
[
i
+
1
])
for
i
in
range
(
steps
))
ret
=
edict
({
"samples"
:
None
,
"pred_x_t"
:
[],
"pred_x_0"
:
[]})
for
t
,
t_prev
in
tqdm
(
t_pairs
,
desc
=
tqdm_desc
,
disable
=
not
verbose
):
out
=
self
.
sample_once
(
model
,
sample
,
t
,
t_prev
,
cond
,
**
kwargs
)
sample
=
out
.
pred_x_prev
ret
.
pred_x_t
.
append
(
out
.
pred_x_prev
)
ret
.
pred_x_0
.
append
(
out
.
pred_x_0
)
ret
.
samples
=
sample
return
ret
class
FlowEulerCfgSampler
(
ClassifierFreeGuidanceSamplerMixin
,
FlowEulerSampler
):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
"""
@
torch
.
no_grad
()
def
sample
(
self
,
model
,
noise
,
cond
,
neg_cond
,
steps
:
int
=
50
,
rescale_t
:
float
=
1.0
,
guidance_strength
:
float
=
3.0
,
verbose
:
bool
=
True
,
**
kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
guidance_strength: The strength of classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
return
super
().
sample
(
model
,
noise
,
cond
,
steps
,
rescale_t
,
verbose
,
neg_cond
=
neg_cond
,
guidance_strength
=
guidance_strength
,
**
kwargs
)
class
FlowEulerGuidanceIntervalSampler
(
GuidanceIntervalSamplerMixin
,
ClassifierFreeGuidanceSamplerMixin
,
FlowEulerSampler
):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
"""
@
torch
.
no_grad
()
def
sample
(
self
,
model
,
noise
,
cond
,
neg_cond
,
steps
:
int
=
50
,
rescale_t
:
float
=
1.0
,
guidance_strength
:
float
=
3.0
,
guidance_interval
:
Tuple
[
float
,
float
]
=
(
0.0
,
1.0
),
verbose
:
bool
=
True
,
**
kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
guidance_strength: The strength of classifier-free guidance.
guidance_interval: The interval for classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
return
super
().
sample
(
model
,
noise
,
cond
,
steps
,
rescale_t
,
verbose
,
neg_cond
=
neg_cond
,
guidance_strength
=
guidance_strength
,
guidance_interval
=
guidance_interval
,
**
kwargs
)
TRELLIS.2_DCU/trellis2/pipelines/samplers/guidance_interval_mixin.py
0 → 100644
View file @
f05e915f
from
typing
import
*
class
GuidanceIntervalSamplerMixin
:
"""
A mixin class for samplers that apply classifier-free guidance with interval.
"""
def
_inference_model
(
self
,
model
,
x_t
,
t
,
cond
,
guidance_strength
,
guidance_interval
,
**
kwargs
):
if
guidance_interval
[
0
]
<=
t
<=
guidance_interval
[
1
]:
return
super
().
_inference_model
(
model
,
x_t
,
t
,
cond
,
guidance_strength
=
guidance_strength
,
**
kwargs
)
else
:
return
super
().
_inference_model
(
model
,
x_t
,
t
,
cond
,
guidance_strength
=
1
,
**
kwargs
)
TRELLIS.2_DCU/trellis2/pipelines/trellis2_image_to_3d.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
PIL
import
Image
from
.base
import
Pipeline
from
.
import
samplers
,
rembg
from
..modules.sparse
import
SparseTensor
from
..modules
import
image_feature_extractor
from
..representations
import
Mesh
,
MeshWithVoxel
from
..utils.pipeline_logger
import
get_logger
,
log_sparse
,
log_mesh
,
log_tensor
,
section
,
elapsed
import
matplotlib.pyplot
as
plt
from
mpl_toolkits.mplot3d
import
Axes3D
class
Trellis2ImageTo3DPipeline
(
Pipeline
):
"""
Pipeline for inferring Trellis2 image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
shape_slat_sampler (samplers.Sampler): The sampler for the structured latent.
tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler.
shape_slat_sampler_params (dict): The parameters for the structured latent sampler.
tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
shape_slat_normalization (dict): The normalization parameters for the structured latent.
tex_slat_normalization (dict): The normalization parameters for the texture latent.
image_cond_model (Callable): The image conditioning model.
rembg_model (Callable): The model for removing background.
low_vram (bool): Whether to use low-VRAM mode.
"""
model_names_to_load
=
[
'sparse_structure_flow_model'
,
'sparse_structure_decoder'
,
'shape_slat_flow_model_512'
,
'shape_slat_flow_model_1024'
,
'shape_slat_decoder'
,
'tex_slat_flow_model_512'
,
'tex_slat_flow_model_1024'
,
'tex_slat_decoder'
,
]
def
__init__
(
self
,
models
:
dict
[
str
,
nn
.
Module
]
=
None
,
sparse_structure_sampler
:
samplers
.
Sampler
=
None
,
shape_slat_sampler
:
samplers
.
Sampler
=
None
,
tex_slat_sampler
:
samplers
.
Sampler
=
None
,
sparse_structure_sampler_params
:
dict
=
None
,
shape_slat_sampler_params
:
dict
=
None
,
tex_slat_sampler_params
:
dict
=
None
,
shape_slat_normalization
:
dict
=
None
,
tex_slat_normalization
:
dict
=
None
,
image_cond_model
:
Callable
=
None
,
rembg_model
:
Callable
=
None
,
low_vram
:
bool
=
True
,
default_pipeline_type
:
str
=
'1024_cascade'
,
):
if
models
is
None
:
return
super
().
__init__
(
models
)
self
.
sparse_structure_sampler
=
sparse_structure_sampler
self
.
shape_slat_sampler
=
shape_slat_sampler
self
.
tex_slat_sampler
=
tex_slat_sampler
self
.
sparse_structure_sampler_params
=
sparse_structure_sampler_params
self
.
shape_slat_sampler_params
=
shape_slat_sampler_params
self
.
tex_slat_sampler_params
=
tex_slat_sampler_params
self
.
shape_slat_normalization
=
shape_slat_normalization
self
.
tex_slat_normalization
=
tex_slat_normalization
self
.
image_cond_model
=
image_cond_model
self
.
rembg_model
=
rembg_model
self
.
low_vram
=
low_vram
self
.
default_pipeline_type
=
default_pipeline_type
self
.
pbr_attr_layout
=
{
'base_color'
:
slice
(
0
,
3
),
'metallic'
:
slice
(
3
,
4
),
'roughness'
:
slice
(
4
,
5
),
'alpha'
:
slice
(
5
,
6
),
}
self
.
_device
=
'cpu'
@
classmethod
def
from_pretrained
(
cls
,
path
:
str
,
config_file
:
str
=
"pipeline.json"
)
->
"Trellis2ImageTo3DPipeline"
:
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline
=
super
().
from_pretrained
(
path
,
config_file
)
args
=
pipeline
.
_pretrained_args
pipeline
.
sparse_structure_sampler
=
getattr
(
samplers
,
args
[
'sparse_structure_sampler'
][
'name'
])(
**
args
[
'sparse_structure_sampler'
][
'args'
])
pipeline
.
sparse_structure_sampler_params
=
args
[
'sparse_structure_sampler'
][
'params'
]
pipeline
.
shape_slat_sampler
=
getattr
(
samplers
,
args
[
'shape_slat_sampler'
][
'name'
])(
**
args
[
'shape_slat_sampler'
][
'args'
])
pipeline
.
shape_slat_sampler_params
=
args
[
'shape_slat_sampler'
][
'params'
]
pipeline
.
tex_slat_sampler
=
getattr
(
samplers
,
args
[
'tex_slat_sampler'
][
'name'
])(
**
args
[
'tex_slat_sampler'
][
'args'
])
pipeline
.
tex_slat_sampler_params
=
args
[
'tex_slat_sampler'
][
'params'
]
pipeline
.
shape_slat_normalization
=
args
[
'shape_slat_normalization'
]
pipeline
.
tex_slat_normalization
=
args
[
'tex_slat_normalization'
]
pipeline
.
image_cond_model
=
getattr
(
image_feature_extractor
,
args
[
'image_cond_model'
][
'name'
])(
**
args
[
'image_cond_model'
][
'args'
])
pipeline
.
rembg_model
=
getattr
(
rembg
,
args
[
'rembg_model'
][
'name'
])(
**
args
[
'rembg_model'
][
'args'
])
pipeline
.
low_vram
=
args
.
get
(
'low_vram'
,
True
)
pipeline
.
default_pipeline_type
=
args
.
get
(
'default_pipeline_type'
,
'1024_cascade'
)
pipeline
.
pbr_attr_layout
=
{
'base_color'
:
slice
(
0
,
3
),
'metallic'
:
slice
(
3
,
4
),
'roughness'
:
slice
(
4
,
5
),
'alpha'
:
slice
(
5
,
6
),
}
pipeline
.
_device
=
'cpu'
return
pipeline
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
self
.
_device
=
device
if
not
self
.
low_vram
:
super
().
to
(
device
)
self
.
image_cond_model
.
to
(
device
)
if
self
.
rembg_model
is
not
None
:
self
.
rembg_model
.
to
(
device
)
def
preprocess_image
(
self
,
input
:
Image
.
Image
)
->
Image
.
Image
:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha
=
False
if
input
.
mode
==
'RGBA'
:
alpha
=
np
.
array
(
input
)[:,
:,
3
]
if
not
np
.
all
(
alpha
==
255
):
has_alpha
=
True
max_size
=
max
(
input
.
size
)
scale
=
min
(
1
,
1024
/
max_size
)
if
scale
<
1
:
input
=
input
.
resize
((
int
(
input
.
width
*
scale
),
int
(
input
.
height
*
scale
)),
Image
.
Resampling
.
LANCZOS
)
if
has_alpha
:
output
=
input
else
:
input
=
input
.
convert
(
'RGB'
)
if
self
.
low_vram
:
self
.
rembg_model
.
to
(
self
.
device
)
output
=
self
.
rembg_model
(
input
)
if
self
.
low_vram
:
self
.
rembg_model
.
cpu
()
output_np
=
np
.
array
(
output
)
alpha
=
output_np
[:,
:,
3
]
bbox
=
np
.
argwhere
(
alpha
>
0.8
*
255
)
bbox
=
np
.
min
(
bbox
[:,
1
]),
np
.
min
(
bbox
[:,
0
]),
np
.
max
(
bbox
[:,
1
]),
np
.
max
(
bbox
[:,
0
])
center
=
(
bbox
[
0
]
+
bbox
[
2
])
/
2
,
(
bbox
[
1
]
+
bbox
[
3
])
/
2
size
=
max
(
bbox
[
2
]
-
bbox
[
0
],
bbox
[
3
]
-
bbox
[
1
])
size
=
int
(
size
*
1
)
bbox
=
center
[
0
]
-
size
//
2
,
center
[
1
]
-
size
//
2
,
center
[
0
]
+
size
//
2
,
center
[
1
]
+
size
//
2
output
=
output
.
crop
(
bbox
)
# type: ignore
output
=
np
.
array
(
output
).
astype
(
np
.
float32
)
/
255
output
=
output
[:,
:,
:
3
]
*
output
[:,
:,
3
:
4
]
output
=
Image
.
fromarray
((
output
*
255
).
astype
(
np
.
uint8
))
return
output
def
get_cond
(
self
,
image
:
Union
[
torch
.
Tensor
,
list
[
Image
.
Image
]],
resolution
:
int
,
include_neg_cond
:
bool
=
True
)
->
dict
:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
self
.
image_cond_model
.
image_size
=
resolution
if
self
.
low_vram
:
self
.
image_cond_model
.
to
(
self
.
device
)
cond
=
self
.
image_cond_model
(
image
)
if
self
.
low_vram
:
self
.
image_cond_model
.
cpu
()
if
not
include_neg_cond
:
return
{
'cond'
:
cond
}
neg_cond
=
torch
.
zeros_like
(
cond
)
return
{
'cond'
:
cond
,
'neg_cond'
:
neg_cond
,
}
def
sample_sparse_structure
(
self
,
cond
:
dict
,
resolution
:
int
,
num_samples
:
int
=
1
,
sampler_params
:
dict
=
{},
)
->
torch
.
Tensor
:
"""
Sample sparse structures with the given conditioning.
Args:
cond (dict): The conditioning information.
resolution (int): The resolution of the sparse structure.
num_samples (int): The number of samples to generate.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample sparse structure latent
flow_model
=
self
.
models
[
'sparse_structure_flow_model'
]
reso
=
flow_model
.
resolution
in_channels
=
flow_model
.
in_channels
noise
=
torch
.
randn
(
num_samples
,
in_channels
,
reso
,
reso
,
reso
).
to
(
self
.
device
)
sampler_params
=
{
**
self
.
sparse_structure_sampler_params
,
**
sampler_params
}
if
self
.
low_vram
:
flow_model
.
to
(
self
.
device
)
z_s
=
self
.
sparse_structure_sampler
.
sample
(
flow_model
,
noise
,
**
cond
,
**
sampler_params
,
verbose
=
True
,
tqdm_desc
=
"Sampling sparse structure"
,
).
samples
if
self
.
low_vram
:
flow_model
.
cpu
()
# Decode sparse structure latent
decoder
=
self
.
models
[
'sparse_structure_decoder'
]
if
self
.
low_vram
:
decoder
.
to
(
self
.
device
)
decoded
=
decoder
(
z_s
)
>
0
if
self
.
low_vram
:
decoder
.
cpu
()
if
resolution
!=
decoded
.
shape
[
2
]:
ratio
=
decoded
.
shape
[
2
]
//
resolution
decoded
=
torch
.
nn
.
functional
.
max_pool3d
(
decoded
.
float
(),
ratio
,
ratio
,
0
)
>
0.5
coords
=
torch
.
argwhere
(
decoded
)[:,
[
0
,
2
,
3
,
4
]].
int
()
return
coords
def
sample_shape_slat
(
self
,
cond
:
dict
,
flow_model
,
coords
:
torch
.
Tensor
,
sampler_params
:
dict
=
{},
)
->
SparseTensor
:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
noise
=
SparseTensor
(
feats
=
torch
.
randn
(
coords
.
shape
[
0
],
flow_model
.
in_channels
).
to
(
self
.
device
),
coords
=
coords
,
)
sampler_params
=
{
**
self
.
shape_slat_sampler_params
,
**
sampler_params
}
if
self
.
low_vram
:
flow_model
.
to
(
self
.
device
)
slat
=
self
.
shape_slat_sampler
.
sample
(
flow_model
,
noise
,
**
cond
,
**
sampler_params
,
verbose
=
True
,
tqdm_desc
=
"Sampling shape SLat"
,
).
samples
if
self
.
low_vram
:
flow_model
.
cpu
()
std
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'std'
])[
None
].
to
(
slat
.
device
)
mean
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'mean'
])[
None
].
to
(
slat
.
device
)
slat
=
slat
*
std
+
mean
return
slat
def
sample_shape_slat_cascade
(
self
,
lr_cond
:
dict
,
cond
:
dict
,
flow_model_lr
,
flow_model
,
lr_resolution
:
int
,
resolution
:
int
,
coords
:
torch
.
Tensor
,
sampler_params
:
dict
=
{},
max_num_tokens
:
int
=
49152
,
visualize_hr_coords
:
bool
=
False
,
visualize_save_dir
:
str
=
None
,
)
->
SparseTensor
:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
visualize_hr_coords (bool): Whether to visualize high-resolution coordinates after upsampling.
visualize_save_dir (str): Directory to save visualization images. If None, displays interactively.
"""
# LR
noise
=
SparseTensor
(
feats
=
torch
.
randn
(
coords
.
shape
[
0
],
flow_model_lr
.
in_channels
).
to
(
self
.
device
),
coords
=
coords
,
)
sampler_params
=
{
**
self
.
shape_slat_sampler_params
,
**
sampler_params
}
if
self
.
low_vram
:
flow_model_lr
.
to
(
self
.
device
)
slat
=
self
.
shape_slat_sampler
.
sample
(
flow_model_lr
,
noise
,
**
lr_cond
,
**
sampler_params
,
verbose
=
True
,
tqdm_desc
=
"Sampling shape SLat"
,
).
samples
get_logger
().
debug
(
f
"DEBUG SLAT: coords=
{
slat
.
coords
.
shape
}
, spatial_shape=
{
slat
.
spatial_shape
}
, "
f
"coords_max=
{
slat
.
coords
[:,
1
:].
contiguous
().
max
(
dim
=
0
).
values
}
, dtype=
{
slat
.
feats
.
dtype
}
"
)
if
self
.
low_vram
:
flow_model_lr
.
cpu
()
std
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'std'
])[
None
].
to
(
slat
.
device
)
mean
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'mean'
])[
None
].
to
(
slat
.
device
)
slat
=
slat
*
std
+
mean
get_logger
().
debug
(
f
"DEBUG SLAT[after *std + mean]: coords=
{
slat
.
coords
.
shape
}
, spatial_shape=
{
slat
.
spatial_shape
}
, "
f
"coords_max=
{
slat
.
coords
[:,
1
:].
contiguous
().
max
(
dim
=
0
).
values
}
, dtype=
{
slat
.
feats
.
dtype
}
"
)
# Upsample
if
self
.
low_vram
:
self
.
models
[
'shape_slat_decoder'
].
to
(
self
.
device
)
self
.
models
[
'shape_slat_decoder'
].
low_vram
=
True
hr_coords
=
self
.
models
[
'shape_slat_decoder'
].
upsample
(
slat
,
upsample_times
=
4
)
get_logger
().
debug
(
f
"DEBUG CASCADE: hr_coords shape=
{
hr_coords
.
shape
}
, max=
{
hr_coords
[:,
1
:].
max
(
dim
=
0
).
values
}
, unique_x=
{
hr_coords
[:,
1
].
unique
().
shape
[
0
]
}
, unique_y=
{
hr_coords
[:,
2
].
unique
().
shape
[
0
]
}
, unique_z=
{
hr_coords
[:,
3
].
unique
().
shape
[
0
]
}
"
)
# Visualize high-resolution coordinates if requested
if
visualize_hr_coords
:
print
(
"
\n
=== High-Resolution Coordinates Visualization (After Upsampling) ==="
)
self
.
analyze_sparse_structure
(
hr_coords
)
# Calculate effective resolution for visualization
effective_resolution
=
lr_resolution
*
4
# upsample_times=4
if
visualize_save_dir
:
import
os
os
.
makedirs
(
visualize_save_dir
,
exist_ok
=
True
)
base_path
=
os
.
path
.
join
(
visualize_save_dir
,
f
"hr_coords_
{
resolution
}
_upsampled"
)
self
.
visualize_sparse_structure_matplotlib
(
hr_coords
,
title
=
f
"HR Coordinates - Upsampled
{
resolution
}
(effective res:
{
effective_resolution
}
)"
,
save_path
=
f
"
{
base_path
}
_3d.png"
)
self
.
visualize_sparse_structure_voxel
(
hr_coords
,
resolution
=
effective_resolution
,
title
=
f
"HR Voxel Grid - Upsampled
{
resolution
}
(effective res:
{
effective_resolution
}
)"
,
save_path
=
f
"
{
base_path
}
_voxel.png"
)
self
.
visualize_sparse_structure_projections
(
hr_coords
,
resolution
=
effective_resolution
,
title
=
f
"HR Projections - Upsampled
{
resolution
}
(effective res:
{
effective_resolution
}
)"
,
save_path
=
f
"
{
base_path
}
_projections.png"
)
self
.
visualize_sparse_structure_multi_view
(
hr_coords
,
title
=
f
"HR Multi-View - Upsampled
{
resolution
}
(effective res:
{
effective_resolution
}
)"
,
save_path
=
f
"
{
base_path
}
_multi_view.png"
)
else
:
# Interactive visualization (no saving)
self
.
visualize_sparse_structure_matplotlib
(
hr_coords
,
title
=
f
"HR Coordinates - Upsampled
{
resolution
}
(effective res:
{
effective_resolution
}
)"
)
self
.
visualize_sparse_structure_voxel
(
hr_coords
,
resolution
=
effective_resolution
,
title
=
f
"HR Voxel Grid - Upsampled
{
resolution
}
(effective res:
{
effective_resolution
}
)"
)
self
.
visualize_sparse_structure_projections
(
hr_coords
,
resolution
=
effective_resolution
,
title
=
f
"HR Projections - Upsampled
{
resolution
}
(effective res:
{
effective_resolution
}
)"
)
self
.
visualize_sparse_structure_multi_view
(
hr_coords
,
title
=
f
"HR Multi-View - Upsampled
{
resolution
}
(effective res:
{
effective_resolution
}
)"
)
print
(
"=== HR Coordinates Visualization Complete ===
\n
"
)
coord_set
=
set
(
map
(
tuple
,
hr_coords
[:,
1
:].
cpu
().
numpy
().
tolist
()))
has_neighbor
=
sum
(
1
for
c
in
coord_set
if
any
(
(
c
[
0
]
+
dx
,
c
[
1
]
+
dy
,
c
[
2
]
+
dz
)
in
coord_set
for
dx
,
dy
,
dz
in
[(
1
,
0
,
0
),(
-
1
,
0
,
0
),(
0
,
1
,
0
),(
0
,
-
1
,
0
),(
0
,
0
,
1
),(
0
,
0
,
-
1
)]
))
/
len
(
coord_set
)
get_logger
().
debug
(
f
"DEBUG TOPOLOGY: coords=
{
len
(
coord_set
)
}
, neighbor_coverage=
{
has_neighbor
:.
3
f
}
"
)
if
self
.
low_vram
:
self
.
models
[
'shape_slat_decoder'
].
cpu
()
self
.
models
[
'shape_slat_decoder'
].
low_vram
=
False
hr_resolution
=
resolution
while
True
:
quant_coords
=
torch
.
cat
([
hr_coords
[:,
:
1
],
((
hr_coords
[:,
1
:]
+
0.5
)
/
lr_resolution
*
(
hr_resolution
//
16
)).
int
(),
],
dim
=
1
)
coords
=
quant_coords
.
unique
(
dim
=
0
)
get_logger
().
debug
(
f
"DEBUG COORDS: num_tokens=
{
coords
.
shape
[
0
]
}
, max=
{
coords
[:,
1
:].
max
(
dim
=
0
).
values
}
"
)
num_tokens
=
coords
.
shape
[
0
]
if
num_tokens
<
max_num_tokens
or
hr_resolution
==
1024
:
if
hr_resolution
!=
resolution
:
print
(
f
"Due to the limited number of tokens, the resolution is reduced to
{
hr_resolution
}
."
)
break
hr_resolution
-=
128
# Visualize quantized coordinates if requested
if
visualize_hr_coords
:
print
(
"
\n
=== Quantized Coordinates Visualization (After Resolution Adjustment) ==="
)
self
.
analyze_sparse_structure
(
coords
)
if
visualize_save_dir
:
import
os
os
.
makedirs
(
visualize_save_dir
,
exist_ok
=
True
)
base_path
=
os
.
path
.
join
(
visualize_save_dir
,
f
"quantized_coords_
{
hr_resolution
}
"
)
self
.
visualize_sparse_structure_matplotlib
(
coords
,
title
=
f
"Quantized Coords - Resolution
{
hr_resolution
}
"
,
save_path
=
f
"
{
base_path
}
_3d.png"
)
self
.
visualize_sparse_structure_voxel
(
coords
,
resolution
=
hr_resolution
//
16
,
title
=
f
"Quantized Voxel Grid - Resolution
{
hr_resolution
}
"
,
save_path
=
f
"
{
base_path
}
_voxel.png"
)
self
.
visualize_sparse_structure_projections
(
coords
,
resolution
=
hr_resolution
//
16
,
title
=
f
"Quantized Projections - Resolution
{
hr_resolution
}
"
,
save_path
=
f
"
{
base_path
}
_projections.png"
)
self
.
visualize_sparse_structure_multi_view
(
coords
,
title
=
f
"Quantized Multi-View - Resolution
{
hr_resolution
}
"
,
save_path
=
f
"
{
base_path
}
_multi_view.png"
)
else
:
# Interactive visualization (no saving)
self
.
visualize_sparse_structure_matplotlib
(
coords
,
title
=
f
"Quantized Coords - Resolution
{
hr_resolution
}
"
)
self
.
visualize_sparse_structure_voxel
(
coords
,
resolution
=
hr_resolution
//
16
,
title
=
f
"Quantized Voxel Grid - Resolution
{
hr_resolution
}
"
)
self
.
visualize_sparse_structure_projections
(
coords
,
resolution
=
hr_resolution
//
16
,
title
=
f
"Quantized Projections - Resolution
{
hr_resolution
}
"
)
self
.
visualize_sparse_structure_multi_view
(
coords
,
title
=
f
"Quantized Multi-View - Resolution
{
hr_resolution
}
"
)
print
(
"=== Quantized Coordinates Visualization Complete ===
\n
"
)
# Sample structured latent
noise
=
SparseTensor
(
feats
=
torch
.
randn
(
coords
.
shape
[
0
],
flow_model
.
in_channels
).
to
(
self
.
device
),
coords
=
coords
,
)
sampler_params
=
{
**
self
.
shape_slat_sampler_params
,
**
sampler_params
}
if
self
.
low_vram
:
flow_model
.
to
(
self
.
device
)
slat
=
self
.
shape_slat_sampler
.
sample
(
flow_model
,
noise
,
**
cond
,
**
sampler_params
,
verbose
=
True
,
tqdm_desc
=
"Sampling shape SLat"
,
).
samples
if
self
.
low_vram
:
flow_model
.
cpu
()
std
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'std'
])[
None
].
to
(
slat
.
device
)
mean
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'mean'
])[
None
].
to
(
slat
.
device
)
slat
=
slat
*
std
+
mean
get_logger
().
debug
(
f
"CASCADE final slat: nan=
{
torch
.
isnan
(
slat
.
feats
).
any
().
item
()
}
inf=
{
torch
.
isinf
(
slat
.
feats
).
any
().
item
()
}
max=
{
slat
.
feats
.
abs
().
max
().
item
():.
4
f
}
dtype=
{
slat
.
feats
.
dtype
}
"
)
# Visualize final SLat features if requested
if
visualize_hr_coords
:
print
(
"
\n
=== Final SLat Features Visualization (After Denormalization) ==="
)
self
.
analyze_slat_features
(
slat
)
if
visualize_save_dir
:
import
os
os
.
makedirs
(
visualize_save_dir
,
exist_ok
=
True
)
base_path
=
os
.
path
.
join
(
visualize_save_dir
,
f
"final_slat_
{
hr_resolution
}
"
)
# Visualize first few features
for
i
in
range
(
min
(
3
,
slat
.
feats
.
shape
[
1
])):
self
.
visualize_slat_features
(
slat
,
title
=
f
"Final SLat Feature
{
i
}
- Resolution
{
hr_resolution
}
"
,
save_path
=
f
"
{
base_path
}
_feature
{
i
}
.png"
,
feature_idx
=
i
)
else
:
# Interactive visualization (no saving)
for
i
in
range
(
min
(
3
,
slat
.
feats
.
shape
[
1
])):
self
.
visualize_slat_features
(
slat
,
title
=
f
"Final SLat Feature
{
i
}
- Resolution
{
hr_resolution
}
"
,
feature_idx
=
i
)
print
(
"=== Final SLat Features Visualization Complete ===
\n
"
)
return
slat
,
hr_resolution
def
decode_shape_slat
(
self
,
slat
:
SparseTensor
,
resolution
:
int
,
)
->
Tuple
[
List
[
Mesh
],
List
[
SparseTensor
]]:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
Returns:
List[Mesh]: The decoded meshes.
List[SparseTensor]: The decoded substructures.
"""
self
.
models
[
'shape_slat_decoder'
].
set_resolution
(
resolution
)
if
self
.
low_vram
:
self
.
models
[
'shape_slat_decoder'
].
to
(
self
.
device
)
self
.
models
[
'shape_slat_decoder'
].
low_vram
=
True
ret
=
self
.
models
[
'shape_slat_decoder'
](
slat
,
return_subs
=
True
)
if
self
.
low_vram
:
self
.
models
[
'shape_slat_decoder'
].
cpu
()
self
.
models
[
'shape_slat_decoder'
].
low_vram
=
False
return
ret
def
sample_tex_slat
(
self
,
cond
:
dict
,
flow_model
,
shape_slat
:
SparseTensor
,
sampler_params
:
dict
=
{},
visualize
:
bool
=
False
,
visualize_save_dir
:
str
=
None
,
pipeline_type
:
str
=
'unknown'
,
)
->
SparseTensor
:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
shape_slat (SparseTensor): The structured latent for shape
sampler_params (dict): Additional parameters for the sampler.
visualize (bool): Whether to visualize shape + colored texture slat.
visualize_save_dir (str): Directory to save visualizations. None = interactive.
pipeline_type (str): Pipeline name used in visualization titles.
"""
# Sample structured latent
std
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'std'
])[
None
].
to
(
shape_slat
.
device
)
mean
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'mean'
])[
None
].
to
(
shape_slat
.
device
)
shape_slat_norm
=
(
shape_slat
-
mean
)
/
std
in_channels
=
flow_model
.
in_channels
if
isinstance
(
flow_model
,
nn
.
Module
)
else
flow_model
[
0
].
in_channels
noise
=
shape_slat_norm
.
replace
(
feats
=
torch
.
randn
(
shape_slat_norm
.
coords
.
shape
[
0
],
in_channels
-
shape_slat_norm
.
feats
.
shape
[
1
]).
to
(
self
.
device
))
sampler_params
=
{
**
self
.
tex_slat_sampler_params
,
**
sampler_params
}
if
self
.
low_vram
:
flow_model
.
to
(
self
.
device
)
slat
=
self
.
tex_slat_sampler
.
sample
(
flow_model
,
noise
,
concat_cond
=
shape_slat_norm
,
**
cond
,
**
sampler_params
,
verbose
=
True
,
tqdm_desc
=
"Sampling texture SLat"
,
).
samples
if
self
.
low_vram
:
flow_model
.
cpu
()
# Visualize: shape structure + colored texture slat
if
visualize
:
import
os
print
(
"
\n
=== Texture SLat Visualization ==="
)
self
.
analyze_slat_features
(
slat
)
if
visualize_save_dir
:
os
.
makedirs
(
visualize_save_dir
,
exist_ok
=
True
)
base_path
=
os
.
path
.
join
(
visualize_save_dir
,
f
"tex_slat_
{
pipeline_type
}
"
)
# 1. Shape-only structure (occupancy/geometry)
self
.
visualize_sparse_structure_projections
(
shape_slat
.
coords
,
title
=
f
"Shape Structure -
{
pipeline_type
}
"
,
save_path
=
f
"
{
base_path
}
_shape_projections.png"
,
)
# 2. Combined: shape colored by tex-slat latent features (pseudo-RGB from first 3 dims)
self
.
visualize_tex_slat_colored
(
slat
,
title
=
f
"Tex SLat Colored -
{
pipeline_type
}
"
,
save_path
=
f
"
{
base_path
}
_colored.png"
,
)
# 3. Per-feature projections (first 3 latent dims)
for
i
in
range
(
min
(
3
,
slat
.
feats
.
shape
[
1
])):
self
.
visualize_slat_features
(
slat
,
title
=
f
"Tex Feature
{
i
}
-
{
pipeline_type
}
"
,
save_path
=
f
"
{
base_path
}
_feature
{
i
}
.png"
,
feature_idx
=
i
,
)
else
:
self
.
visualize_sparse_structure_projections
(
shape_slat
.
coords
,
title
=
f
"Shape Structure -
{
pipeline_type
}
"
,
)
self
.
visualize_tex_slat_colored
(
slat
,
title
=
f
"Tex SLat Colored -
{
pipeline_type
}
"
,
)
for
i
in
range
(
min
(
3
,
slat
.
feats
.
shape
[
1
])):
self
.
visualize_slat_features
(
slat
,
title
=
f
"Tex Feature
{
i
}
-
{
pipeline_type
}
"
,
feature_idx
=
i
,
)
print
(
"=== Texture SLat Visualization Complete ===
\n
"
)
std
=
torch
.
tensor
(
self
.
tex_slat_normalization
[
'std'
])[
None
].
to
(
slat
.
device
)
mean
=
torch
.
tensor
(
self
.
tex_slat_normalization
[
'mean'
])[
None
].
to
(
slat
.
device
)
slat
=
slat
*
std
+
mean
return
slat
def
decode_tex_slat
(
self
,
slat
:
SparseTensor
,
subs
:
List
[
SparseTensor
],
)
->
SparseTensor
:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
Returns:
SparseTensor: The decoded texture voxels
"""
if
self
.
low_vram
:
self
.
models
[
'tex_slat_decoder'
].
to
(
self
.
device
)
ret
=
self
.
models
[
'tex_slat_decoder'
](
slat
,
guide_subs
=
subs
)
*
0.5
+
0.5
if
self
.
low_vram
:
self
.
models
[
'tex_slat_decoder'
].
cpu
()
return
ret
def
visualize_sparse_structure_matplotlib
(
self
,
coords
:
torch
.
Tensor
,
title
:
str
=
"Sparse Structure"
,
save_path
:
str
=
None
):
"""
Visualize sparse structure coordinates using matplotlib 3D scatter plot.
Args:
coords: torch.Tensor of shape [N, 4] with [batch, x, y, z]
title: Title for the plot
save_path: Optional path to save the figure
"""
# Convert to numpy and extract spatial coordinates (drop batch index)
coords_np
=
coords
.
cpu
().
numpy
()
x
=
coords_np
[:,
1
]
# x coordinate
y
=
coords_np
[:,
2
]
# y coordinate
z
=
coords_np
[:,
3
]
# z coordinate
# Create 3D plot
fig
=
plt
.
figure
(
figsize
=
(
10
,
8
))
ax
=
fig
.
add_subplot
(
111
,
projection
=
'3d'
)
# Plot points
scatter
=
ax
.
scatter
(
x
,
y
,
z
,
c
=
z
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.6
)
# Set labels and title
ax
.
set_xlabel
(
'X'
)
ax
.
set_ylabel
(
'Y'
)
ax
.
set_zlabel
(
'Z'
)
ax
.
set_title
(
f
'
{
title
}
\n
{
len
(
coords
)
}
occupied voxels'
)
# Add colorbar
plt
.
colorbar
(
scatter
,
label
=
'Z coordinate'
)
# Set equal aspect ratio
max_range
=
np
.
array
([
x
.
max
()
-
x
.
min
(),
y
.
max
()
-
y
.
min
(),
z
.
max
()
-
z
.
min
()]).
max
()
/
2.0
mid_x
=
(
x
.
max
()
+
x
.
min
())
*
0.5
mid_y
=
(
y
.
max
()
+
y
.
min
())
*
0.5
mid_z
=
(
z
.
max
()
+
z
.
min
())
*
0.5
ax
.
set_xlim
(
mid_x
-
max_range
,
mid_x
+
max_range
)
ax
.
set_ylim
(
mid_y
-
max_range
,
mid_y
+
max_range
)
ax
.
set_zlim
(
mid_z
-
max_range
,
mid_z
+
max_range
)
plt
.
tight_layout
()
if
save_path
:
plt
.
savefig
(
save_path
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved matplotlib visualization to
{
save_path
}
"
)
plt
.
show
()
plt
.
close
()
def
visualize_sparse_structure_voxel
(
self
,
coords
:
torch
.
Tensor
,
resolution
:
int
=
32
,
title
:
str
=
"Sparse Structure"
,
save_path
:
str
=
None
):
"""
Visualize sparse structure as a 3D voxel grid.
Args:
coords: torch.Tensor of shape [N, 4] with [batch, x, y, z]
resolution: Grid resolution (e.g., 32 for 32³ grid)
title: Title for the plot
save_path: Optional path to save the figure
"""
# Create empty 3D grid
grid
=
np
.
zeros
((
resolution
,
resolution
,
resolution
),
dtype
=
bool
)
# Fill in occupied voxels
coords_np
=
coords
.
cpu
().
numpy
()
for
coord
in
coords_np
:
_
,
x
,
y
,
z
=
coord
if
0
<=
x
<
resolution
and
0
<=
y
<
resolution
and
0
<=
z
<
resolution
:
grid
[
x
,
y
,
z
]
=
True
# Get coordinates of occupied voxels
x
,
y
,
z
=
np
.
where
(
grid
)
# Create 3D plot
fig
=
plt
.
figure
(
figsize
=
(
10
,
8
))
ax
=
fig
.
add_subplot
(
111
,
projection
=
'3d'
)
# Plot voxels
ax
.
scatter
(
x
,
y
,
z
,
c
=
z
,
cmap
=
'viridis'
,
s
=
10
,
alpha
=
0.3
)
# Set labels
ax
.
set_xlabel
(
'X'
)
ax
.
set_ylabel
(
'Y'
)
ax
.
set_zlabel
(
'Z'
)
ax
.
set_title
(
f
'
{
title
}
\n
{
len
(
coords
)
}
occupied voxels /
{
resolution
**
3
}
total'
)
plt
.
tight_layout
()
if
save_path
:
plt
.
savefig
(
save_path
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved voxel visualization to
{
save_path
}
"
)
plt
.
show
()
plt
.
close
()
def
visualize_sparse_structure_projections
(
self
,
coords
:
torch
.
Tensor
,
resolution
:
int
=
32
,
title
:
str
=
"Sparse Structure"
,
save_path
:
str
=
None
):
"""
Visualize sparse structure using 2D projections (XY, XZ, YZ planes).
Args:
coords: torch.Tensor of shape [N, 4] with [batch, x, y, z]
resolution: Grid resolution
title: Title for the plot
save_path: Optional path to save the figure
"""
coords_np
=
coords
.
cpu
().
numpy
()
x
=
coords_np
[:,
1
]
y
=
coords_np
[:,
2
]
z
=
coords_np
[:,
3
]
# Create figure with 3 subplots
fig
,
axes
=
plt
.
subplots
(
1
,
3
,
figsize
=
(
15
,
5
))
# XY projection (looking down Z axis)
axes
[
0
].
scatter
(
x
,
y
,
c
=
z
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.5
)
axes
[
0
].
set_xlabel
(
'X'
)
axes
[
0
].
set_ylabel
(
'Y'
)
axes
[
0
].
set_title
(
'XY Projection (Top View)'
)
axes
[
0
].
set_xlim
(
0
,
resolution
)
axes
[
0
].
set_ylim
(
0
,
resolution
)
axes
[
0
].
set_aspect
(
'equal'
)
# XZ projection (looking down Y axis)
axes
[
1
].
scatter
(
x
,
z
,
c
=
y
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.5
)
axes
[
1
].
set_xlabel
(
'X'
)
axes
[
1
].
set_ylabel
(
'Z'
)
axes
[
1
].
set_title
(
'XZ Projection (Side View)'
)
axes
[
1
].
set_xlim
(
0
,
resolution
)
axes
[
1
].
set_ylim
(
0
,
resolution
)
axes
[
1
].
set_aspect
(
'equal'
)
# YZ projection (looking down X axis)
axes
[
2
].
scatter
(
y
,
z
,
c
=
x
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.5
)
axes
[
2
].
set_xlabel
(
'Y'
)
axes
[
2
].
set_ylabel
(
'Z'
)
axes
[
2
].
set_title
(
'YZ Projection (Front View)'
)
axes
[
2
].
set_xlim
(
0
,
resolution
)
axes
[
2
].
set_ylim
(
0
,
resolution
)
axes
[
2
].
set_aspect
(
'equal'
)
plt
.
suptitle
(
f
'
{
title
}
\n
{
len
(
coords
)
}
occupied voxels'
,
fontsize
=
14
)
plt
.
tight_layout
()
if
save_path
:
plt
.
savefig
(
save_path
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved projections visualization to
{
save_path
}
"
)
plt
.
show
()
plt
.
close
()
def
visualize_sparse_structure_multi_view
(
self
,
coords
:
torch
.
Tensor
,
title
:
str
=
"Sparse Structure"
,
save_path
:
str
=
None
):
"""
Visualize sparse structure with multiple views (3D + 2D projections).
Args:
coords: torch.Tensor of shape [N, 4] with [batch, x, y, z]
title: Title for the plot
save_path: Optional path to save the figure
"""
import
matplotlib.pyplot
as
plt
import
numpy
as
np
coords_np
=
coords
.
cpu
().
numpy
()
x
,
y
,
z
=
coords_np
[:,
1
],
coords_np
[:,
2
],
coords_np
[:,
3
]
# Create multi-view visualization
fig
=
plt
.
figure
(
figsize
=
(
18
,
6
))
# 3D scatter plot
ax1
=
fig
.
add_subplot
(
131
,
projection
=
'3d'
)
ax1
.
scatter
(
x
,
y
,
z
,
c
=
z
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.6
)
ax1
.
set_title
(
'3D View'
)
ax1
.
set_xlabel
(
'X'
);
ax1
.
set_ylabel
(
'Y'
);
ax1
.
set_zlabel
(
'Z'
)
# XY projection
ax2
=
fig
.
add_subplot
(
132
)
ax2
.
scatter
(
x
,
y
,
c
=
z
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.5
)
ax2
.
set_title
(
'XY Projection'
)
ax2
.
set_xlabel
(
'X'
);
ax2
.
set_ylabel
(
'Y'
)
ax2
.
set_aspect
(
'equal'
)
# XZ projection
ax3
=
fig
.
add_subplot
(
133
)
ax3
.
scatter
(
x
,
z
,
c
=
y
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.5
)
ax3
.
set_title
(
'XZ Projection'
)
ax3
.
set_xlabel
(
'X'
);
ax3
.
set_ylabel
(
'Z'
)
ax3
.
set_aspect
(
'equal'
)
plt
.
suptitle
(
f
'
{
title
}
\n
{
len
(
coords
)
}
occupied voxels'
,
fontsize
=
14
)
plt
.
tight_layout
()
if
save_path
:
plt
.
savefig
(
save_path
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved multi-view visualization to
{
save_path
}
"
)
plt
.
show
()
plt
.
close
()
def
analyze_sparse_structure
(
self
,
coords
:
torch
.
Tensor
):
"""
Analyze and print statistics about the sparse structure.
Args:
coords: torch.Tensor of shape [N, 4]
"""
coords_np
=
coords
.
cpu
().
numpy
()
x
,
y
,
z
=
coords_np
[:,
1
],
coords_np
[:,
2
],
coords_np
[:,
3
]
print
(
f
"Sparse Structure Analysis:"
)
print
(
f
" Total occupied voxels:
{
len
(
coords
)
}
"
)
print
(
f
" X range: [
{
x
.
min
()
}
,
{
x
.
max
()
}
]"
)
print
(
f
" Y range: [
{
y
.
min
()
}
,
{
y
.
max
()
}
]"
)
print
(
f
" Z range: [
{
z
.
min
()
}
,
{
z
.
max
()
}
]"
)
print
(
f
" Center: [
{
x
.
mean
():.
1
f
}
,
{
y
.
mean
():.
1
f
}
,
{
z
.
mean
():.
1
f
}
]"
)
print
(
f
" Std dev: [
{
x
.
std
():.
1
f
}
,
{
y
.
std
():.
1
f
}
,
{
z
.
std
():.
1
f
}
]"
)
print
(
f
" Bounding box volume:
{
(
x
.
max
()
-
x
.
min
())
*
(
y
.
max
()
-
y
.
min
())
*
(
z
.
max
()
-
z
.
min
())
}
"
)
def
visualize_slat_features
(
self
,
slat
:
SparseTensor
,
title
:
str
=
"SLat Features"
,
save_path
:
str
=
None
,
feature_idx
:
int
=
0
):
"""
Visualize features from a SparseTensor (shape SLat).
Args:
slat: SparseTensor with features at sparse coordinates
title: Title for the plot
save_path: Optional path to save the figure
feature_idx: Which feature dimension to visualize (default: 0)
"""
coords_np
=
slat
.
coords
.
cpu
().
numpy
()
feats_np
=
slat
.
feats
.
cpu
().
numpy
()
# Extract coordinates and selected feature
x
=
coords_np
[:,
1
]
y
=
coords_np
[:,
2
]
z
=
coords_np
[:,
3
]
feature_values
=
feats_np
[:,
feature_idx
]
# Create 3D plot
fig
=
plt
.
figure
(
figsize
=
(
10
,
8
))
ax
=
fig
.
add_subplot
(
111
,
projection
=
'3d'
)
# Plot points colored by feature value
scatter
=
ax
.
scatter
(
x
,
y
,
z
,
c
=
feature_values
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.6
)
# Set labels and title
ax
.
set_xlabel
(
'X'
)
ax
.
set_ylabel
(
'Y'
)
ax
.
set_zlabel
(
'Z'
)
ax
.
set_title
(
f
'
{
title
}
\n
Feature
{
feature_idx
}
| Range: [
{
feature_values
.
min
():.
3
f
}
,
{
feature_values
.
max
():.
3
f
}
]'
)
# Add colorbar
plt
.
colorbar
(
scatter
,
label
=
f
'Feature
{
feature_idx
}
Value'
)
# Set equal aspect ratio
max_range
=
np
.
array
([
x
.
max
()
-
x
.
min
(),
y
.
max
()
-
y
.
min
(),
z
.
max
()
-
z
.
min
()]).
max
()
/
2.0
mid_x
=
(
x
.
max
()
+
x
.
min
())
*
0.5
mid_y
=
(
y
.
max
()
+
y
.
min
())
*
0.5
mid_z
=
(
z
.
max
()
+
z
.
min
())
*
0.5
ax
.
set_xlim
(
mid_x
-
max_range
,
mid_x
+
max_range
)
ax
.
set_ylim
(
mid_y
-
max_range
,
mid_y
+
max_range
)
ax
.
set_zlim
(
mid_z
-
max_range
,
mid_z
+
max_range
)
plt
.
tight_layout
()
if
save_path
:
plt
.
savefig
(
save_path
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved SLat feature visualization to
{
save_path
}
"
)
plt
.
show
()
plt
.
close
()
def
visualize_tex_slat_colored
(
self
,
slat
:
SparseTensor
,
title
:
str
=
"Tex SLat Colored"
,
save_path
:
str
=
None
):
"""
Visualize texture SLat with points colored by the first 3 latent feature dimensions
mapped to R, G, B — giving a pseudo-color view of the texture distribution across the shape.
Also shows three 2D projection panels (XY/XZ/YZ) beside the 3D view so you can see
coverage completeness at a glance.
Args:
slat: SparseTensor with texture latent features [N, C]
title: Title for the plot
save_path: Optional path to save the figure. None = interactive display.
"""
import
numpy
as
np
coords_np
=
slat
.
coords
.
cpu
().
float
().
numpy
()
feats_np
=
slat
.
feats
.
cpu
().
float
().
numpy
()
x
=
coords_np
[:,
1
]
y
=
coords_np
[:,
2
]
z
=
coords_np
[:,
3
]
# Build per-point RGB from first 3 feature dims, normalised to [0, 1]
n_color_dims
=
min
(
3
,
feats_np
.
shape
[
1
])
rgb
=
feats_np
[:,
:
n_color_dims
].
copy
()
for
ch
in
range
(
n_color_dims
):
lo
,
hi
=
rgb
[:,
ch
].
min
(),
rgb
[:,
ch
].
max
()
rgb
[:,
ch
]
=
(
rgb
[:,
ch
]
-
lo
)
/
(
hi
-
lo
+
1e-8
)
if
n_color_dims
<
3
:
pad
=
np
.
ones
((
rgb
.
shape
[
0
],
3
-
n_color_dims
))
rgb
=
np
.
concatenate
([
rgb
,
pad
],
axis
=
1
)
rgb
=
np
.
clip
(
rgb
,
0.0
,
1.0
)
fig
=
plt
.
figure
(
figsize
=
(
22
,
6
))
fig
.
suptitle
(
f
'
{
title
}
(
{
len
(
x
)
}
voxels,
{
feats_np
.
shape
[
1
]
}
feat dims)'
,
fontsize
=
13
)
# 3D scatter coloured by pseudo-RGB
ax3d
=
fig
.
add_subplot
(
141
,
projection
=
'3d'
)
ax3d
.
scatter
(
x
,
y
,
z
,
c
=
rgb
,
s
=
1
,
alpha
=
0.6
)
ax3d
.
set_xlabel
(
'X'
);
ax3d
.
set_ylabel
(
'Y'
);
ax3d
.
set_zlabel
(
'Z'
)
ax3d
.
set_title
(
'3D (pseudo-RGB)'
)
# XY projection
ax_xy
=
fig
.
add_subplot
(
142
)
ax_xy
.
scatter
(
x
,
y
,
c
=
rgb
,
s
=
1
,
alpha
=
0.5
)
ax_xy
.
set_xlabel
(
'X'
);
ax_xy
.
set_ylabel
(
'Y'
)
ax_xy
.
set_title
(
'XY (top)'
)
ax_xy
.
set_aspect
(
'equal'
)
# XZ projection
ax_xz
=
fig
.
add_subplot
(
143
)
ax_xz
.
scatter
(
x
,
z
,
c
=
rgb
,
s
=
1
,
alpha
=
0.5
)
ax_xz
.
set_xlabel
(
'X'
);
ax_xz
.
set_ylabel
(
'Z'
)
ax_xz
.
set_title
(
'XZ (side)'
)
ax_xz
.
set_aspect
(
'equal'
)
# YZ projection
ax_yz
=
fig
.
add_subplot
(
144
)
ax_yz
.
scatter
(
y
,
z
,
c
=
rgb
,
s
=
1
,
alpha
=
0.5
)
ax_yz
.
set_xlabel
(
'Y'
);
ax_yz
.
set_ylabel
(
'Z'
)
ax_yz
.
set_title
(
'YZ (front)'
)
ax_yz
.
set_aspect
(
'equal'
)
plt
.
tight_layout
()
if
save_path
:
plt
.
savefig
(
save_path
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved tex-slat colored visualization to
{
save_path
}
"
)
plt
.
show
()
plt
.
close
()
def
visualize_decoded_mesh
(
self
,
mesh
,
title
:
str
=
"Decoded Mesh"
,
save_path_prefix
:
str
=
None
):
"""
Visualize a decoded triangle mesh (vertices + faces).
Renders four panels:
- 3D scatter of vertices coloured by Z (subsampled to ≤50k points so matplotlib doesn't choke)
- XY / XZ / YZ 2D projections
Saves four separate PNGs when save_path_prefix is given (one per panel style matches
the naming convention used elsewhere in the pipeline):
<prefix>_3d.png, <prefix>_projections.png
"""
import
numpy
as
np
import
os
verts
=
mesh
.
vertices
.
cpu
().
float
().
numpy
()
# [V, 3]
n_verts
=
verts
.
shape
[
0
]
n_faces
=
mesh
.
faces
.
shape
[
0
]
MAX_SCATTER
=
50_000
if
n_verts
>
MAX_SCATTER
:
idx
=
np
.
random
.
choice
(
n_verts
,
MAX_SCATTER
,
replace
=
False
)
v
=
verts
[
idx
]
else
:
v
=
verts
x
,
y
,
z
=
v
[:,
0
],
v
[:,
1
],
v
[:,
2
]
subtitle
=
f
"
{
n_verts
:,
}
vertices
{
n_faces
:,
}
faces"
+
(
f
" (scatter:
{
len
(
x
):,
}
sampled)"
if
n_verts
>
MAX_SCATTER
else
""
)
# --- 3D scatter ---
fig
=
plt
.
figure
(
figsize
=
(
10
,
8
))
ax
=
fig
.
add_subplot
(
111
,
projection
=
'3d'
)
ax
.
scatter
(
x
,
y
,
z
,
c
=
z
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.6
)
ax
.
set_xlabel
(
'X'
);
ax
.
set_ylabel
(
'Y'
);
ax
.
set_zlabel
(
'Z'
)
ax
.
set_title
(
f
'
{
title
}
\n
{
subtitle
}
'
)
plt
.
tight_layout
()
if
save_path_prefix
:
p
=
f
"
{
save_path_prefix
}
_3d.png"
plt
.
savefig
(
p
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved decoded mesh 3D to
{
p
}
"
)
plt
.
show
();
plt
.
close
()
# --- 3-panel 2D projections ---
fig
,
axes
=
plt
.
subplots
(
1
,
3
,
figsize
=
(
18
,
6
))
axes
[
0
].
scatter
(
x
,
y
,
c
=
z
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.5
)
axes
[
0
].
set_xlabel
(
'X'
);
axes
[
0
].
set_ylabel
(
'Y'
);
axes
[
0
].
set_title
(
'XY (top)'
)
axes
[
0
].
set_aspect
(
'equal'
)
axes
[
1
].
scatter
(
x
,
z
,
c
=
y
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.5
)
axes
[
1
].
set_xlabel
(
'X'
);
axes
[
1
].
set_ylabel
(
'Z'
);
axes
[
1
].
set_title
(
'XZ (side)'
)
axes
[
1
].
set_aspect
(
'equal'
)
axes
[
2
].
scatter
(
y
,
z
,
c
=
x
,
cmap
=
'viridis'
,
s
=
1
,
alpha
=
0.5
)
axes
[
2
].
set_xlabel
(
'Y'
);
axes
[
2
].
set_ylabel
(
'Z'
);
axes
[
2
].
set_title
(
'YZ (front)'
)
axes
[
2
].
set_aspect
(
'equal'
)
plt
.
suptitle
(
f
'
{
title
}
\n
{
subtitle
}
'
,
fontsize
=
13
)
plt
.
tight_layout
()
if
save_path_prefix
:
p
=
f
"
{
save_path_prefix
}
_projections.png"
plt
.
savefig
(
p
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved decoded mesh projections to
{
p
}
"
)
plt
.
show
();
plt
.
close
()
def
visualize_mesh_with_voxel
(
self
,
mv
,
title
:
str
=
"MeshWithVoxel"
,
save_path_prefix
:
str
=
None
):
"""
Visualize a MeshWithVoxel: overlays mesh vertices (grey) and texture voxel positions
(coloured by pseudo-RGB from first 3 attr dims) in one 5-panel figure.
Panels: 3D overlay, XY / XZ / YZ 2D projections.
"""
import
numpy
as
np
verts
=
mv
.
vertices
.
cpu
().
float
().
numpy
()
n_verts
=
verts
.
shape
[
0
]
n_faces
=
mv
.
faces
.
shape
[
0
]
coords
=
mv
.
coords
.
cpu
().
float
().
numpy
()
# [N, 3] (already stripped of batch dim)
attrs
=
mv
.
attrs
.
cpu
().
float
().
numpy
()
# [N, C]
n_vox
=
coords
.
shape
[
0
]
MAX_SCATTER
=
50_000
if
n_verts
>
MAX_SCATTER
:
vi
=
np
.
random
.
choice
(
n_verts
,
MAX_SCATTER
,
replace
=
False
)
vp
=
verts
[
vi
]
else
:
vp
=
verts
if
n_vox
>
MAX_SCATTER
:
ci
=
np
.
random
.
choice
(
n_vox
,
MAX_SCATTER
,
replace
=
False
)
cp
=
coords
[
ci
];
ap
=
attrs
[
ci
]
else
:
cp
=
coords
;
ap
=
attrs
# Build pseudo-RGB from first 3 attr dims
n_color
=
min
(
3
,
ap
.
shape
[
1
])
rgb
=
ap
[:,
:
n_color
].
copy
()
for
ch
in
range
(
n_color
):
lo
,
hi
=
rgb
[:,
ch
].
min
(),
rgb
[:,
ch
].
max
()
rgb
[:,
ch
]
=
(
rgb
[:,
ch
]
-
lo
)
/
(
hi
-
lo
+
1e-8
)
if
n_color
<
3
:
rgb
=
np
.
concatenate
([
rgb
,
np
.
ones
((
rgb
.
shape
[
0
],
3
-
n_color
))],
axis
=
1
)
rgb
=
np
.
clip
(
rgb
,
0
,
1
)
subtitle
=
(
f
"Mesh:
{
n_verts
:,
}
v
{
n_faces
:,
}
f | Voxels:
{
n_vox
:,
}
"
+
(
" (both subsampled)"
if
n_verts
>
MAX_SCATTER
or
n_vox
>
MAX_SCATTER
else
""
))
# Voxel coords are integer indices; convert to world space for overlay
vox_world
=
cp
*
mv
.
voxel_size
+
mv
.
origin
.
cpu
().
numpy
()
vx
,
vy
,
vz
=
vox_world
[:,
0
],
vox_world
[:,
1
],
vox_world
[:,
2
]
mx
,
my
,
mz
=
vp
[:,
0
],
vp
[:,
1
],
vp
[:,
2
]
# --- 3D overlay ---
fig
=
plt
.
figure
(
figsize
=
(
11
,
8
))
ax
=
fig
.
add_subplot
(
111
,
projection
=
'3d'
)
ax
.
scatter
(
mx
,
my
,
mz
,
c
=
'lightgrey'
,
s
=
1
,
alpha
=
0.3
,
label
=
'mesh verts'
)
ax
.
scatter
(
vx
,
vy
,
vz
,
c
=
rgb
,
s
=
2
,
alpha
=
0.6
,
label
=
'tex voxels'
)
ax
.
set_xlabel
(
'X'
);
ax
.
set_ylabel
(
'Y'
);
ax
.
set_zlabel
(
'Z'
)
ax
.
set_title
(
f
'
{
title
}
\n
{
subtitle
}
'
)
plt
.
tight_layout
()
if
save_path_prefix
:
p
=
f
"
{
save_path_prefix
}
_3d.png"
plt
.
savefig
(
p
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved MeshWithVoxel 3D to
{
p
}
"
)
plt
.
show
();
plt
.
close
()
# --- 4-panel 2D projections ---
fig
,
axes
=
plt
.
subplots
(
1
,
4
,
figsize
=
(
24
,
6
))
def
proj
(
ax_
,
hx
,
hy
,
hz
,
label
):
ax_
.
scatter
(
hx
,
hy
,
c
=
'lightgrey'
,
s
=
1
,
alpha
=
0.25
)
ax_
.
scatter
(
vx
if
label
==
'XY'
else
(
vx
if
label
==
'XZ'
else
vy
),
vy
if
label
==
'XY'
else
(
vz
if
label
==
'XZ'
else
vz
),
c
=
rgb
,
s
=
1
,
alpha
=
0.5
)
ax_
.
set_aspect
(
'equal'
)
axes
[
0
].
scatter
(
mx
,
my
,
c
=
'lightgrey'
,
s
=
1
,
alpha
=
0.25
)
axes
[
0
].
scatter
(
vx
,
vy
,
c
=
rgb
,
s
=
1
,
alpha
=
0.5
)
axes
[
0
].
set_xlabel
(
'X'
);
axes
[
0
].
set_ylabel
(
'Y'
);
axes
[
0
].
set_title
(
'XY (top)'
);
axes
[
0
].
set_aspect
(
'equal'
)
axes
[
1
].
scatter
(
mx
,
mz
,
c
=
'lightgrey'
,
s
=
1
,
alpha
=
0.25
)
axes
[
1
].
scatter
(
vx
,
vz
,
c
=
rgb
,
s
=
1
,
alpha
=
0.5
)
axes
[
1
].
set_xlabel
(
'X'
);
axes
[
1
].
set_ylabel
(
'Z'
);
axes
[
1
].
set_title
(
'XZ (side)'
);
axes
[
1
].
set_aspect
(
'equal'
)
axes
[
2
].
scatter
(
my
,
mz
,
c
=
'lightgrey'
,
s
=
1
,
alpha
=
0.25
)
axes
[
2
].
scatter
(
vy
,
vz
,
c
=
rgb
,
s
=
1
,
alpha
=
0.5
)
axes
[
2
].
set_xlabel
(
'Y'
);
axes
[
2
].
set_ylabel
(
'Z'
);
axes
[
2
].
set_title
(
'YZ (front)'
);
axes
[
2
].
set_aspect
(
'equal'
)
# 4th panel: voxel coverage ratio as bar chart per axis
axes
[
3
].
axis
(
'off'
)
info
=
(
f
"Mesh vertices :
{
n_verts
:,
}
\n
"
f
"Mesh faces :
{
n_faces
:,
}
\n
"
f
"Tex voxels :
{
n_vox
:,
}
\n
"
f
"Voxel size :
{
mv
.
voxel_size
:.
5
f
}
\n
"
f
"Voxel world X : [
{
vx
.
min
():.
3
f
}
,
{
vx
.
max
():.
3
f
}
]
\n
"
f
"Voxel world Y : [
{
vy
.
min
():.
3
f
}
,
{
vy
.
max
():.
3
f
}
]
\n
"
f
"Voxel world Z : [
{
vz
.
min
():.
3
f
}
,
{
vz
.
max
():.
3
f
}
]
\n
"
f
"Attr dims :
{
mv
.
attrs
.
shape
[
1
]
}
\n
"
f
"Attr range : [
{
mv
.
attrs
.
min
().
item
():.
4
f
}
,
{
mv
.
attrs
.
max
().
item
():.
4
f
}
]"
)
axes
[
3
].
text
(
0.05
,
0.95
,
info
,
transform
=
axes
[
3
].
transAxes
,
fontsize
=
10
,
verticalalignment
=
'top'
,
fontfamily
=
'monospace'
)
axes
[
3
].
set_title
(
'Stats'
)
plt
.
suptitle
(
f
'
{
title
}
\n
{
subtitle
}
'
,
fontsize
=
13
)
plt
.
tight_layout
()
if
save_path_prefix
:
p
=
f
"
{
save_path_prefix
}
_projections.png"
plt
.
savefig
(
p
,
dpi
=
150
,
bbox_inches
=
'tight'
)
print
(
f
"Saved MeshWithVoxel projections to
{
p
}
"
)
plt
.
show
();
plt
.
close
()
def
analyze_slat_features
(
self
,
slat
:
SparseTensor
):
"""
Analyze and print statistics about SLat features.
Args:
slat: SparseTensor with features
"""
coords_np
=
slat
.
coords
.
cpu
().
numpy
()
feats_np
=
slat
.
feats
.
cpu
().
numpy
()
print
(
f
"
\n
SLat Features Analysis:"
)
print
(
f
" Number of tokens:
{
slat
.
coords
.
shape
[
0
]
}
"
)
print
(
f
" Feature dimensions:
{
slat
.
feats
.
shape
[
1
]
}
"
)
print
(
f
" Feature statistics:"
)
for
i
in
range
(
min
(
5
,
slat
.
feats
.
shape
[
1
])):
# Show first 5 features
feat
=
feats_np
[:,
i
]
print
(
f
" Feature
{
i
}
: min=
{
feat
.
min
():.
4
f
}
, max=
{
feat
.
max
():.
4
f
}
, mean=
{
feat
.
mean
():.
4
f
}
, std=
{
feat
.
std
():.
4
f
}
"
)
print
(
f
" NaN values:
{
np
.
isnan
(
feats_np
).
any
()
}
"
)
print
(
f
" Inf values:
{
np
.
isinf
(
feats_np
).
any
()
}
"
)
print
(
f
" Coordinate range: X=[
{
coords_np
[:,
1
].
min
()
}
,
{
coords_np
[:,
1
].
max
()
}
], "
f
"Y=[
{
coords_np
[:,
2
].
min
()
}
,
{
coords_np
[:,
2
].
max
()
}
], "
f
"Z=[
{
coords_np
[:,
3
].
min
()
}
,
{
coords_np
[:,
3
].
max
()
}
]"
)
@
torch
.
no_grad
()
def
decode_latent
(
self
,
shape_slat
:
SparseTensor
,
tex_slat
:
SparseTensor
,
resolution
:
int
,
visualize
:
bool
=
False
,
visualize_save_dir
:
str
=
None
,
pipeline_type
:
str
=
'unknown'
,
)
->
List
[
MeshWithVoxel
]:
"""
Decode the latent codes.
Args:
shape_slat (SparseTensor): The structured latent for shape.
tex_slat (SparseTensor): The structured latent for texture.
resolution (int): The resolution of the output.
"""
L
=
get_logger
()
section
(
f
"decode_latent resolution=
{
resolution
}
"
)
section
(
"decode_shape_slat"
)
log_sparse
(
shape_slat
,
"shape_slat-in"
)
meshes
,
subs
=
self
.
decode_shape_slat
(
shape_slat
,
resolution
)
L
.
info
(
f
"
{
elapsed
()
}
decode_shape_slat produced
{
len
(
meshes
)
}
mesh(es)"
)
for
i
,
m
in
enumerate
(
meshes
):
log_mesh
(
m
.
vertices
,
m
.
faces
,
f
"shape_mesh[
{
i
}
]"
)
# Visualize decoded shape meshes
if
visualize
:
import
os
for
i
,
m
in
enumerate
(
meshes
):
base
=
(
os
.
path
.
join
(
visualize_save_dir
,
f
"decoded_mesh_
{
pipeline_type
}
_s
{
i
}
"
)
if
visualize_save_dir
else
None
)
if
base
:
os
.
makedirs
(
visualize_save_dir
,
exist_ok
=
True
)
self
.
visualize_decoded_mesh
(
m
,
title
=
f
"Decoded Shape Mesh [
{
i
}
] -
{
pipeline_type
}
"
,
save_path_prefix
=
base
)
section
(
"decode_tex_slat"
)
log_sparse
(
tex_slat
,
"tex_slat-in"
)
tex_voxels
=
self
.
decode_tex_slat
(
tex_slat
,
subs
)
L
.
info
(
f
"
{
elapsed
()
}
decode_tex_slat produced
{
len
(
tex_voxels
)
}
voxel set(s)"
)
#Commented temporarily for speed.
"""
# Visualize texture voxels
if visualize:
import os
for i, v in enumerate(tex_voxels):
base = (os.path.join(visualize_save_dir, f"tex_voxels_{pipeline_type}_s{i}")
if visualize_save_dir else None)
if base:
os.makedirs(visualize_save_dir, exist_ok=True)
self.visualize_tex_slat_colored(v,
title=f"Tex Voxels [{i}] - {pipeline_type}",
save_path=f"{base}_colored.png" if base else None)
"""
section
(
"build MeshWithVoxel"
)
out_mesh
=
[]
for
i
,
(
m
,
v
)
in
enumerate
(
zip
(
meshes
,
tex_voxels
)):
L
.
info
(
f
"
{
elapsed
()
}
sample
{
i
}
:"
)
log_sparse
(
v
,
f
"tex_voxels[
{
i
}
]"
)
L
.
info
(
f
" spatial_shape=
{
v
.
spatial_shape
}
"
f
"coords_max=
{
v
.
coords
.
max
(
dim
=
0
).
values
.
tolist
()
}
"
)
log_mesh
(
m
.
vertices
,
m
.
faces
,
f
"before-fill_holes[
{
i
}
]"
)
# CPU simplification via pyfqmr (QEM) before fill_holes to avoid
# GPU OOM from CuMesh's O(F*3) edge buffers on large meshes.
import
pyfqmr
,
time
_target
=
4_000_000
if
m
.
faces
.
shape
[
0
]
>
_target
:
_v_np
=
m
.
vertices
.
detach
().
cpu
().
float
().
numpy
()
_f_np
=
m
.
faces
.
detach
().
cpu
().
int
().
numpy
()
L
.
info
(
f
" [pyfqmr] simplify
{
m
.
faces
.
shape
[
0
]
}
→
{
_target
}
faces ..."
)
_t0
=
time
.
perf_counter
()
_simplifier
=
pyfqmr
.
Simplify
()
_simplifier
.
setMesh
(
_v_np
,
_f_np
)
_simplifier
.
simplify_mesh
(
_target
,
aggressiveness
=
7
,
verbose
=
False
)
_sv
,
_sf
,
_sn
=
_simplifier
.
getMesh
()
_dt
=
time
.
perf_counter
()
-
_t0
L
.
info
(
f
" [pyfqmr] done in
{
_dt
:.
2
f
}
s →
{
len
(
_sv
)
}
verts
{
len
(
_sf
)
}
faces"
)
m
.
vertices
=
torch
.
from_numpy
(
_sv
).
to
(
dtype
=
torch
.
float32
,
device
=
m
.
vertices
.
device
)
m
.
faces
=
torch
.
from_numpy
(
_sf
).
to
(
dtype
=
torch
.
int32
,
device
=
m
.
faces
.
device
)
m
.
fill_holes
()
log_mesh
(
m
.
vertices
,
m
.
faces
,
f
"after-fill_holes[
{
i
}
]"
)
coords_xyz
=
v
.
coords
[:,
1
:].
contiguous
()
L
.
info
(
f
" coords_xyz:
{
list
(
coords_xyz
.
shape
)
}
"
f
"range=
{
[
coords_xyz
.
min
().
item
(),
coords_xyz
.
max
().
item
()]
}
"
)
L
.
info
(
f
" attrs:
{
list
(
v
.
feats
.
shape
)
}
"
f
"range=[
{
v
.
feats
.
min
().
item
():.
4
g
}
,
{
v
.
feats
.
max
().
item
():.
4
g
}
] "
f
"NaN=
{
torch
.
isnan
(
v
.
feats
).
any
().
item
()
}
"
)
L
.
info
(
f
" voxel_size=
{
1
/
resolution
:.
6
f
}
origin=[-0.5,-0.5,-0.5]"
)
mv
=
MeshWithVoxel
(
m
.
vertices
,
m
.
faces
,
origin
=
[
-
0.5
,
-
0.5
,
-
0.5
],
voxel_size
=
1
/
resolution
,
coords
=
coords_xyz
,
attrs
=
v
.
feats
,
voxel_shape
=
torch
.
Size
([
*
v
.
shape
,
*
v
.
spatial_shape
]),
layout
=
self
.
pbr_attr_layout
)
L
.
info
(
f
" MeshWithVoxel.voxel_shape=
{
mv
.
voxel_shape
}
"
f
"voxel_size=
{
mv
.
voxel_size
}
origin=
{
mv
.
origin
}
"
)
# Visualize final MeshWithVoxel
if
visualize
:
import
os
base
=
(
os
.
path
.
join
(
visualize_save_dir
,
f
"mesh_with_voxel_
{
pipeline_type
}
_s
{
i
}
"
)
if
visualize_save_dir
else
None
)
if
base
:
os
.
makedirs
(
visualize_save_dir
,
exist_ok
=
True
)
self
.
visualize_mesh_with_voxel
(
mv
,
title
=
f
"MeshWithVoxel [
{
i
}
] -
{
pipeline_type
}
"
,
save_path_prefix
=
base
)
out_mesh
.
append
(
mv
)
section
(
"decode_latent complete"
)
return
out_mesh
@
torch
.
no_grad
()
def
run
(
self
,
image
:
Image
.
Image
,
num_samples
:
int
=
1
,
seed
:
int
=
42
,
sparse_structure_sampler_params
:
dict
=
{},
shape_slat_sampler_params
:
dict
=
{},
tex_slat_sampler_params
:
dict
=
{},
preprocess_image
:
bool
=
True
,
return_latent
:
bool
=
False
,
pipeline_type
:
Optional
[
str
]
=
None
,
max_num_tokens
:
int
=
49152
,
visualize_sparse_structure
:
bool
=
False
,
visualize_save_dir
:
str
=
None
,
)
->
List
[
MeshWithVoxel
]:
"""
Run the pipeline.
Args:
image (Image.Image): The image prompt.
num_samples (int): The number of samples to generate.
seed (int): The random seed.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler.
tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
preprocess_image (bool): Whether to preprocess the image.
return_latent (bool): Whether to return the latent codes.
pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
max_num_tokens (int): The maximum number of tokens to use.
visualize_sparse_structure (bool): Whether to visualize the sparse structure.
visualize_save_dir (str): Directory to save visualization images. If None, displays interactively.
"""
# Check pipeline type
pipeline_type
=
pipeline_type
or
self
.
default_pipeline_type
if
pipeline_type
==
'512'
:
assert
'shape_slat_flow_model_512'
in
self
.
models
,
"No 512 resolution shape SLat flow model found."
assert
'tex_slat_flow_model_512'
in
self
.
models
,
"No 512 resolution texture SLat flow model found."
elif
pipeline_type
==
'1024'
:
assert
'shape_slat_flow_model_1024'
in
self
.
models
,
"No 1024 resolution shape SLat flow model found."
assert
'tex_slat_flow_model_1024'
in
self
.
models
,
"No 1024 resolution texture SLat flow model found."
elif
pipeline_type
==
'1024_cascade'
:
assert
'shape_slat_flow_model_512'
in
self
.
models
,
"No 512 resolution shape SLat flow model found."
assert
'shape_slat_flow_model_1024'
in
self
.
models
,
"No 1024 resolution shape SLat flow model found."
assert
'tex_slat_flow_model_1024'
in
self
.
models
,
"No 1024 resolution texture SLat flow model found."
elif
pipeline_type
==
'1536_cascade'
:
assert
'shape_slat_flow_model_512'
in
self
.
models
,
"No 512 resolution shape SLat flow model found."
assert
'shape_slat_flow_model_1024'
in
self
.
models
,
"No 1024 resolution shape SLat flow model found."
assert
'tex_slat_flow_model_1024'
in
self
.
models
,
"No 1024 resolution texture SLat flow model found."
else
:
raise
ValueError
(
f
"Invalid pipeline type:
{
pipeline_type
}
"
)
if
preprocess_image
:
image
=
self
.
preprocess_image
(
image
)
torch
.
manual_seed
(
seed
)
cond_512
=
self
.
get_cond
([
image
],
512
)
cond_1024
=
self
.
get_cond
([
image
],
1024
)
if
pipeline_type
!=
'512'
else
None
ss_res
=
{
'512'
:
32
,
'1024'
:
64
,
'1024_cascade'
:
32
,
'1536_cascade'
:
32
}[
pipeline_type
]
coords
=
self
.
sample_sparse_structure
(
cond_512
,
ss_res
,
num_samples
,
sparse_structure_sampler_params
)
# Visualize sparse structure if requested
if
visualize_sparse_structure
:
print
(
"
\n
=== Sparse Structure Visualization ==="
)
self
.
analyze_sparse_structure
(
coords
)
if
visualize_save_dir
:
import
os
os
.
makedirs
(
visualize_save_dir
,
exist_ok
=
True
)
base_path
=
os
.
path
.
join
(
visualize_save_dir
,
f
"sparse_structure_
{
pipeline_type
}
_seed
{
seed
}
"
)
self
.
visualize_sparse_structure_matplotlib
(
coords
,
title
=
f
"Sparse Structure -
{
pipeline_type
}
(seed=
{
seed
}
)"
,
save_path
=
f
"
{
base_path
}
_3d.png"
)
self
.
visualize_sparse_structure_voxel
(
coords
,
resolution
=
ss_res
,
title
=
f
"Voxel Grid -
{
pipeline_type
}
(seed=
{
seed
}
)"
,
save_path
=
f
"
{
base_path
}
_voxel.png"
)
self
.
visualize_sparse_structure_projections
(
coords
,
resolution
=
ss_res
,
title
=
f
"Projections -
{
pipeline_type
}
(seed=
{
seed
}
)"
,
save_path
=
f
"
{
base_path
}
_projections.png"
)
self
.
visualize_sparse_structure_multi_view
(
coords
,
title
=
f
"Multi-View -
{
pipeline_type
}
(seed=
{
seed
}
)"
,
save_path
=
f
"
{
base_path
}
_multi_view.png"
)
else
:
# Interactive visualization (no saving)
self
.
visualize_sparse_structure_matplotlib
(
coords
,
title
=
f
"Sparse Structure -
{
pipeline_type
}
(seed=
{
seed
}
)"
)
self
.
visualize_sparse_structure_voxel
(
coords
,
resolution
=
ss_res
,
title
=
f
"Voxel Grid -
{
pipeline_type
}
(seed=
{
seed
}
)"
)
self
.
visualize_sparse_structure_projections
(
coords
,
resolution
=
ss_res
,
title
=
f
"Projections -
{
pipeline_type
}
(seed=
{
seed
}
)"
)
self
.
visualize_sparse_structure_multi_view
(
coords
,
title
=
f
"Multi-View -
{
pipeline_type
}
(seed=
{
seed
}
)"
)
print
(
"=== Visualization Complete ===
\n
"
)
if
pipeline_type
==
'512'
:
shape_slat
=
self
.
sample_shape_slat
(
cond_512
,
self
.
models
[
'shape_slat_flow_model_512'
],
coords
,
shape_slat_sampler_params
)
tex_slat
=
self
.
sample_tex_slat
(
cond_512
,
self
.
models
[
'tex_slat_flow_model_512'
],
shape_slat
,
tex_slat_sampler_params
,
visualize
=
visualize_sparse_structure
,
visualize_save_dir
=
visualize_save_dir
,
pipeline_type
=
pipeline_type
,
)
res
=
512
elif
pipeline_type
==
'1024'
:
shape_slat
=
self
.
sample_shape_slat
(
cond_1024
,
self
.
models
[
'shape_slat_flow_model_1024'
],
coords
,
shape_slat_sampler_params
)
tex_slat
=
self
.
sample_tex_slat
(
cond_1024
,
self
.
models
[
'tex_slat_flow_model_1024'
],
shape_slat
,
tex_slat_sampler_params
,
visualize
=
visualize_sparse_structure
,
visualize_save_dir
=
visualize_save_dir
,
pipeline_type
=
pipeline_type
,
)
res
=
1024
elif
pipeline_type
==
'1024_cascade'
:
shape_slat
,
res
=
self
.
sample_shape_slat_cascade
(
cond_512
,
cond_1024
,
self
.
models
[
'shape_slat_flow_model_512'
],
self
.
models
[
'shape_slat_flow_model_1024'
],
512
,
1024
,
coords
,
shape_slat_sampler_params
,
max_num_tokens
,
visualize_hr_coords
=
visualize_sparse_structure
,
visualize_save_dir
=
visualize_save_dir
,
)
tex_slat
=
self
.
sample_tex_slat
(
cond_1024
,
self
.
models
[
'tex_slat_flow_model_1024'
],
shape_slat
,
tex_slat_sampler_params
,
visualize
=
visualize_sparse_structure
,
visualize_save_dir
=
visualize_save_dir
,
pipeline_type
=
pipeline_type
,
)
elif
pipeline_type
==
'1536_cascade'
:
shape_slat
,
res
=
self
.
sample_shape_slat_cascade
(
cond_512
,
cond_1024
,
self
.
models
[
'shape_slat_flow_model_512'
],
self
.
models
[
'shape_slat_flow_model_1024'
],
512
,
1536
,
coords
,
shape_slat_sampler_params
,
max_num_tokens
,
visualize_hr_coords
=
visualize_sparse_structure
,
visualize_save_dir
=
visualize_save_dir
,
)
tex_slat
=
self
.
sample_tex_slat
(
cond_1024
,
self
.
models
[
'tex_slat_flow_model_1024'
],
shape_slat
,
tex_slat_sampler_params
,
visualize
=
visualize_sparse_structure
,
visualize_save_dir
=
visualize_save_dir
,
pipeline_type
=
pipeline_type
,
)
torch
.
cuda
.
empty_cache
()
out_mesh
=
self
.
decode_latent
(
shape_slat
,
tex_slat
,
res
,
visualize
=
visualize_sparse_structure
,
visualize_save_dir
=
visualize_save_dir
,
pipeline_type
=
pipeline_type
)
if
return_latent
:
return
out_mesh
,
(
shape_slat
,
tex_slat
,
res
)
else
:
return
out_mesh
TRELLIS.2_DCU/trellis2/pipelines/trellis2_texturing.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
PIL
import
Image
import
trimesh
from
.base
import
Pipeline
from
.
import
samplers
,
rembg
from
..modules.sparse
import
SparseTensor
from
..modules
import
image_feature_extractor
import
o_voxel
import
cumesh
import
nvdiffrast.torch
as
dr
import
cv2
import
flex_gemm
class
Trellis2TexturingPipeline
(
Pipeline
):
"""
Pipeline for inferring Trellis2 image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
shape_slat_normalization (dict): The normalization parameters for the structured latent.
tex_slat_normalization (dict): The normalization parameters for the texture latent.
image_cond_model (Callable): The image conditioning model.
rembg_model (Callable): The model for removing background.
low_vram (bool): Whether to use low-VRAM mode.
"""
model_names_to_load
=
[
'shape_slat_encoder'
,
'tex_slat_decoder'
,
'tex_slat_flow_model_512'
,
'tex_slat_flow_model_1024'
]
def
__init__
(
self
,
models
:
dict
[
str
,
nn
.
Module
]
=
None
,
tex_slat_sampler
:
samplers
.
Sampler
=
None
,
tex_slat_sampler_params
:
dict
=
None
,
shape_slat_normalization
:
dict
=
None
,
tex_slat_normalization
:
dict
=
None
,
image_cond_model
:
Callable
=
None
,
rembg_model
:
Callable
=
None
,
low_vram
:
bool
=
True
,
):
if
models
is
None
:
return
super
().
__init__
(
models
)
self
.
tex_slat_sampler
=
tex_slat_sampler
self
.
tex_slat_sampler_params
=
tex_slat_sampler_params
self
.
shape_slat_normalization
=
shape_slat_normalization
self
.
tex_slat_normalization
=
tex_slat_normalization
self
.
image_cond_model
=
image_cond_model
self
.
rembg_model
=
rembg_model
self
.
low_vram
=
low_vram
self
.
pbr_attr_layout
=
{
'base_color'
:
slice
(
0
,
3
),
'metallic'
:
slice
(
3
,
4
),
'roughness'
:
slice
(
4
,
5
),
'alpha'
:
slice
(
5
,
6
),
}
self
.
_device
=
'cpu'
@
classmethod
def
from_pretrained
(
cls
,
path
:
str
,
config_file
:
str
=
"pipeline.json"
)
->
"Trellis2TexturingPipeline"
:
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline
=
super
().
from_pretrained
(
path
,
config_file
)
args
=
pipeline
.
_pretrained_args
pipeline
.
tex_slat_sampler
=
getattr
(
samplers
,
args
[
'tex_slat_sampler'
][
'name'
])(
**
args
[
'tex_slat_sampler'
][
'args'
])
pipeline
.
tex_slat_sampler_params
=
args
[
'tex_slat_sampler'
][
'params'
]
pipeline
.
shape_slat_normalization
=
args
[
'shape_slat_normalization'
]
pipeline
.
tex_slat_normalization
=
args
[
'tex_slat_normalization'
]
pipeline
.
image_cond_model
=
getattr
(
image_feature_extractor
,
args
[
'image_cond_model'
][
'name'
])(
**
args
[
'image_cond_model'
][
'args'
])
pipeline
.
rembg_model
=
getattr
(
rembg
,
args
[
'rembg_model'
][
'name'
])(
**
args
[
'rembg_model'
][
'args'
])
pipeline
.
low_vram
=
args
.
get
(
'low_vram'
,
True
)
pipeline
.
pbr_attr_layout
=
{
'base_color'
:
slice
(
0
,
3
),
'metallic'
:
slice
(
3
,
4
),
'roughness'
:
slice
(
4
,
5
),
'alpha'
:
slice
(
5
,
6
),
}
pipeline
.
_device
=
'cpu'
return
pipeline
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
self
.
_device
=
device
if
not
self
.
low_vram
:
super
().
to
(
device
)
self
.
image_cond_model
.
to
(
device
)
if
self
.
rembg_model
is
not
None
:
self
.
rembg_model
.
to
(
device
)
def
preprocess_mesh
(
self
,
mesh
:
trimesh
.
Trimesh
)
->
trimesh
.
Trimesh
:
"""
Preprocess the input mesh.
"""
vertices
=
mesh
.
vertices
vertices_min
=
vertices
.
min
(
axis
=
0
)
vertices_max
=
vertices
.
max
(
axis
=
0
)
center
=
(
vertices_min
+
vertices_max
)
/
2
scale
=
0.99999
/
(
vertices_max
-
vertices_min
).
max
()
vertices
=
(
vertices
-
center
)
*
scale
tmp
=
vertices
[:,
1
].
copy
()
vertices
[:,
1
]
=
-
vertices
[:,
2
]
vertices
[:,
2
]
=
tmp
assert
np
.
all
(
vertices
>=
-
0.5
)
and
np
.
all
(
vertices
<=
0.5
),
'vertices out of range'
return
trimesh
.
Trimesh
(
vertices
=
vertices
,
faces
=
mesh
.
faces
,
process
=
False
)
def
preprocess_image
(
self
,
input
:
Image
.
Image
)
->
Image
.
Image
:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha
=
False
if
input
.
mode
==
'RGBA'
:
alpha
=
np
.
array
(
input
)[:,
:,
3
]
if
not
np
.
all
(
alpha
==
255
):
has_alpha
=
True
max_size
=
max
(
input
.
size
)
scale
=
min
(
1
,
1024
/
max_size
)
if
scale
<
1
:
input
=
input
.
resize
((
int
(
input
.
width
*
scale
),
int
(
input
.
height
*
scale
)),
Image
.
Resampling
.
LANCZOS
)
if
has_alpha
:
output
=
input
else
:
input
=
input
.
convert
(
'RGB'
)
if
self
.
low_vram
:
self
.
rembg_model
.
to
(
self
.
device
)
output
=
self
.
rembg_model
(
input
)
if
self
.
low_vram
:
self
.
rembg_model
.
cpu
()
output_np
=
np
.
array
(
output
)
alpha
=
output_np
[:,
:,
3
]
bbox
=
np
.
argwhere
(
alpha
>
0.8
*
255
)
bbox
=
np
.
min
(
bbox
[:,
1
]),
np
.
min
(
bbox
[:,
0
]),
np
.
max
(
bbox
[:,
1
]),
np
.
max
(
bbox
[:,
0
])
center
=
(
bbox
[
0
]
+
bbox
[
2
])
/
2
,
(
bbox
[
1
]
+
bbox
[
3
])
/
2
size
=
max
(
bbox
[
2
]
-
bbox
[
0
],
bbox
[
3
]
-
bbox
[
1
])
size
=
int
(
size
*
1
)
bbox
=
center
[
0
]
-
size
//
2
,
center
[
1
]
-
size
//
2
,
center
[
0
]
+
size
//
2
,
center
[
1
]
+
size
//
2
output
=
output
.
crop
(
bbox
)
# type: ignore
output
=
np
.
array
(
output
).
astype
(
np
.
float32
)
/
255
output
=
output
[:,
:,
:
3
]
*
output
[:,
:,
3
:
4
]
output
=
Image
.
fromarray
((
output
*
255
).
astype
(
np
.
uint8
))
return
output
def
get_cond
(
self
,
image
:
Union
[
torch
.
Tensor
,
list
[
Image
.
Image
]],
resolution
:
int
,
include_neg_cond
:
bool
=
True
)
->
dict
:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
self
.
image_cond_model
.
image_size
=
resolution
if
self
.
low_vram
:
self
.
image_cond_model
.
to
(
self
.
device
)
cond
=
self
.
image_cond_model
(
image
)
if
self
.
low_vram
:
self
.
image_cond_model
.
cpu
()
if
not
include_neg_cond
:
return
{
'cond'
:
cond
}
neg_cond
=
torch
.
zeros_like
(
cond
)
return
{
'cond'
:
cond
,
'neg_cond'
:
neg_cond
,
}
def
encode_shape_slat
(
self
,
mesh
:
trimesh
.
Trimesh
,
resolution
:
int
=
1024
,
)
->
SparseTensor
:
"""
Encode the meshes to structured latent.
Args:
mesh (trimesh.Trimesh): The mesh to encode.
resolution (int): The resolution of mesh
Returns:
SparseTensor: The encoded structured latent.
"""
vertices
=
torch
.
from_numpy
(
mesh
.
vertices
).
float
()
faces
=
torch
.
from_numpy
(
mesh
.
faces
).
long
()
voxel_indices
,
dual_vertices
,
intersected
=
o_voxel
.
convert
.
mesh_to_flexible_dual_grid
(
vertices
.
cpu
(),
faces
.
cpu
(),
grid_size
=
resolution
,
aabb
=
[[
-
0.5
,
-
0.5
,
-
0.5
],[
0.5
,
0.5
,
0.5
]],
face_weight
=
1.0
,
boundary_weight
=
0.2
,
regularization_weight
=
1e-2
,
timing
=
True
,
)
vertices
=
SparseTensor
(
feats
=
dual_vertices
*
resolution
-
voxel_indices
,
coords
=
torch
.
cat
([
torch
.
zeros_like
(
voxel_indices
[:,
0
:
1
]),
voxel_indices
],
dim
=-
1
)
).
to
(
self
.
device
)
intersected
=
vertices
.
replace
(
intersected
).
to
(
self
.
device
)
if
self
.
low_vram
:
self
.
models
[
'shape_slat_encoder'
].
to
(
self
.
device
)
shape_slat
=
self
.
models
[
'shape_slat_encoder'
](
vertices
,
intersected
)
if
self
.
low_vram
:
self
.
models
[
'shape_slat_encoder'
].
cpu
()
return
shape_slat
def
sample_tex_slat
(
self
,
cond
:
dict
,
flow_model
,
shape_slat
:
SparseTensor
,
sampler_params
:
dict
=
{},
)
->
SparseTensor
:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
shape_slat (SparseTensor): The structured latent for shape
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
std
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'std'
])[
None
].
to
(
shape_slat
.
device
)
mean
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'mean'
])[
None
].
to
(
shape_slat
.
device
)
shape_slat
=
(
shape_slat
-
mean
)
/
std
in_channels
=
flow_model
.
in_channels
if
isinstance
(
flow_model
,
nn
.
Module
)
else
flow_model
[
0
].
in_channels
noise
=
shape_slat
.
replace
(
feats
=
torch
.
randn
(
shape_slat
.
coords
.
shape
[
0
],
in_channels
-
shape_slat
.
feats
.
shape
[
1
]).
to
(
self
.
device
))
sampler_params
=
{
**
self
.
tex_slat_sampler_params
,
**
sampler_params
}
if
self
.
low_vram
:
flow_model
.
to
(
self
.
device
)
slat
=
self
.
tex_slat_sampler
.
sample
(
flow_model
,
noise
,
concat_cond
=
shape_slat
,
**
cond
,
**
sampler_params
,
verbose
=
True
,
tqdm_desc
=
"Sampling texture SLat"
,
).
samples
if
self
.
low_vram
:
flow_model
.
cpu
()
std
=
torch
.
tensor
(
self
.
tex_slat_normalization
[
'std'
])[
None
].
to
(
slat
.
device
)
mean
=
torch
.
tensor
(
self
.
tex_slat_normalization
[
'mean'
])[
None
].
to
(
slat
.
device
)
slat
=
slat
*
std
+
mean
return
slat
def
decode_tex_slat
(
self
,
slat
:
SparseTensor
,
)
->
SparseTensor
:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
Returns:
SparseTensor: The decoded texture voxels
"""
if
self
.
low_vram
:
self
.
models
[
'tex_slat_decoder'
].
to
(
self
.
device
)
ret
=
self
.
models
[
'tex_slat_decoder'
](
slat
)
*
0.5
+
0.5
if
self
.
low_vram
:
self
.
models
[
'tex_slat_decoder'
].
cpu
()
return
ret
def
postprocess_mesh
(
self
,
mesh
:
trimesh
.
Trimesh
,
pbr_voxel
:
SparseTensor
,
resolution
:
int
=
1024
,
texture_size
:
int
=
1024
,
)
->
trimesh
.
Trimesh
:
vertices
=
mesh
.
vertices
faces
=
mesh
.
faces
normals
=
mesh
.
vertex_normals
vertices_torch
=
torch
.
from_numpy
(
vertices
).
float
().
cuda
()
faces_torch
=
torch
.
from_numpy
(
faces
).
int
().
cuda
()
if
hasattr
(
mesh
,
'visual'
)
and
hasattr
(
mesh
.
visual
,
'uv'
)
and
mesh
.
visual
.
uv
is
not
None
:
uvs
=
mesh
.
visual
.
uv
.
copy
()
uvs
[:,
1
]
=
1
-
uvs
[:,
1
]
uvs_torch
=
torch
.
from_numpy
(
uvs
).
float
().
cuda
()
else
:
_cumesh
=
cumesh
.
CuMesh
()
_cumesh
.
init
(
vertices_torch
,
faces_torch
)
vertices_torch
,
faces_torch
,
uvs_torch
,
vmap
=
_cumesh
.
uv_unwrap
(
return_vmaps
=
True
)
vertices_torch
=
vertices_torch
.
cuda
()
faces_torch
=
faces_torch
.
cuda
()
uvs_torch
=
uvs_torch
.
cuda
()
vertices
=
vertices_torch
.
cpu
().
numpy
()
faces
=
faces_torch
.
cpu
().
numpy
()
uvs
=
uvs_torch
.
cpu
().
numpy
()
normals
=
normals
[
vmap
.
cpu
().
numpy
()]
# rasterize
ctx
=
dr
.
RasterizeCudaContext
()
uvs_torch
=
torch
.
cat
([
uvs_torch
*
2
-
1
,
torch
.
zeros_like
(
uvs_torch
[:,
:
1
]),
torch
.
ones_like
(
uvs_torch
[:,
:
1
])],
dim
=-
1
).
unsqueeze
(
0
)
rast
,
_
=
dr
.
rasterize
(
ctx
,
uvs_torch
,
faces_torch
,
resolution
=
[
texture_size
,
texture_size
],
)
mask
=
rast
[
0
,
...,
3
]
>
0
pos
=
dr
.
interpolate
(
vertices_torch
.
unsqueeze
(
0
),
rast
,
faces_torch
)[
0
][
0
]
attrs
=
torch
.
zeros
(
texture_size
,
texture_size
,
pbr_voxel
.
shape
[
1
],
device
=
self
.
device
)
if
mask
.
any
():
attrs
[
mask
]
=
flex_gemm
.
ops
.
grid_sample
.
grid_sample_3d
(
pbr_voxel
.
feats
,
pbr_voxel
.
coords
,
shape
=
torch
.
Size
([
*
pbr_voxel
.
shape
,
*
pbr_voxel
.
spatial_shape
]),
grid
=
((
pos
[
mask
]
+
0.5
)
*
resolution
).
reshape
(
1
,
-
1
,
3
),
mode
=
'trilinear'
,
)
# construct mesh
mask
=
mask
.
cpu
().
numpy
()
base_color
=
np
.
clip
(
attrs
[...,
self
.
pbr_attr_layout
[
'base_color'
]].
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
metallic
=
np
.
clip
(
attrs
[...,
self
.
pbr_attr_layout
[
'metallic'
]].
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
roughness
=
np
.
clip
(
attrs
[...,
self
.
pbr_attr_layout
[
'roughness'
]].
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
alpha
=
np
.
clip
(
attrs
[...,
self
.
pbr_attr_layout
[
'alpha'
]].
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
# extend
mask
=
(
~
mask
).
astype
(
np
.
uint8
)
base_color
=
cv2
.
inpaint
(
base_color
,
mask
,
3
,
cv2
.
INPAINT_TELEA
)
metallic
=
cv2
.
inpaint
(
metallic
,
mask
,
1
,
cv2
.
INPAINT_TELEA
)[...,
None
]
roughness
=
cv2
.
inpaint
(
roughness
,
mask
,
1
,
cv2
.
INPAINT_TELEA
)[...,
None
]
alpha
=
cv2
.
inpaint
(
alpha
,
mask
,
1
,
cv2
.
INPAINT_TELEA
)[...,
None
]
material
=
trimesh
.
visual
.
material
.
PBRMaterial
(
baseColorTexture
=
Image
.
fromarray
(
np
.
concatenate
([
base_color
,
alpha
],
axis
=-
1
)),
baseColorFactor
=
np
.
array
([
255
,
255
,
255
,
255
],
dtype
=
np
.
uint8
),
metallicRoughnessTexture
=
Image
.
fromarray
(
np
.
concatenate
([
np
.
zeros_like
(
metallic
),
roughness
,
metallic
],
axis
=-
1
)),
metallicFactor
=
1.0
,
roughnessFactor
=
1.0
,
alphaMode
=
'OPAQUE'
,
doubleSided
=
True
,
)
# Swap Y and Z axes, invert Y (common conversion for GLB compatibility)
vertices
[:,
1
],
vertices
[:,
2
]
=
vertices
[:,
2
],
-
vertices
[:,
1
]
normals
[:,
1
],
normals
[:,
2
]
=
normals
[:,
2
],
-
normals
[:,
1
]
uvs
[:,
1
]
=
1
-
uvs
[:,
1
]
# Flip UV V-coordinate
textured_mesh
=
trimesh
.
Trimesh
(
vertices
=
vertices
,
faces
=
faces
,
vertex_normals
=
normals
,
process
=
False
,
visual
=
trimesh
.
visual
.
TextureVisuals
(
uv
=
uvs
,
material
=
material
)
)
return
textured_mesh
@
torch
.
no_grad
()
def
run
(
self
,
mesh
:
trimesh
.
Trimesh
,
image
:
Image
.
Image
,
seed
:
int
=
42
,
tex_slat_sampler_params
:
dict
=
{},
preprocess_image
:
bool
=
True
,
resolution
:
int
=
1024
,
texture_size
:
int
=
2048
,
)
->
trimesh
.
Trimesh
:
"""
Run the pipeline.
Args:
mesh (trimesh.Trimesh): The mesh to texture.
image (Image.Image): The image prompt.
seed (int): The random seed.
tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler.
preprocess_image (bool): Whether to preprocess the image.
"""
if
preprocess_image
:
image
=
self
.
preprocess_image
(
image
)
mesh
=
self
.
preprocess_mesh
(
mesh
)
torch
.
manual_seed
(
seed
)
cond
=
self
.
get_cond
([
image
],
512
)
if
resolution
==
512
else
self
.
get_cond
([
image
],
1024
)
shape_slat
=
self
.
encode_shape_slat
(
mesh
,
resolution
)
tex_model
=
self
.
models
[
'tex_slat_flow_model_512'
]
if
resolution
==
512
else
self
.
models
[
'tex_slat_flow_model_1024'
]
tex_slat
=
self
.
sample_tex_slat
(
cond
,
tex_model
,
shape_slat
,
tex_slat_sampler_params
)
pbr_voxel
=
self
.
decode_tex_slat
(
tex_slat
)
out_mesh
=
self
.
postprocess_mesh
(
mesh
,
pbr_voxel
,
resolution
,
texture_size
)
return
out_mesh
TRELLIS.2_DCU/trellis2/renderers/__init__.py
0 → 100644
View file @
f05e915f
import
importlib
__attributes
=
{
'MeshRenderer'
:
'mesh_renderer'
,
'VoxelRenderer'
:
'voxel_renderer'
,
'PbrMeshRenderer'
:
'pbr_mesh_renderer'
,
'EnvMap'
:
'pbr_mesh_renderer'
,
}
__submodules
=
[]
__all__
=
list
(
__attributes
.
keys
())
+
__submodules
def
__getattr__
(
name
):
if
name
not
in
globals
():
if
name
in
__attributes
:
module_name
=
__attributes
[
name
]
module
=
importlib
.
import_module
(
f
".
{
module_name
}
"
,
__name__
)
globals
()[
name
]
=
getattr
(
module
,
name
)
elif
name
in
__submodules
:
module
=
importlib
.
import_module
(
f
".
{
name
}
"
,
__name__
)
globals
()[
name
]
=
module
else
:
raise
AttributeError
(
f
"module
{
__name__
}
has no attribute
{
name
}
"
)
return
globals
()[
name
]
# For Pylance
if
__name__
==
'__main__'
:
from
.mesh_renderer
import
MeshRenderer
from
.voxel_renderer
import
VoxelRenderer
from
.pbr_mesh_renderer
import
PbrMeshRenderer
,
EnvMap
\ No newline at end of file
TRELLIS.2_DCU/trellis2/renderers/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/renderers/__pycache__/mesh_renderer.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/renderers/__pycache__/pbr_mesh_renderer.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/renderers/__pycache__/voxel_renderer.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/renderers/mesh_renderer.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
from
easydict
import
EasyDict
as
edict
from
..representations.mesh
import
Mesh
,
MeshWithVoxel
,
MeshWithPbrMaterial
,
TextureFilterMode
,
AlphaMode
,
TextureWrapMode
import
torch.nn.functional
as
F
def
intrinsics_to_projection
(
intrinsics
:
torch
.
Tensor
,
near
:
float
,
far
:
float
,
)
->
torch
.
Tensor
:
"""
OpenCV intrinsics to OpenGL perspective matrix
Args:
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
near (float): near plane to clip
far (float): far plane to clip
Returns:
(torch.Tensor): [4, 4] OpenGL perspective matrix
"""
fx
,
fy
=
intrinsics
[
0
,
0
],
intrinsics
[
1
,
1
]
cx
,
cy
=
intrinsics
[
0
,
2
],
intrinsics
[
1
,
2
]
ret
=
torch
.
zeros
((
4
,
4
),
dtype
=
intrinsics
.
dtype
,
device
=
intrinsics
.
device
)
ret
[
0
,
0
]
=
2
*
fx
ret
[
1
,
1
]
=
2
*
fy
ret
[
0
,
2
]
=
2
*
cx
-
1
ret
[
1
,
2
]
=
-
2
*
cy
+
1
ret
[
2
,
2
]
=
(
far
+
near
)
/
(
far
-
near
)
ret
[
2
,
3
]
=
2
*
near
*
far
/
(
near
-
far
)
ret
[
3
,
2
]
=
1.
return
ret
class
MeshRenderer
:
"""
Renderer for the Mesh representation.
Args:
rendering_options (dict): Rendering options.
"""
def
__init__
(
self
,
rendering_options
=
{},
device
=
'cuda'
):
if
'dr'
not
in
globals
():
import
nvdiffrast.torch
as
dr
self
.
rendering_options
=
edict
({
"resolution"
:
None
,
"near"
:
None
,
"far"
:
None
,
"ssaa"
:
1
,
"chunk_size"
:
None
,
"antialias"
:
True
,
"clamp_barycentric_coords"
:
False
,
})
self
.
rendering_options
.
update
(
rendering_options
)
self
.
glctx
=
dr
.
RasterizeCudaContext
(
device
=
device
)
self
.
device
=
device
def
render
(
self
,
mesh
:
Mesh
,
extrinsics
:
torch
.
Tensor
,
intrinsics
:
torch
.
Tensor
,
return_types
=
[
"mask"
,
"normal"
,
"depth"
],
transformation
:
Optional
[
torch
.
Tensor
]
=
None
)
->
edict
:
"""
Render the mesh.
Args:
mesh : meshmodel
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
return_types (list): list of return types, can be "attr", "mask", "depth", "coord", "normal"
Returns:
edict based on return_types containing:
attr (torch.Tensor): [C, H, W] rendered attr image
depth (torch.Tensor): [H, W] rendered depth image
normal (torch.Tensor): [3, H, W] rendered normal image
mask (torch.Tensor): [H, W] rendered mask image
"""
if
'dr'
not
in
globals
():
import
nvdiffrast.torch
as
dr
resolution
=
self
.
rendering_options
[
"resolution"
]
near
=
self
.
rendering_options
[
"near"
]
far
=
self
.
rendering_options
[
"far"
]
ssaa
=
self
.
rendering_options
[
"ssaa"
]
chunk_size
=
self
.
rendering_options
[
"chunk_size"
]
antialias
=
self
.
rendering_options
[
"antialias"
]
clamp_barycentric_coords
=
self
.
rendering_options
[
"clamp_barycentric_coords"
]
if
mesh
.
vertices
.
shape
[
0
]
==
0
or
mesh
.
faces
.
shape
[
0
]
==
0
:
ret_dict
=
edict
()
for
type
in
return_types
:
if
type
==
"mask"
:
ret_dict
[
type
]
=
torch
.
zeros
((
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
elif
type
==
"depth"
:
ret_dict
[
type
]
=
torch
.
zeros
((
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
elif
type
==
"normal"
:
ret_dict
[
type
]
=
torch
.
full
((
3
,
resolution
,
resolution
),
0.5
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
elif
type
==
"coord"
:
ret_dict
[
type
]
=
torch
.
zeros
((
3
,
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
elif
type
==
"attr"
:
if
isinstance
(
mesh
,
MeshWithVoxel
):
ret_dict
[
type
]
=
torch
.
zeros
((
mesh
.
attrs
.
shape
[
-
1
],
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
else
:
ret_dict
[
type
]
=
torch
.
zeros
((
mesh
.
vertex_attrs
.
shape
[
-
1
],
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
return
ret_dict
perspective
=
intrinsics_to_projection
(
intrinsics
,
near
,
far
)
full_proj
=
(
perspective
@
extrinsics
).
unsqueeze
(
0
)
extrinsics
=
extrinsics
.
unsqueeze
(
0
)
vertices
=
mesh
.
vertices
.
unsqueeze
(
0
)
vertices_homo
=
torch
.
cat
([
vertices
,
torch
.
ones_like
(
vertices
[...,
:
1
])],
dim
=-
1
)
if
transformation
is
not
None
:
vertices_homo
=
torch
.
bmm
(
vertices_homo
,
transformation
.
unsqueeze
(
0
).
transpose
(
-
1
,
-
2
))
vertices
=
vertices_homo
[...,
:
3
].
contiguous
()
vertices_camera
=
torch
.
bmm
(
vertices_homo
,
extrinsics
.
transpose
(
-
1
,
-
2
))
vertices_clip
=
torch
.
bmm
(
vertices_homo
,
full_proj
.
transpose
(
-
1
,
-
2
))
faces
=
mesh
.
faces
if
'normal'
in
return_types
:
v0
=
vertices_camera
[
0
,
mesh
.
faces
[:,
0
],
:
3
]
v1
=
vertices_camera
[
0
,
mesh
.
faces
[:,
1
],
:
3
]
v2
=
vertices_camera
[
0
,
mesh
.
faces
[:,
2
],
:
3
]
e0
=
v1
-
v0
e1
=
v2
-
v0
face_normal
=
torch
.
cross
(
e0
,
e1
,
dim
=
1
)
face_normal
=
F
.
normalize
(
face_normal
,
dim
=
1
)
face_normal
=
torch
.
where
(
torch
.
sum
(
face_normal
*
v0
,
dim
=
1
,
keepdim
=
True
)
>
0
,
face_normal
,
-
face_normal
)
out_dict
=
edict
()
if
chunk_size
is
None
:
rast
,
rast_db
=
dr
.
rasterize
(
self
.
glctx
,
vertices_clip
,
faces
,
(
resolution
*
ssaa
,
resolution
*
ssaa
)
)
if
clamp_barycentric_coords
:
rast
[...,
:
2
]
=
torch
.
clamp
(
rast
[...,
:
2
],
0
,
1
)
rast
[...,
:
2
]
/=
torch
.
where
(
rast
[...,
:
2
].
sum
(
dim
=-
1
,
keepdim
=
True
)
>
1
,
rast
[...,
:
2
].
sum
(
dim
=-
1
,
keepdim
=
True
),
torch
.
ones_like
(
rast
[...,
:
2
]))
for
type
in
return_types
:
img
=
None
if
type
==
"mask"
:
img
=
(
rast
[...,
-
1
:]
>
0
).
float
()
if
antialias
:
img
=
dr
.
antialias
(
img
,
rast
,
vertices_clip
,
faces
)
elif
type
==
"depth"
:
img
=
dr
.
interpolate
(
vertices_camera
[...,
2
:
3
].
contiguous
(),
rast
,
faces
)[
0
]
if
antialias
:
img
=
dr
.
antialias
(
img
,
rast
,
vertices_clip
,
faces
)
elif
type
==
"normal"
:
img
=
dr
.
interpolate
(
face_normal
.
unsqueeze
(
0
),
rast
,
torch
.
arange
(
face_normal
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
self
.
device
).
unsqueeze
(
1
).
repeat
(
1
,
3
).
contiguous
())[
0
]
if
antialias
:
img
=
dr
.
antialias
(
img
,
rast
,
vertices_clip
,
faces
)
img
=
(
img
+
1
)
/
2
elif
type
==
"coord"
:
img
=
dr
.
interpolate
(
vertices
,
rast
,
faces
)[
0
]
if
antialias
:
img
=
dr
.
antialias
(
img
,
rast
,
vertices_clip
,
faces
)
elif
type
==
"attr"
:
if
isinstance
(
mesh
,
MeshWithVoxel
):
if
'grid_sample_3d'
not
in
globals
():
from
flex_gemm.ops.grid_sample
import
grid_sample_3d
mask
=
rast
[...,
-
1
:]
>
0
xyz
=
dr
.
interpolate
(
vertices
,
rast
,
faces
)[
0
]
xyz
=
((
xyz
-
mesh
.
origin
)
/
mesh
.
voxel_size
).
reshape
(
1
,
-
1
,
3
)
img
=
grid_sample_3d
(
mesh
.
attrs
,
torch
.
cat
([
torch
.
zeros_like
(
mesh
.
coords
[...,
:
1
]),
mesh
.
coords
],
dim
=-
1
),
mesh
.
voxel_shape
,
xyz
,
mode
=
'trilinear'
)
img
=
img
.
reshape
(
1
,
resolution
*
ssaa
,
resolution
*
ssaa
,
mesh
.
attrs
.
shape
[
-
1
])
*
mask
elif
isinstance
(
mesh
,
MeshWithPbrMaterial
):
tri_id
=
rast
[
0
,
:,
:,
-
1
:]
mask
=
tri_id
>
0
uv_coords
=
mesh
.
uv_coords
.
reshape
(
1
,
-
1
,
2
)
texc
,
texd
=
dr
.
interpolate
(
uv_coords
,
rast
,
torch
.
arange
(
mesh
.
uv_coords
.
shape
[
0
]
*
3
,
dtype
=
torch
.
int
,
device
=
self
.
device
).
reshape
(
-
1
,
3
),
rast_db
=
rast_db
,
diff_attrs
=
'all'
)
# Fix problematic texture coordinates
texc
=
torch
.
nan_to_num
(
texc
,
nan
=
0.0
,
posinf
=
1e3
,
neginf
=-
1e3
)
texc
=
torch
.
clamp
(
texc
,
min
=-
1e3
,
max
=
1e3
)
texd
=
torch
.
nan_to_num
(
texd
,
nan
=
0.0
,
posinf
=
1e3
,
neginf
=-
1e3
)
texd
=
torch
.
clamp
(
texd
,
min
=-
1e3
,
max
=
1e3
)
mid
=
mesh
.
material_ids
[(
tri_id
-
1
).
long
()]
imgs
=
{
'base_color'
:
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
3
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
'metallic'
:
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
'roughness'
:
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
'alpha'
:
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
}
for
id
,
mat
in
enumerate
(
mesh
.
materials
):
mat_mask
=
(
mid
==
id
).
float
()
*
mask
.
float
()
mat_texc
=
texc
*
mat_mask
mat_texd
=
texd
*
mat_mask
if
mat
.
base_color_texture
is
not
None
:
base_color
=
dr
.
texture
(
mat
.
base_color_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
base_color_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
base_color_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
imgs
[
'base_color'
]
+=
base_color
*
mat
.
base_color_factor
*
mat_mask
else
:
imgs
[
'base_color'
]
+=
mat
.
base_color_factor
*
mat_mask
if
mat
.
metallic_texture
is
not
None
:
metallic
=
dr
.
texture
(
mat
.
metallic_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
metallic_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
metallic_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
imgs
[
'metallic'
]
+=
metallic
*
mat
.
metallic_factor
*
mat_mask
else
:
imgs
[
'metallic'
]
+=
mat
.
metallic_factor
*
mat_mask
if
mat
.
roughness_texture
is
not
None
:
roughness
=
dr
.
texture
(
mat
.
roughness_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
roughness_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
roughness_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
imgs
[
'roughness'
]
+=
roughness
*
mat
.
roughness_factor
*
mat_mask
else
:
imgs
[
'roughness'
]
+=
mat
.
roughness_factor
*
mat_mask
if
mat
.
alpha_mode
==
AlphaMode
.
OPAQUE
:
imgs
[
'alpha'
]
+=
1.0
*
mat_mask
else
:
if
mat
.
alpha_texture
is
not
None
:
alpha
=
dr
.
texture
(
mat
.
alpha_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
alpha_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
alpha_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
if
mat
.
alpha_mode
==
AlphaMode
.
MASK
:
imgs
[
'alpha'
]
+=
(
alpha
*
mat
.
alpha_factor
>
mat
.
alpha_cutoff
).
float
()
*
mat_mask
elif
mat
.
alpha_mode
==
AlphaMode
.
BLEND
:
imgs
[
'alpha'
]
+=
alpha
*
mat
.
alpha_factor
*
mat_mask
else
:
if
mat
.
alpha_mode
==
AlphaMode
.
MASK
:
imgs
[
'alpha'
]
+=
(
mat
.
alpha_factor
>
mat
.
alpha_cutoff
).
float
()
*
mat_mask
elif
mat
.
alpha_mode
==
AlphaMode
.
BLEND
:
imgs
[
'alpha'
]
+=
mat
.
alpha_factor
*
mat_mask
img
=
torch
.
cat
([
imgs
[
name
]
for
name
in
imgs
.
keys
()],
dim
=-
1
).
unsqueeze
(
0
)
else
:
img
=
dr
.
interpolate
(
mesh
.
vertex_attrs
.
unsqueeze
(
0
),
rast
,
faces
)[
0
]
if
antialias
:
img
=
dr
.
antialias
(
img
,
rast
,
vertices_clip
,
faces
)
out_dict
[
type
]
=
img
else
:
z_buffer
=
torch
.
full
((
1
,
resolution
*
ssaa
,
resolution
*
ssaa
),
torch
.
inf
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
for
i
in
range
(
0
,
faces
.
shape
[
0
],
chunk_size
):
faces_chunk
=
faces
[
i
:
i
+
chunk_size
]
rast
,
rast_db
=
dr
.
rasterize
(
self
.
glctx
,
vertices_clip
,
faces_chunk
,
(
resolution
*
ssaa
,
resolution
*
ssaa
)
)
z_filter
=
torch
.
logical_and
(
rast
[...,
3
]
!=
0
,
rast
[...,
2
]
<
z_buffer
)
z_buffer
[
z_filter
]
=
rast
[
z_filter
][...,
2
]
for
type
in
return_types
:
img
=
None
if
type
==
"mask"
:
img
=
(
rast
[...,
-
1
:]
>
0
).
float
()
elif
type
==
"depth"
:
img
=
dr
.
interpolate
(
vertices_camera
[...,
2
:
3
].
contiguous
(),
rast
,
faces_chunk
)[
0
]
elif
type
==
"normal"
:
face_normal_chunk
=
face_normal
[
i
:
i
+
chunk_size
]
img
=
dr
.
interpolate
(
face_normal_chunk
.
unsqueeze
(
0
),
rast
,
torch
.
arange
(
face_normal_chunk
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
self
.
device
).
unsqueeze
(
1
).
repeat
(
1
,
3
).
contiguous
())[
0
]
img
=
(
img
+
1
)
/
2
elif
type
==
"coord"
:
img
=
dr
.
interpolate
(
vertices
,
rast
,
faces_chunk
)[
0
]
elif
type
==
"attr"
:
if
isinstance
(
mesh
,
MeshWithVoxel
):
if
'grid_sample_3d'
not
in
globals
():
from
flex_gemm.ops.grid_sample
import
grid_sample_3d
mask
=
rast
[...,
-
1
:]
>
0
xyz
=
dr
.
interpolate
(
vertices
,
rast
,
faces_chunk
)[
0
]
xyz
=
((
xyz
-
mesh
.
origin
)
/
mesh
.
voxel_size
).
reshape
(
1
,
-
1
,
3
)
img
=
grid_sample_3d
(
mesh
.
attrs
,
torch
.
cat
([
torch
.
zeros_like
(
mesh
.
coords
[...,
:
1
]),
mesh
.
coords
],
dim
=-
1
),
mesh
.
voxel_shape
,
xyz
,
mode
=
'trilinear'
)
img
=
img
.
reshape
(
1
,
resolution
*
ssaa
,
resolution
*
ssaa
,
mesh
.
attrs
.
shape
[
-
1
])
*
mask
elif
isinstance
(
mesh
,
MeshWithPbrMaterial
):
tri_id
=
rast
[
0
,
:,
:,
-
1
:]
mask
=
tri_id
>
0
uv_coords
=
mesh
.
uv_coords
.
reshape
(
1
,
-
1
,
2
)
texc
,
texd
=
dr
.
interpolate
(
uv_coords
,
rast
,
torch
.
arange
(
mesh
.
uv_coords
.
shape
[
0
]
*
3
,
dtype
=
torch
.
int
,
device
=
self
.
device
).
reshape
(
-
1
,
3
),
rast_db
=
rast_db
,
diff_attrs
=
'all'
)
# Fix problematic texture coordinates
texc
=
torch
.
nan_to_num
(
texc
,
nan
=
0.0
,
posinf
=
1e3
,
neginf
=-
1e3
)
texc
=
torch
.
clamp
(
texc
,
min
=-
1e3
,
max
=
1e3
)
texd
=
torch
.
nan_to_num
(
texd
,
nan
=
0.0
,
posinf
=
1e3
,
neginf
=-
1e3
)
texd
=
torch
.
clamp
(
texd
,
min
=-
1e3
,
max
=
1e3
)
mid
=
mesh
.
material_ids
[(
tri_id
-
1
).
long
()]
imgs
=
{
'base_color'
:
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
3
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
'metallic'
:
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
'roughness'
:
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
'alpha'
:
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
}
for
id
,
mat
in
enumerate
(
mesh
.
materials
):
mat_mask
=
(
mid
==
id
).
float
()
*
mask
.
float
()
mat_texc
=
texc
*
mat_mask
mat_texd
=
texd
*
mat_mask
if
mat
.
base_color_texture
is
not
None
:
base_color
=
dr
.
texture
(
mat
.
base_color_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
base_color_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
base_color_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
imgs
[
'base_color'
]
+=
base_color
*
mat
.
base_color_factor
*
mat_mask
else
:
imgs
[
'base_color'
]
+=
mat
.
base_color_factor
*
mat_mask
if
mat
.
metallic_texture
is
not
None
:
metallic
=
dr
.
texture
(
mat
.
metallic_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
metallic_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
metallic_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
imgs
[
'metallic'
]
+=
metallic
*
mat
.
metallic_factor
*
mat_mask
else
:
imgs
[
'metallic'
]
+=
mat
.
metallic_factor
*
mat_mask
if
mat
.
roughness_texture
is
not
None
:
roughness
=
dr
.
texture
(
mat
.
roughness_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
roughness_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
roughness_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
imgs
[
'roughness'
]
+=
roughness
*
mat
.
roughness_factor
*
mat_mask
else
:
imgs
[
'roughness'
]
+=
mat
.
roughness_factor
*
mat_mask
if
mat
.
alpha_mode
==
AlphaMode
.
OPAQUE
:
imgs
[
'alpha'
]
+=
1.0
*
mat_mask
else
:
if
mat
.
alpha_texture
is
not
None
:
alpha
=
dr
.
texture
(
mat
.
alpha_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
alpha_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
alpha_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
if
mat
.
alpha_mode
==
AlphaMode
.
MASK
:
imgs
[
'alpha'
]
+=
(
alpha
*
mat
.
alpha_factor
>
mat
.
alpha_cutoff
).
float
()
*
mat_mask
elif
mat
.
alpha_mode
==
AlphaMode
.
BLEND
:
imgs
[
'alpha'
]
+=
alpha
*
mat
.
alpha_factor
*
mat_mask
else
:
if
mat
.
alpha_mode
==
AlphaMode
.
MASK
:
imgs
[
'alpha'
]
+=
(
mat
.
alpha_factor
>
mat
.
alpha_cutoff
).
float
()
*
mat_mask
elif
mat
.
alpha_mode
==
AlphaMode
.
BLEND
:
imgs
[
'alpha'
]
+=
mat
.
alpha_factor
*
mat_mask
img
=
torch
.
cat
([
imgs
[
name
]
for
name
in
imgs
.
keys
()],
dim
=-
1
).
unsqueeze
(
0
)
else
:
img
=
dr
.
interpolate
(
mesh
.
vertex_attrs
.
unsqueeze
(
0
),
rast
,
faces_chunk
)[
0
]
if
type
not
in
out_dict
:
out_dict
[
type
]
=
img
else
:
out_dict
[
type
][
z_filter
]
=
img
[
z_filter
]
for
type
in
return_types
:
img
=
out_dict
[
type
]
if
ssaa
>
1
:
img
=
F
.
interpolate
(
img
.
permute
(
0
,
3
,
1
,
2
),
(
resolution
,
resolution
),
mode
=
'bilinear'
,
align_corners
=
False
,
antialias
=
True
)
img
=
img
.
squeeze
()
else
:
img
=
img
.
permute
(
0
,
3
,
1
,
2
).
squeeze
()
out_dict
[
type
]
=
img
if
isinstance
(
mesh
,
(
MeshWithVoxel
,
MeshWithPbrMaterial
))
and
'attr'
in
return_types
:
for
k
,
s
in
mesh
.
layout
.
items
():
out_dict
[
k
]
=
out_dict
[
'attr'
][
s
]
del
out_dict
[
'attr'
]
return
out_dict
TRELLIS.2_DCU/trellis2/renderers/pbr_mesh_renderer.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
from
easydict
import
EasyDict
as
edict
import
numpy
as
np
import
utils3d
from
..representations.mesh
import
Mesh
,
MeshWithVoxel
,
MeshWithPbrMaterial
,
TextureFilterMode
,
AlphaMode
,
TextureWrapMode
import
torch.nn.functional
as
F
from
..utils.pipeline_logger
import
get_logger
,
log_mesh
,
log_uv
,
log_tensor
,
elapsed
,
section
from
..modules.sparse.linear
import
ROCM_SAFE_CHUNK
def
_safe_transform4x4
(
vertices_homo
:
torch
.
Tensor
,
matrix
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Chunked drop-in for torch.bmm(vertices_homo, matrix) to work around the
ROCm GEMM bug where N > ~800k produces corrupt results.
vertices_homo: [B, N, 4] matrix: [B, 4, 4]
"""
B
,
N
,
_
=
vertices_homo
.
shape
if
N
<=
ROCM_SAFE_CHUNK
:
return
torch
.
bmm
(
vertices_homo
,
matrix
)
parts
=
[]
for
s
in
range
(
0
,
N
,
ROCM_SAFE_CHUNK
):
e
=
min
(
s
+
ROCM_SAFE_CHUNK
,
N
)
parts
.
append
(
torch
.
bmm
(
vertices_homo
[:,
s
:
e
,
:],
matrix
))
return
torch
.
cat
(
parts
,
dim
=
1
)
def
cube_to_dir
(
s
,
x
,
y
):
if
s
==
0
:
rx
,
ry
,
rz
=
torch
.
ones_like
(
x
),
-
x
,
-
y
elif
s
==
1
:
rx
,
ry
,
rz
=
-
torch
.
ones_like
(
x
),
x
,
-
y
elif
s
==
2
:
rx
,
ry
,
rz
=
x
,
y
,
torch
.
ones_like
(
x
)
elif
s
==
3
:
rx
,
ry
,
rz
=
x
,
-
y
,
-
torch
.
ones_like
(
x
)
elif
s
==
4
:
rx
,
ry
,
rz
=
x
,
torch
.
ones_like
(
x
),
-
y
elif
s
==
5
:
rx
,
ry
,
rz
=
-
x
,
-
torch
.
ones_like
(
x
),
-
y
return
torch
.
stack
((
rx
,
ry
,
rz
),
dim
=-
1
)
def
latlong_to_cubemap
(
latlong_map
,
res
):
if
'dr'
not
in
globals
():
import
nvdiffrast.torch
as
dr
cubemap
=
torch
.
zeros
(
6
,
res
[
0
],
res
[
1
],
latlong_map
.
shape
[
-
1
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
for
s
in
range
(
6
):
gy
,
gx
=
torch
.
meshgrid
(
torch
.
linspace
(
-
1.0
+
1.0
/
res
[
0
],
1.0
-
1.0
/
res
[
0
],
res
[
0
],
device
=
'cuda'
),
torch
.
linspace
(
-
1.0
+
1.0
/
res
[
1
],
1.0
-
1.0
/
res
[
1
],
res
[
1
],
device
=
'cuda'
),
indexing
=
'ij'
)
v
=
F
.
normalize
(
cube_to_dir
(
s
,
gx
,
gy
),
dim
=-
1
)
tu
=
torch
.
atan2
(
v
[...,
0
:
1
],
-
v
[...,
2
:
3
])
/
(
2
*
np
.
pi
)
+
0.5
tv
=
torch
.
acos
(
torch
.
clamp
(
v
[...,
1
:
2
],
min
=-
1
,
max
=
1
))
/
np
.
pi
texcoord
=
torch
.
cat
((
tu
,
tv
),
dim
=-
1
)
cubemap
[
s
,
...]
=
dr
.
texture
(
latlong_map
[
None
,
...],
texcoord
[
None
,
...],
filter_mode
=
'linear'
)[
0
]
return
cubemap
class
EnvMap
:
def
__init__
(
self
,
image
:
torch
.
Tensor
):
self
.
image
=
image
@
property
def
_backend
(
self
):
if
not
hasattr
(
self
,
'_nvdiffrec_envlight'
):
if
'EnvironmentLight'
not
in
globals
():
from
nvdiffrec_render.light
import
EnvironmentLight
cubemap
=
latlong_to_cubemap
(
self
.
image
,
[
512
,
512
])
self
.
_nvdiffrec_envlight
=
EnvironmentLight
(
cubemap
)
self
.
_nvdiffrec_envlight
.
build_mips
()
return
self
.
_nvdiffrec_envlight
def
shade
(
self
,
gb_pos
,
gb_normal
,
kd
,
ks
,
view_pos
,
specular
=
True
):
return
self
.
_backend
.
shade
(
gb_pos
,
gb_normal
,
kd
,
ks
,
view_pos
,
specular
)
def
sample
(
self
,
directions
:
torch
.
Tensor
):
if
'dr'
not
in
globals
():
import
nvdiffrast.torch
as
dr
return
dr
.
texture
(
self
.
_backend
.
base
.
unsqueeze
(
0
),
directions
.
unsqueeze
(
0
),
boundary_mode
=
'cube'
,
)[
0
]
def
intrinsics_to_projection
(
intrinsics
:
torch
.
Tensor
,
near
:
float
,
far
:
float
,
)
->
torch
.
Tensor
:
"""
OpenCV intrinsics to OpenGL perspective matrix
Args:
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
near (float): near plane to clip
far (float): far plane to clip
Returns:
(torch.Tensor): [4, 4] OpenGL perspective matrix
"""
fx
,
fy
=
intrinsics
[
0
,
0
],
intrinsics
[
1
,
1
]
cx
,
cy
=
intrinsics
[
0
,
2
],
intrinsics
[
1
,
2
]
ret
=
torch
.
zeros
((
4
,
4
),
dtype
=
intrinsics
.
dtype
,
device
=
intrinsics
.
device
)
ret
[
0
,
0
]
=
2
*
fx
ret
[
1
,
1
]
=
2
*
fy
ret
[
0
,
2
]
=
2
*
cx
-
1
ret
[
1
,
2
]
=
-
2
*
cy
+
1
ret
[
2
,
2
]
=
(
far
+
near
)
/
(
far
-
near
)
ret
[
2
,
3
]
=
2
*
near
*
far
/
(
near
-
far
)
ret
[
3
,
2
]
=
1.
return
ret
def
screen_space_ambient_occlusion
(
depth
:
torch
.
Tensor
,
normal
:
torch
.
Tensor
,
perspective
:
torch
.
Tensor
,
radius
:
float
=
0.1
,
bias
:
float
=
1e-6
,
samples
:
int
=
64
,
intensity
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""
Screen space ambient occlusion (SSAO)
Args:
depth (torch.Tensor): [H, W, 1] depth image
normal (torch.Tensor): [H, W, 3] normal image
perspective (torch.Tensor): [4, 4] camera projection matrix
radius (float): radius of the SSAO kernel
bias (float): bias to avoid self-occlusion
samples (int): number of samples to use for the SSAO kernel
intensity (float): intensity of the SSAO effect
Returns:
(torch.Tensor): [H, W, 1] SSAO image
"""
device
=
depth
.
device
H
,
W
,
_
=
depth
.
shape
fx
=
perspective
[
0
,
0
]
fy
=
perspective
[
1
,
1
]
cx
=
perspective
[
0
,
2
]
cy
=
perspective
[
1
,
2
]
y_grid
,
x_grid
=
torch
.
meshgrid
(
(
torch
.
arange
(
H
,
device
=
device
)
+
0.5
)
/
H
*
2
-
1
,
(
torch
.
arange
(
W
,
device
=
device
)
+
0.5
)
/
W
*
2
-
1
,
indexing
=
'ij'
)
x_view
=
(
x_grid
.
float
()
-
cx
)
*
depth
[...,
0
]
/
fx
y_view
=
(
y_grid
.
float
()
-
cy
)
*
depth
[...,
0
]
/
fy
view_pos
=
torch
.
stack
([
x_view
,
y_view
,
depth
[...,
0
]],
dim
=-
1
)
# [H, W, 3]
depth_feat
=
depth
.
permute
(
2
,
0
,
1
).
unsqueeze
(
0
)
occlusion
=
torch
.
zeros
((
H
,
W
),
device
=
device
)
# start sampling
for
_
in
range
(
samples
):
# sample normal distribution, if inside, flip the sign
rnd_vec
=
torch
.
randn
(
H
,
W
,
3
,
device
=
device
)
rnd_vec
=
F
.
normalize
(
rnd_vec
,
p
=
2
,
dim
=-
1
)
dot_val
=
torch
.
sum
(
rnd_vec
*
normal
,
dim
=-
1
,
keepdim
=
True
)
sample_dir
=
torch
.
sign
(
dot_val
)
*
rnd_vec
scale
=
torch
.
rand
(
H
,
W
,
1
,
device
=
device
)
scale
=
scale
*
scale
sample_pos
=
view_pos
+
sample_dir
*
radius
*
scale
sample_z
=
sample_pos
[...,
2
]
# project to screen space
z_safe
=
torch
.
clamp
(
sample_pos
[...,
2
],
min
=
1e-5
)
proj_u
=
(
sample_pos
[...,
0
]
*
fx
/
z_safe
)
+
cx
proj_v
=
(
sample_pos
[...,
1
]
*
fy
/
z_safe
)
+
cy
grid
=
torch
.
stack
([
proj_u
,
proj_v
],
dim
=-
1
).
unsqueeze
(
0
)
geo_z
=
F
.
grid_sample
(
depth_feat
,
grid
,
mode
=
'nearest'
,
padding_mode
=
'border'
).
squeeze
()
range_check
=
torch
.
abs
(
geo_z
-
sample_z
)
<
radius
is_occluded
=
(
geo_z
<=
sample_z
-
bias
)
&
range_check
occlusion
+=
is_occluded
.
float
()
f_occ
=
occlusion
/
samples
*
intensity
f_occ
=
torch
.
clamp
(
f_occ
,
0.0
,
1.0
)
return
f_occ
.
unsqueeze
(
-
1
)
def
aces_tonemapping
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Applies ACES tone mapping curve to an HDR image tensor.
Input: x - HDR tensor, shape (..., 3), range [0, +inf)
Output: LDR tensor, same shape, range [0, 1]
"""
a
=
2.51
b
=
0.03
c
=
2.43
d
=
0.59
e
=
0.14
# Apply the ACES fitted curve
mapped
=
(
x
*
(
a
*
x
+
b
))
/
(
x
*
(
c
*
x
+
d
)
+
e
)
# Clamp to [0, 1] for display or saving
return
torch
.
clamp
(
mapped
,
0.0
,
1.0
)
def
gamma_correction
(
x
:
torch
.
Tensor
,
gamma
:
float
=
2.2
)
->
torch
.
Tensor
:
"""
Applies gamma correction to an HDR image tensor.
"""
return
torch
.
clamp
(
x
**
(
1.0
/
gamma
),
0.0
,
1.0
)
class
PbrMeshRenderer
:
"""
Renderer for the PBR mesh.
Args:
rendering_options (dict): Rendering options.
"""
def
__init__
(
self
,
rendering_options
=
{},
device
=
'cuda'
):
if
'dr'
not
in
globals
():
import
nvdiffrast.torch
as
dr
self
.
rendering_options
=
edict
({
"resolution"
:
None
,
"near"
:
None
,
"far"
:
None
,
"ssaa"
:
1
,
"peel_layers"
:
8
,
})
self
.
rendering_options
.
update
(
rendering_options
)
self
.
glctx
=
dr
.
RasterizeCudaContext
(
device
=
device
)
self
.
device
=
device
def
render
(
self
,
mesh
:
Mesh
,
extrinsics
:
torch
.
Tensor
,
intrinsics
:
torch
.
Tensor
,
envmap
:
Union
[
EnvMap
,
Dict
[
str
,
EnvMap
]],
use_envmap_bg
:
bool
=
False
,
transformation
:
Optional
[
torch
.
Tensor
]
=
None
)
->
edict
:
"""
Render the mesh.
Args:
mesh : meshmodel
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
envmap (Union[EnvMap, Dict[str, EnvMap]]): environment map or a dictionary of environment maps
use_envmap_bg (bool): whether to use envmap as background
transformation (torch.Tensor): (4, 4) transformation matrix
Returns:
edict based on return_types containing:
shaded (torch.Tensor): [3, H, W] shaded color image
normal (torch.Tensor): [3, H, W] normal image
base_color (torch.Tensor): [3, H, W] base color image
metallic (torch.Tensor): [H, W] metallic image
roughness (torch.Tensor): [H, W] roughness image
"""
if
'dr'
not
in
globals
():
import
nvdiffrast.torch
as
dr
if
not
isinstance
(
envmap
,
dict
):
envmap
=
{
''
:
envmap
}
num_envmaps
=
len
(
envmap
)
resolution
=
self
.
rendering_options
[
"resolution"
]
near
=
self
.
rendering_options
[
"near"
]
far
=
self
.
rendering_options
[
"far"
]
ssaa
=
self
.
rendering_options
[
"ssaa"
]
if
mesh
.
vertices
.
shape
[
0
]
==
0
or
mesh
.
faces
.
shape
[
0
]
==
0
:
out_dict
=
edict
(
normal
=
torch
.
zeros
((
3
,
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
mask
=
torch
.
zeros
((
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
base_color
=
torch
.
zeros
((
3
,
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
metallic
=
torch
.
zeros
((
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
roughness
=
torch
.
zeros
((
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
alpha
=
torch
.
zeros
((
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
clay
=
torch
.
zeros
((
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
)
for
i
,
k
in
enumerate
(
envmap
.
keys
()):
shaded_key
=
f
"shaded_
{
k
}
"
if
k
!=
''
else
"shaded"
out_dict
[
shaded_key
]
=
torch
.
zeros
((
3
,
resolution
,
resolution
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
return
out_dict
rays_o
,
rays_d
=
utils3d
.
torch
.
get_image_rays
(
extrinsics
,
intrinsics
,
resolution
*
ssaa
,
resolution
*
ssaa
)
perspective
=
intrinsics_to_projection
(
intrinsics
,
near
,
far
)
full_proj
=
(
perspective
@
extrinsics
).
unsqueeze
(
0
)
extrinsics
=
extrinsics
.
unsqueeze
(
0
)
L
=
get_logger
()
section
(
f
"PbrMeshRenderer.render res=
{
resolution
}
ssaa=
{
ssaa
}
"
)
vertices
=
mesh
.
vertices
.
unsqueeze
(
0
)
vertices_orig
=
vertices
.
clone
()
vertices_homo
=
torch
.
cat
([
vertices
,
torch
.
ones_like
(
vertices
[...,
:
1
])],
dim
=-
1
)
if
transformation
is
not
None
:
vertices_homo
=
_safe_transform4x4
(
vertices_homo
,
transformation
.
unsqueeze
(
0
).
transpose
(
-
1
,
-
2
))
vertices
=
vertices_homo
[...,
:
3
].
contiguous
()
vertices_camera
=
_safe_transform4x4
(
vertices_homo
,
extrinsics
.
transpose
(
-
1
,
-
2
))
vertices_clip
=
_safe_transform4x4
(
vertices_homo
,
full_proj
.
transpose
(
-
1
,
-
2
))
faces
=
mesh
.
faces
# ── Pre-rasterize sanity checks ──────────────────────────────────────
log_mesh
(
mesh
.
vertices
,
mesh
.
faces
,
"renderer-input"
)
L
.
info
(
f
"
{
elapsed
()
}
full_proj:
\n
{
full_proj
[
0
].
cpu
().
numpy
()
}
"
)
vc
=
vertices_clip
[
0
]
# [N, 4]
has_nan
=
torch
.
isnan
(
vc
).
any
().
item
()
has_inf
=
torch
.
isinf
(
vc
).
any
().
item
()
w_min
,
w_max
=
vc
[:,
3
].
min
().
item
(),
vc
[:,
3
].
max
().
item
()
w_zero
=
(
vc
[:,
3
].
abs
()
<
1e-6
).
sum
().
item
()
L
.
info
(
f
"
{
elapsed
()
}
vertices_clip: shape=
{
list
(
vc
.
shape
)
}
"
f
"NaN=
{
has_nan
}
inf=
{
has_inf
}
"
f
"x=[
{
vc
[:,
0
].
min
().
item
():.
4
g
}
,
{
vc
[:,
0
].
max
().
item
():.
4
g
}
] "
f
"y=[
{
vc
[:,
1
].
min
().
item
():.
4
g
}
,
{
vc
[:,
1
].
max
().
item
():.
4
g
}
] "
f
"z=[
{
vc
[:,
2
].
min
().
item
():.
4
g
}
,
{
vc
[:,
2
].
max
().
item
():.
4
g
}
] "
f
"w=[
{
w_min
:.
4
g
}
,
{
w_max
:.
4
g
}
] w_zeros=
{
w_zero
}
"
)
if
has_nan
or
has_inf
:
L
.
error
(
" ⚠ vertices_clip has NaN/inf — rasterizer will produce garbage!"
)
if
w_min
<
0
:
L
.
warning
(
f
" ⚠ vertices_clip has negative w values (
{
(
vc
[:,
3
]
<
0
).
sum
().
item
()
}
vertices)"
" — behind camera, may cause artifacts"
)
# NDC coords after perspective divide
ndc
=
vc
[:,
:
3
]
/
vc
[:,
3
:
4
].
clamp
(
min
=
1e-6
)
L
.
info
(
f
"
{
elapsed
()
}
NDC (after w-divide): "
f
"x=[
{
ndc
[:,
0
].
min
().
item
():.
4
g
}
,
{
ndc
[:,
0
].
max
().
item
():.
4
g
}
] "
f
"y=[
{
ndc
[:,
1
].
min
().
item
():.
4
g
}
,
{
ndc
[:,
1
].
max
().
item
():.
4
g
}
] "
f
"z=[
{
ndc
[:,
2
].
min
().
item
():.
4
g
}
,
{
ndc
[:,
2
].
max
().
item
():.
4
g
}
] "
f
"out_of_frustum=
{
(
ndc
.
abs
()
>
1.0
).
any
(
dim
=
1
).
sum
().
item
()
}
/
{
vc
.
shape
[
0
]
}
"
)
# Normal computation is skipped — all GPU and CPU smooth-normal approaches
# produce artifacts on ROCm GFX1201 for large meshes.
# A constant normal is used instead: normal view will be flat, but PBR/clay
# renders will be artifact-free.
_faces_cpu
=
mesh
.
faces
.
long
().
cpu
()
# [F, 3] — needed in the render loop
out_dict
=
edict
()
shaded
=
torch
.
zeros
((
num_envmaps
,
resolution
*
ssaa
,
resolution
*
ssaa
,
3
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
depth
=
torch
.
full
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
1e10
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
normal
=
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
3
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
max_w
=
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
alpha
=
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
rast_test
,
_
=
dr
.
rasterize
(
self
.
glctx
,
vertices_clip
,
faces
,
resolution
=
[
resolution
*
ssaa
,
resolution
*
ssaa
])
max_tri_id
=
rast_test
[...,
-
1
].
max
().
item
()
visible_px
=
(
rast_test
[...,
-
1
]
>
0
).
sum
().
item
()
total_px
=
(
resolution
*
ssaa
)
**
2
L
.
info
(
f
"
{
elapsed
()
}
rasterize test: max_tri_id=
{
max_tri_id
:.
0
f
}
"
f
"visible_px=
{
visible_px
}
/
{
total_px
}
(
{
100.
*
visible_px
/
total_px
:.
1
f
}
%)"
)
if
max_tri_id
>
mesh
.
faces
.
shape
[
0
]:
L
.
error
(
f
" ⚠ max_tri_id
{
max_tri_id
}
> num_faces
{
mesh
.
faces
.
shape
[
0
]
}
— CORRUPT RASTERIZE OUTPUT"
)
with
dr
.
DepthPeeler
(
self
.
glctx
,
vertices_clip
,
faces
,
(
resolution
*
ssaa
,
resolution
*
ssaa
))
as
peeler
:
for
_
in
range
(
self
.
rendering_options
[
"peel_layers"
]):
rast
,
rast_db
=
peeler
.
rasterize_next_layer
()
if
_
in
[
0
,
1
,
2
]:
visible_pixels
=
(
rast
[...,
-
1
]
>
0
).
sum
().
item
()
L
.
info
(
f
"
{
elapsed
()
}
DepthPeel layer=
{
_
}
visible_px=
{
visible_pixels
}
"
)
# Pos
pos
=
dr
.
interpolate
(
vertices
,
rast
,
faces
)[
0
][
0
]
# Depth
gb_depth
=
dr
.
interpolate
(
vertices_camera
[...,
2
:
3
].
contiguous
(),
rast
,
faces
)[
0
][
0
]
# Constant normal pointing toward the camera (-Z in camera space).
# Smooth normal computation is unreliable on ROCm GFX1201 large meshes.
H
=
rast
.
shape
[
1
];
W
=
rast
.
shape
[
2
]
gb_normal
=
torch
.
zeros
(
H
,
W
,
3
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
gb_normal
[...,
2
]
=
-
1.0
gb_normal
=
gb_normal
*
(
rast
[
0
,
...,
3
:
4
]
>
0
).
float
()
gb_cam_normal
=
(
extrinsics
[...,
:
3
,
:
3
].
reshape
(
1
,
1
,
3
,
3
)
@
gb_normal
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
if
_
==
0
:
out_dict
.
normal
=
-
gb_cam_normal
*
0.5
+
0.5
mask
=
(
rast
[
0
,
...,
-
1
:]
>
0
).
float
()
out_dict
.
mask
=
mask
# PBR attributes
if
isinstance
(
mesh
,
MeshWithVoxel
):
if
'grid_sample_3d'
not
in
globals
():
from
flex_gemm.ops.grid_sample
import
grid_sample_3d
mask
=
rast
[...,
-
1
:]
>
0
xyz
=
dr
.
interpolate
(
vertices_orig
,
rast
,
faces
)[
0
]
xyz
=
((
xyz
-
mesh
.
origin
)
/
mesh
.
voxel_size
).
reshape
(
1
,
-
1
,
3
)
img
=
grid_sample_3d
(
mesh
.
attrs
,
torch
.
cat
([
torch
.
zeros_like
(
mesh
.
coords
[...,
:
1
]),
mesh
.
coords
],
dim
=-
1
),
mesh
.
voxel_shape
,
xyz
,
mode
=
'trilinear'
)
img
=
img
.
reshape
(
1
,
resolution
*
ssaa
,
resolution
*
ssaa
,
mesh
.
attrs
.
shape
[
-
1
])
*
mask
gb_basecolor
=
img
[
0
,
...,
mesh
.
layout
[
'base_color'
]]
gb_metallic
=
img
[
0
,
...,
mesh
.
layout
[
'metallic'
]]
gb_roughness
=
img
[
0
,
...,
mesh
.
layout
[
'roughness'
]]
gb_alpha
=
img
[
0
,
...,
mesh
.
layout
[
'alpha'
]]
elif
isinstance
(
mesh
,
MeshWithPbrMaterial
):
tri_id
=
rast
[
0
,
:,
:,
-
1
:]
mask
=
tri_id
>
0
if
_
==
0
:
# log once per render call
L
.
info
(
f
"
{
elapsed
()
}
MeshWithPbrMaterial: "
f
"uv_coords=
{
list
(
mesh
.
uv_coords
.
shape
)
}
"
f
"material_ids=
{
list
(
mesh
.
material_ids
.
shape
)
}
"
f
"num_materials=
{
len
(
mesh
.
materials
)
}
"
)
log_uv
(
mesh
.
uv_coords
.
reshape
(
-
1
,
2
),
"mesh.uv_coords"
)
fi_min
=
mesh
.
material_ids
.
min
().
item
()
fi_max
=
mesh
.
material_ids
.
max
().
item
()
L
.
info
(
f
"
{
elapsed
()
}
material_ids range=[
{
fi_min
}
,
{
fi_max
}
] "
f
"num_materials=
{
len
(
mesh
.
materials
)
}
"
)
if
fi_max
>=
len
(
mesh
.
materials
):
L
.
error
(
f
" ⚠ material_ids max
{
fi_max
}
>= num_materials
{
len
(
mesh
.
materials
)
}
!"
)
uv_coords
=
mesh
.
uv_coords
.
reshape
(
1
,
-
1
,
2
)
texc
,
texd
=
dr
.
interpolate
(
uv_coords
,
rast
,
torch
.
arange
(
mesh
.
uv_coords
.
shape
[
0
]
*
3
,
dtype
=
torch
.
int
,
device
=
self
.
device
).
reshape
(
-
1
,
3
),
rast_db
=
rast_db
,
diff_attrs
=
'all'
)
if
_
==
0
:
log_tensor
(
texc
,
"texc-pre-clamp"
)
# Fix problematic texture coordinates
texc
=
torch
.
nan_to_num
(
texc
,
nan
=
0.0
,
posinf
=
1e3
,
neginf
=-
1e3
)
texc
=
torch
.
clamp
(
texc
,
min
=-
1e3
,
max
=
1e3
)
texd
=
torch
.
nan_to_num
(
texd
,
nan
=
0.0
,
posinf
=
1e3
,
neginf
=-
1e3
)
texd
=
torch
.
clamp
(
texd
,
min
=-
1e3
,
max
=
1e3
)
if
_
==
0
:
log_tensor
(
texc
,
"texc-post-clamp"
)
mid
=
mesh
.
material_ids
[(
tri_id
-
1
).
long
()]
gb_basecolor
=
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
3
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
gb_metallic
=
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
gb_roughness
=
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
gb_alpha
=
torch
.
zeros
((
resolution
*
ssaa
,
resolution
*
ssaa
,
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
id
,
mat
in
enumerate
(
mesh
.
materials
):
mat_mask
=
(
mid
==
id
).
float
()
*
mask
.
float
()
mat_texc
=
texc
*
mat_mask
mat_texd
=
texd
*
mat_mask
if
mat
.
base_color_texture
is
not
None
:
bc
=
dr
.
texture
(
mat
.
base_color_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
base_color_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
base_color_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
gb_basecolor
+=
bc
*
mat
.
base_color_factor
*
mat_mask
else
:
gb_basecolor
+=
mat
.
base_color_factor
*
mat_mask
if
mat
.
metallic_texture
is
not
None
:
m
=
dr
.
texture
(
mat
.
metallic_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
metallic_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
metallic_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
gb_metallic
+=
m
*
mat
.
metallic_factor
*
mat_mask
else
:
gb_metallic
+=
mat
.
metallic_factor
*
mat_mask
if
mat
.
roughness_texture
is
not
None
:
r
=
dr
.
texture
(
mat
.
roughness_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
roughness_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
roughness_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
gb_roughness
+=
r
*
mat
.
roughness_factor
*
mat_mask
else
:
gb_roughness
+=
mat
.
roughness_factor
*
mat_mask
if
mat
.
alpha_mode
==
AlphaMode
.
OPAQUE
:
gb_alpha
+=
1.0
*
mat_mask
else
:
if
mat
.
alpha_texture
is
not
None
:
a
=
dr
.
texture
(
mat
.
alpha_texture
.
image
.
unsqueeze
(
0
),
mat_texc
,
mat_texd
,
filter_mode
=
'linear-mipmap-linear'
if
mat
.
alpha_texture
.
filter_mode
==
TextureFilterMode
.
LINEAR
else
'nearest'
,
boundary_mode
=
'clamp'
if
mat
.
alpha_texture
.
wrap_mode
==
TextureWrapMode
.
CLAMP_TO_EDGE
else
'wrap'
)[
0
]
if
mat
.
alpha_mode
==
AlphaMode
.
MASK
:
gb_alpha
+=
(
a
*
mat
.
alpha_factor
>
mat
.
alpha_cutoff
).
float
()
*
mat_mask
elif
mat
.
alpha_mode
==
AlphaMode
.
BLEND
:
gb_alpha
+=
a
*
mat
.
alpha_factor
*
mat_mask
else
:
if
mat
.
alpha_mode
==
AlphaMode
.
MASK
:
gb_alpha
+=
(
mat
.
alpha_factor
>
mat
.
alpha_cutoff
).
float
()
*
mat_mask
elif
mat
.
alpha_mode
==
AlphaMode
.
BLEND
:
gb_alpha
+=
mat
.
alpha_factor
*
mat_mask
if
_
==
0
:
out_dict
.
base_color
=
gb_basecolor
out_dict
.
metallic
=
gb_metallic
out_dict
.
roughness
=
gb_roughness
out_dict
.
alpha
=
gb_alpha
# Shading
gb_basecolor
=
torch
.
clamp
(
gb_basecolor
,
0.0
,
1.0
)
**
2.2
gb_metallic
=
torch
.
clamp
(
gb_metallic
,
0.0
,
1.0
)
gb_roughness
=
torch
.
clamp
(
gb_roughness
,
0.0
,
1.0
)
gb_alpha
=
torch
.
clamp
(
gb_alpha
,
0.0
,
1.0
)
gb_orm
=
torch
.
cat
([
torch
.
zeros_like
(
gb_metallic
),
gb_roughness
,
gb_metallic
,
],
dim
=-
1
)
_log
=
get_logger
()
_log
.
debug
(
f
"--- RASTERIZATION DEBUG --- pos sum:
{
pos
.
sum
().
item
()
}
| max:
{
pos
.
max
().
item
()
}
"
)
_log
.
debug
(
f
"gb_normal sum:
{
gb_normal
.
sum
().
item
()
}
| gb_basecolor sum:
{
gb_basecolor
.
sum
().
item
()
}
| gb_orm sum:
{
gb_orm
.
sum
().
item
()
}
| mask sum:
{
mask
.
float
().
sum
().
item
()
}
"
)
gb_shaded
=
torch
.
stack
([
e
.
shade
(
pos
.
unsqueeze
(
0
),
gb_normal
.
unsqueeze
(
0
),
gb_basecolor
.
unsqueeze
(
0
),
gb_orm
.
unsqueeze
(
0
),
rays_o
,
specular
=
True
,
)[
0
]
for
e
in
envmap
.
values
()
],
dim
=
0
)
# Compositing
w
=
(
1
-
alpha
)
*
gb_alpha
depth
=
torch
.
where
(
w
>
max_w
,
gb_depth
,
depth
)
normal
=
torch
.
where
(
w
>
max_w
,
gb_cam_normal
,
normal
)
max_w
=
torch
.
maximum
(
max_w
,
w
)
shaded
+=
w
*
gb_shaded
alpha
+=
w
# Ambient occulusion
f_occ
=
screen_space_ambient_occlusion
(
depth
,
normal
,
perspective
,
intensity
=
1.5
)
shaded
*=
(
1
-
f_occ
)
out_dict
.
clay
=
(
1
-
f_occ
)
# Background
if
use_envmap_bg
:
bg
=
torch
.
stack
([
e
.
sample
(
rays_d
)
for
e
in
envmap
.
values
()],
dim
=
0
)
shaded
+=
(
1
-
alpha
)
*
bg
for
i
,
k
in
enumerate
(
envmap
.
keys
()):
shaded_key
=
f
"shaded_
{
k
}
"
if
k
!=
''
else
"shaded"
out_dict
[
shaded_key
]
=
shaded
[
i
]
# SSAA
for
k
in
out_dict
.
keys
():
if
ssaa
>
1
:
out_dict
[
k
]
=
F
.
interpolate
(
out_dict
[
k
].
unsqueeze
(
0
).
permute
(
0
,
3
,
1
,
2
),
(
resolution
,
resolution
),
mode
=
'bilinear'
,
align_corners
=
False
,
antialias
=
True
)
else
:
out_dict
[
k
]
=
out_dict
[
k
].
permute
(
2
,
0
,
1
)
out_dict
[
k
]
=
out_dict
[
k
].
squeeze
()
# Post processing
for
k
in
envmap
.
keys
():
shaded_key
=
f
"shaded_
{
k
}
"
if
k
!=
''
else
"shaded"
out_dict
[
shaded_key
]
=
aces_tonemapping
(
out_dict
[
shaded_key
])
out_dict
[
shaded_key
]
=
gamma_correction
(
out_dict
[
shaded_key
])
return
out_dict
TRELLIS.2_DCU/trellis2/renderers/voxel_renderer.py
0 → 100644
View file @
f05e915f
import
torch
from
easydict
import
EasyDict
as
edict
from
..representations
import
Voxel
from
easydict
import
EasyDict
as
edict
class
VoxelRenderer
:
"""
Renderer for the Voxel representation.
Args:
rendering_options (dict): Rendering options.
"""
def
__init__
(
self
,
rendering_options
=
{})
->
None
:
self
.
rendering_options
=
edict
({
"resolution"
:
None
,
"near"
:
0.1
,
"far"
:
10.0
,
"ssaa"
:
1
,
})
self
.
rendering_options
.
update
(
rendering_options
)
def
render
(
self
,
voxel
:
Voxel
,
extrinsics
:
torch
.
Tensor
,
intrinsics
:
torch
.
Tensor
,
colors_overwrite
:
torch
.
Tensor
=
None
)
->
edict
:
"""
Render the gausssian.
Args:
voxel (Voxel): Voxel representation.
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
colors_overwrite (torch.Tensor): (N, 3) override color
Returns:
edict containing:
color (torch.Tensor): (3, H, W) rendered color image
depth (torch.Tensor): (H, W) rendered depth
alpha (torch.Tensor): (H, W) rendered alpha
...
"""
# lazy import
if
'o_voxel'
not
in
globals
():
import
o_voxel
renderer
=
o_voxel
.
rasterize
.
VoxelRenderer
(
self
.
rendering_options
)
positions
=
voxel
.
position
attrs
=
voxel
.
attrs
if
colors_overwrite
is
None
else
colors_overwrite
voxel_size
=
voxel
.
voxel_size
# Render
render_ret
=
renderer
.
render
(
positions
,
attrs
,
voxel_size
,
extrinsics
,
intrinsics
)
ret
=
{
'depth'
:
render_ret
[
'depth'
],
'alpha'
:
render_ret
[
'alpha'
],
}
if
colors_overwrite
is
not
None
:
ret
[
'color'
]
=
render_ret
[
'attr'
]
else
:
for
k
,
s
in
voxel
.
layout
.
items
():
ret
[
k
]
=
render_ret
[
'attr'
][
s
]
return
ret
TRELLIS.2_DCU/trellis2/representations/__init__.py
0 → 100644
View file @
f05e915f
import
importlib
__attributes
=
{
'Mesh'
:
'mesh'
,
'Voxel'
:
'voxel'
,
'MeshWithVoxel'
:
'mesh'
,
'MeshWithPbrMaterial'
:
'mesh'
,
}
__submodules
=
[]
__all__
=
list
(
__attributes
.
keys
())
+
__submodules
def
__getattr__
(
name
):
if
name
not
in
globals
():
if
name
in
__attributes
:
module_name
=
__attributes
[
name
]
module
=
importlib
.
import_module
(
f
".
{
module_name
}
"
,
__name__
)
globals
()[
name
]
=
getattr
(
module
,
name
)
elif
name
in
__submodules
:
module
=
importlib
.
import_module
(
f
".
{
name
}
"
,
__name__
)
globals
()[
name
]
=
module
else
:
raise
AttributeError
(
f
"module
{
__name__
}
has no attribute
{
name
}
"
)
return
globals
()[
name
]
# For Pylance
if
__name__
==
'__main__'
:
from
.mesh
import
Mesh
,
MeshWithVoxel
,
MeshWithPbrMaterial
from
.voxel
import
Voxel
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment