Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
TRELLIS.2
Commits
f05e915f
Commit
f05e915f
authored
May 27, 2026
by
weishb
Browse files
首次提交
parent
297bf637
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2500 additions
and
0 deletions
+2500
-0
TRELLIS.2_DCU/trellis2/datasets/__init__.py
TRELLIS.2_DCU/trellis2/datasets/__init__.py
+46
-0
TRELLIS.2_DCU/trellis2/datasets/components.py
TRELLIS.2_DCU/trellis2/datasets/components.py
+192
-0
TRELLIS.2_DCU/trellis2/datasets/flexi_dual_grid.py
TRELLIS.2_DCU/trellis2/datasets/flexi_dual_grid.py
+173
-0
TRELLIS.2_DCU/trellis2/datasets/sparse_structure_latent.py
TRELLIS.2_DCU/trellis2/datasets/sparse_structure_latent.py
+160
-0
TRELLIS.2_DCU/trellis2/datasets/sparse_voxel_pbr.py
TRELLIS.2_DCU/trellis2/datasets/sparse_voxel_pbr.py
+298
-0
TRELLIS.2_DCU/trellis2/datasets/structured_latent.py
TRELLIS.2_DCU/trellis2/datasets/structured_latent.py
+210
-0
TRELLIS.2_DCU/trellis2/datasets/structured_latent_shape.py
TRELLIS.2_DCU/trellis2/datasets/structured_latent_shape.py
+96
-0
TRELLIS.2_DCU/trellis2/datasets/structured_latent_svpbr.py
TRELLIS.2_DCU/trellis2/datasets/structured_latent_svpbr.py
+290
-0
TRELLIS.2_DCU/trellis2/models/__init__.py
TRELLIS.2_DCU/trellis2/models/__init__.py
+78
-0
TRELLIS.2_DCU/trellis2/models/__pycache__/__init__.cpython-310.pyc
..._DCU/trellis2/models/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc
...2/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/models/__pycache__/sparse_structure_flow.cpython-310.pyc
.../models/__pycache__/sparse_structure_flow.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/models/__pycache__/sparse_structure_vae.cpython-310.pyc
...2/models/__pycache__/sparse_structure_vae.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/models/__pycache__/structured_latent_flow.cpython-310.pyc
...models/__pycache__/structured_latent_flow.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/models/sc_vaes/__pycache__/fdg_vae.cpython-310.pyc
...ellis2/models/sc_vaes/__pycache__/fdg_vae.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/models/sc_vaes/__pycache__/sparse_unet_vae.cpython-310.pyc
...odels/sc_vaes/__pycache__/sparse_unet_vae.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/models/sc_vaes/fdg_vae.py
TRELLIS.2_DCU/trellis2/models/sc_vaes/fdg_vae.py
+117
-0
TRELLIS.2_DCU/trellis2/models/sc_vaes/sparse_unet_vae.py
TRELLIS.2_DCU/trellis2/models/sc_vaes/sparse_unet_vae.py
+569
-0
TRELLIS.2_DCU/trellis2/models/sparse_elastic_mixin.py
TRELLIS.2_DCU/trellis2/models/sparse_elastic_mixin.py
+24
-0
TRELLIS.2_DCU/trellis2/models/sparse_structure_flow.py
TRELLIS.2_DCU/trellis2/models/sparse_structure_flow.py
+247
-0
No files found.
TRELLIS.2_DCU/trellis2/datasets/__init__.py
0 → 100644
View file @
f05e915f
import
importlib
__attributes
=
{
'FlexiDualGridDataset'
:
'flexi_dual_grid'
,
'SparseVoxelPbrDataset'
:
'sparse_voxel_pbr'
,
'SparseStructureLatent'
:
'sparse_structure_latent'
,
'TextConditionedSparseStructureLatent'
:
'sparse_structure_latent'
,
'ImageConditionedSparseStructureLatent'
:
'sparse_structure_latent'
,
'SLat'
:
'structured_latent'
,
'ImageConditionedSLat'
:
'structured_latent'
,
'SLatShape'
:
'structured_latent_shape'
,
'ImageConditionedSLatShape'
:
'structured_latent_shape'
,
'SLatPbr'
:
'structured_latent_svpbr'
,
'ImageConditionedSLatPbr'
:
'structured_latent_svpbr'
,
}
__submodules
=
[]
__all__
=
list
(
__attributes
.
keys
())
+
__submodules
def
__getattr__
(
name
):
if
name
not
in
globals
():
if
name
in
__attributes
:
module_name
=
__attributes
[
name
]
module
=
importlib
.
import_module
(
f
".
{
module_name
}
"
,
__name__
)
globals
()[
name
]
=
getattr
(
module
,
name
)
elif
name
in
__submodules
:
module
=
importlib
.
import_module
(
f
".
{
name
}
"
,
__name__
)
globals
()[
name
]
=
module
else
:
raise
AttributeError
(
f
"module
{
__name__
}
has no attribute
{
name
}
"
)
return
globals
()[
name
]
# For Pylance
if
__name__
==
'__main__'
:
from
.flexi_dual_grid
import
FlexiDualGridDataset
from
.sparse_voxel_pbr
import
SparseVoxelPbrDataset
from
.sparse_structure_latent
import
SparseStructureLatent
,
ImageConditionedSparseStructureLatent
from
.structured_latent
import
SLat
,
ImageConditionedSLat
from
.structured_latent_shape
import
SLatShape
,
ImageConditionedSLatShape
from
.structured_latent_svpbr
import
SLatPbr
,
ImageConditionedSLatPbr
\ No newline at end of file
TRELLIS.2_DCU/trellis2/datasets/components.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
json
from
abc
import
abstractmethod
import
os
import
json
import
torch
import
numpy
as
np
import
pandas
as
pd
from
PIL
import
Image
from
torch.utils.data
import
Dataset
class
StandardDatasetBase
(
Dataset
):
"""
Base class for standard datasets.
Args:
roots (str): paths to the dataset
"""
def
__init__
(
self
,
roots
:
str
,
):
super
().
__init__
()
try
:
self
.
roots
=
json
.
loads
(
roots
)
root_type
=
'obj'
except
:
self
.
roots
=
roots
.
split
(
','
)
root_type
=
'list'
self
.
instances
=
[]
self
.
metadata
=
pd
.
DataFrame
()
self
.
_stats
=
{}
if
root_type
==
'obj'
:
for
key
,
root
in
self
.
roots
.
items
():
self
.
_stats
[
key
]
=
{}
metadata
=
pd
.
DataFrame
(
columns
=
[
'sha256'
]).
set_index
(
'sha256'
)
for
_
,
r
in
root
.
items
():
metadata
=
metadata
.
combine_first
(
pd
.
read_csv
(
os
.
path
.
join
(
r
,
'metadata.csv'
)).
set_index
(
'sha256'
))
self
.
_stats
[
key
][
'Total'
]
=
len
(
metadata
)
metadata
,
stats
=
self
.
filter_metadata
(
metadata
)
self
.
_stats
[
key
].
update
(
stats
)
self
.
instances
.
extend
([(
root
,
sha256
)
for
sha256
in
metadata
.
index
.
values
])
self
.
metadata
=
pd
.
concat
([
self
.
metadata
,
metadata
])
else
:
for
root
in
self
.
roots
:
key
=
os
.
path
.
basename
(
root
)
self
.
_stats
[
key
]
=
{}
metadata
=
pd
.
read_csv
(
os
.
path
.
join
(
root
,
'metadata.csv'
))
self
.
_stats
[
key
][
'Total'
]
=
len
(
metadata
)
metadata
,
stats
=
self
.
filter_metadata
(
metadata
)
self
.
_stats
[
key
].
update
(
stats
)
self
.
instances
.
extend
([(
root
,
sha256
)
for
sha256
in
metadata
[
'sha256'
].
values
])
metadata
.
set_index
(
'sha256'
,
inplace
=
True
)
self
.
metadata
=
pd
.
concat
([
self
.
metadata
,
metadata
])
@
abstractmethod
def
filter_metadata
(
self
,
metadata
:
pd
.
DataFrame
)
->
Tuple
[
pd
.
DataFrame
,
Dict
[
str
,
int
]]:
pass
@
abstractmethod
def
get_instance
(
self
,
root
,
instance
:
str
)
->
Dict
[
str
,
Any
]:
pass
def
__len__
(
self
):
return
len
(
self
.
instances
)
def
__getitem__
(
self
,
index
)
->
Dict
[
str
,
Any
]:
try
:
root
,
instance
=
self
.
instances
[
index
]
return
self
.
get_instance
(
root
,
instance
)
except
Exception
as
e
:
print
(
f
'Error loading
{
instance
}
:
{
e
}
'
)
return
self
.
__getitem__
(
np
.
random
.
randint
(
0
,
len
(
self
)))
def
__str__
(
self
):
lines
=
[]
lines
.
append
(
self
.
__class__
.
__name__
)
lines
.
append
(
f
' - Total instances:
{
len
(
self
)
}
'
)
lines
.
append
(
f
' - Sources:'
)
for
key
,
stats
in
self
.
_stats
.
items
():
lines
.
append
(
f
' -
{
key
}
:'
)
for
k
,
v
in
stats
.
items
():
lines
.
append
(
f
' -
{
k
}
:
{
v
}
'
)
return
'
\n
'
.
join
(
lines
)
class
ImageConditionedMixin
:
def
__init__
(
self
,
roots
,
*
,
image_size
=
518
,
**
kwargs
):
self
.
image_size
=
image_size
super
().
__init__
(
roots
,
**
kwargs
)
def
filter_metadata
(
self
,
metadata
):
metadata
,
stats
=
super
().
filter_metadata
(
metadata
)
metadata
=
metadata
[
metadata
[
'cond_rendered'
].
notna
()]
stats
[
'Cond rendered'
]
=
len
(
metadata
)
return
metadata
,
stats
def
get_instance
(
self
,
root
,
instance
):
pack
=
super
().
get_instance
(
root
,
instance
)
image_root
=
os
.
path
.
join
(
root
[
'render_cond'
],
instance
)
with
open
(
os
.
path
.
join
(
image_root
,
'transforms.json'
))
as
f
:
metadata
=
json
.
load
(
f
)
n_views
=
len
(
metadata
[
'frames'
])
view
=
np
.
random
.
randint
(
n_views
)
metadata
=
metadata
[
'frames'
][
view
]
image_path
=
os
.
path
.
join
(
image_root
,
metadata
[
'file_path'
])
image
=
Image
.
open
(
image_path
)
alpha
=
np
.
array
(
image
.
getchannel
(
3
))
bbox
=
np
.
array
(
alpha
).
nonzero
()
bbox
=
[
bbox
[
1
].
min
(),
bbox
[
0
].
min
(),
bbox
[
1
].
max
(),
bbox
[
0
].
max
()]
center
=
[(
bbox
[
0
]
+
bbox
[
2
])
/
2
,
(
bbox
[
1
]
+
bbox
[
3
])
/
2
]
hsize
=
max
(
bbox
[
2
]
-
bbox
[
0
],
bbox
[
3
]
-
bbox
[
1
])
/
2
aug_hsize
=
hsize
aug_center_offset
=
[
0
,
0
]
aug_center
=
[
center
[
0
]
+
aug_center_offset
[
0
],
center
[
1
]
+
aug_center_offset
[
1
]]
aug_bbox
=
[
int
(
aug_center
[
0
]
-
aug_hsize
),
int
(
aug_center
[
1
]
-
aug_hsize
),
int
(
aug_center
[
0
]
+
aug_hsize
),
int
(
aug_center
[
1
]
+
aug_hsize
)]
image
=
image
.
crop
(
aug_bbox
)
image
=
image
.
resize
((
self
.
image_size
,
self
.
image_size
),
Image
.
Resampling
.
LANCZOS
)
alpha
=
image
.
getchannel
(
3
)
image
=
image
.
convert
(
'RGB'
)
image
=
torch
.
tensor
(
np
.
array
(
image
)).
permute
(
2
,
0
,
1
).
float
()
/
255.0
alpha
=
torch
.
tensor
(
np
.
array
(
alpha
)).
float
()
/
255.0
image
=
image
*
alpha
.
unsqueeze
(
0
)
pack
[
'cond'
]
=
image
return
pack
class
MultiImageConditionedMixin
:
def
__init__
(
self
,
roots
,
*
,
image_size
=
518
,
max_image_cond_view
=
4
,
**
kwargs
):
self
.
image_size
=
image_size
self
.
max_image_cond_view
=
max_image_cond_view
super
().
__init__
(
roots
,
**
kwargs
)
def
filter_metadata
(
self
,
metadata
):
metadata
,
stats
=
super
().
filter_metadata
(
metadata
)
metadata
=
metadata
[
metadata
[
'cond_rendered'
].
notna
()]
stats
[
'Cond rendered'
]
=
len
(
metadata
)
return
metadata
,
stats
def
get_instance
(
self
,
root
,
instance
):
pack
=
super
().
get_instance
(
root
,
instance
)
image_root
=
os
.
path
.
join
(
root
[
'render_cond'
],
instance
)
with
open
(
os
.
path
.
join
(
image_root
,
'transforms.json'
))
as
f
:
metadata
=
json
.
load
(
f
)
n_views
=
len
(
metadata
[
'frames'
])
n_sample_views
=
np
.
random
.
randint
(
1
,
self
.
max_image_cond_view
+
1
)
assert
n_views
>=
n_sample_views
,
f
'Not enough views to sample
{
n_sample_views
}
unique images.'
sampled_views
=
np
.
random
.
choice
(
n_views
,
size
=
n_sample_views
,
replace
=
False
)
cond_images
=
[]
for
v
in
sampled_views
:
frame_info
=
metadata
[
'frames'
][
v
]
image_path
=
os
.
path
.
join
(
image_root
,
frame_info
[
'file_path'
])
image
=
Image
.
open
(
image_path
)
alpha
=
np
.
array
(
image
.
getchannel
(
3
))
bbox
=
np
.
array
(
alpha
).
nonzero
()
bbox
=
[
bbox
[
1
].
min
(),
bbox
[
0
].
min
(),
bbox
[
1
].
max
(),
bbox
[
0
].
max
()]
center
=
[(
bbox
[
0
]
+
bbox
[
2
])
/
2
,
(
bbox
[
1
]
+
bbox
[
3
])
/
2
]
hsize
=
max
(
bbox
[
2
]
-
bbox
[
0
],
bbox
[
3
]
-
bbox
[
1
])
/
2
aug_hsize
=
hsize
aug_center
=
center
aug_bbox
=
[
int
(
aug_center
[
0
]
-
aug_hsize
),
int
(
aug_center
[
1
]
-
aug_hsize
),
int
(
aug_center
[
0
]
+
aug_hsize
),
int
(
aug_center
[
1
]
+
aug_hsize
),
]
img
=
image
.
crop
(
aug_bbox
)
img
=
img
.
resize
((
self
.
image_size
,
self
.
image_size
),
Image
.
Resampling
.
LANCZOS
)
alpha
=
img
.
getchannel
(
3
)
img
=
img
.
convert
(
'RGB'
)
img
=
torch
.
tensor
(
np
.
array
(
img
)).
permute
(
2
,
0
,
1
).
float
()
/
255.0
alpha
=
torch
.
tensor
(
np
.
array
(
alpha
)).
float
()
/
255.0
img
=
img
*
alpha
.
unsqueeze
(
0
)
cond_images
.
append
(
img
)
pack
[
'cond'
]
=
[
torch
.
stack
(
cond_images
,
dim
=
0
)]
# (V,3,H,W)
return
pack
TRELLIS.2_DCU/trellis2/datasets/flexi_dual_grid.py
0 → 100644
View file @
f05e915f
import
os
import
numpy
as
np
import
pickle
import
torch
import
utils3d
from
.components
import
StandardDatasetBase
from
..modules
import
sparse
as
sp
from
..renderers
import
MeshRenderer
from
..representations
import
Mesh
from
..utils.data_utils
import
load_balanced_group_indices
import
o_voxel
class
FlexiDualGridVisMixin
:
@
torch
.
no_grad
()
def
visualize_sample
(
self
,
x
:
dict
):
mesh
=
x
[
'mesh'
]
renderer
=
MeshRenderer
({
'near'
:
1
,
'far'
:
3
})
renderer
.
rendering_options
.
resolution
=
512
renderer
.
rendering_options
.
ssaa
=
4
# Build camera
yaws
=
[
0
,
np
.
pi
/
2
,
np
.
pi
,
3
*
np
.
pi
/
2
]
yaws_offset
=
np
.
random
.
uniform
(
-
np
.
pi
/
4
,
np
.
pi
/
4
)
yaws
=
[
y
+
yaws_offset
for
y
in
yaws
]
pitch
=
[
np
.
random
.
uniform
(
-
np
.
pi
/
4
,
np
.
pi
/
4
)
for
_
in
range
(
4
)]
exts
=
[]
ints
=
[]
for
yaw
,
pitch
in
zip
(
yaws
,
pitch
):
orig
=
torch
.
tensor
([
np
.
sin
(
yaw
)
*
np
.
cos
(
pitch
),
np
.
cos
(
yaw
)
*
np
.
cos
(
pitch
),
np
.
sin
(
pitch
),
]).
float
().
cuda
()
*
2
fov
=
torch
.
deg2rad
(
torch
.
tensor
(
30
)).
cuda
()
extrinsics
=
utils3d
.
torch
.
extrinsics_look_at
(
orig
,
torch
.
tensor
([
0
,
0
,
0
]).
float
().
cuda
(),
torch
.
tensor
([
0
,
0
,
1
]).
float
().
cuda
())
intrinsics
=
utils3d
.
torch
.
intrinsics_from_fov_xy
(
fov
,
fov
)
exts
.
append
(
extrinsics
)
ints
.
append
(
intrinsics
)
# Build each representation
images
=
[]
for
m
in
mesh
:
image
=
torch
.
zeros
(
3
,
1024
,
1024
).
cuda
()
tile
=
[
2
,
2
]
for
j
,
(
ext
,
intr
)
in
enumerate
(
zip
(
exts
,
ints
)):
image
[:,
512
*
(
j
//
tile
[
1
]):
512
*
(
j
//
tile
[
1
]
+
1
),
512
*
(
j
%
tile
[
1
]):
512
*
(
j
%
tile
[
1
]
+
1
)]
=
\
renderer
.
render
(
m
.
cuda
(),
ext
,
intr
)[
'normal'
]
images
.
append
(
image
)
images
=
torch
.
stack
(
images
)
return
images
class
FlexiDualGridDataset
(
FlexiDualGridVisMixin
,
StandardDatasetBase
):
"""
Flexible Dual Grid Dataset
Args:
roots (str): path to the dataset
resolution (int): resolution of the voxel grid
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
"""
def
__init__
(
self
,
roots
,
resolution
:
int
=
1024
,
max_active_voxels
:
int
=
1000000
,
max_num_faces
:
int
=
None
,
min_aesthetic_score
:
float
=
5.0
,
):
self
.
resolution
=
resolution
self
.
min_aesthetic_score
=
min_aesthetic_score
self
.
max_active_voxels
=
max_active_voxels
self
.
max_num_faces
=
max_num_faces
self
.
value_range
=
(
0
,
1
)
super
().
__init__
(
roots
)
self
.
loads
=
[
self
.
metadata
.
loc
[
sha256
,
f
'dual_grid_size'
]
for
_
,
sha256
in
self
.
instances
]
def
__str__
(
self
):
lines
=
[
super
().
__str__
(),
f
' - Resolution:
{
self
.
resolution
}
'
,
]
return
'
\n
'
.
join
(
lines
)
def
filter_metadata
(
self
,
metadata
):
stats
=
{}
metadata
=
metadata
[
metadata
[
f
'dual_grid_converted'
]
==
True
]
stats
[
'Dual Grid Converted'
]
=
len
(
metadata
)
if
self
.
min_aesthetic_score
is
not
None
:
metadata
=
metadata
[
metadata
[
'aesthetic_score'
]
>=
self
.
min_aesthetic_score
]
stats
[
f
'Aesthetic score >=
{
self
.
min_aesthetic_score
}
'
]
=
len
(
metadata
)
metadata
=
metadata
[
metadata
[
f
'dual_grid_size'
]
<=
self
.
max_active_voxels
]
stats
[
f
'Active Voxels <=
{
self
.
max_active_voxels
}
'
]
=
len
(
metadata
)
if
self
.
max_num_faces
is
not
None
:
metadata
=
metadata
[
metadata
[
'num_faces'
]
<=
self
.
max_num_faces
]
stats
[
f
'Faces <=
{
self
.
max_num_faces
}
'
]
=
len
(
metadata
)
return
metadata
,
stats
def
read_mesh
(
self
,
root
,
instance
):
with
open
(
os
.
path
.
join
(
root
,
f
'
{
instance
}
.pickle'
),
'rb'
)
as
f
:
dump
=
pickle
.
load
(
f
)
start
=
0
vertices
=
[]
faces
=
[]
for
obj
in
dump
[
'objects'
]:
if
obj
[
'vertices'
].
size
==
0
or
obj
[
'faces'
].
size
==
0
:
continue
vertices
.
append
(
obj
[
'vertices'
])
faces
.
append
(
obj
[
'faces'
]
+
start
)
start
+=
len
(
obj
[
'vertices'
])
vertices
=
torch
.
from_numpy
(
np
.
concatenate
(
vertices
,
axis
=
0
)).
float
()
faces
=
torch
.
from_numpy
(
np
.
concatenate
(
faces
,
axis
=
0
)).
long
()
vertices_min
=
vertices
.
min
(
dim
=
0
)[
0
]
vertices_max
=
vertices
.
max
(
dim
=
0
)[
0
]
center
=
(
vertices_min
+
vertices_max
)
/
2
scale
=
0.99999
/
(
vertices_max
-
vertices_min
).
max
()
vertices
=
(
vertices
-
center
)
*
scale
assert
torch
.
all
(
vertices
>=
-
0.5
)
and
torch
.
all
(
vertices
<=
0.5
),
'vertices out of range'
return
{
'mesh'
:
[
Mesh
(
vertices
=
vertices
,
faces
=
faces
)]}
def
read_dual_grid
(
self
,
root
,
instance
):
coords
,
attr
=
o_voxel
.
io
.
read_vxz
(
os
.
path
.
join
(
root
,
f
'
{
instance
}
.vxz'
),
num_threads
=
4
)
vertices
=
sp
.
SparseTensor
(
(
attr
[
'vertices'
]
/
255.0
).
float
(),
torch
.
cat
([
torch
.
zeros_like
(
coords
[:,
0
:
1
]),
coords
],
dim
=-
1
),
)
intersected
=
vertices
.
replace
(
torch
.
cat
([
attr
[
'intersected'
]
%
2
,
attr
[
'intersected'
]
//
2
%
2
,
attr
[
'intersected'
]
//
4
%
2
,
],
dim
=-
1
).
bool
())
return
{
'vertices'
:
vertices
,
'intersected'
:
intersected
}
def
get_instance
(
self
,
root
,
instance
):
mesh
=
self
.
read_mesh
(
root
[
'mesh_dump'
],
instance
)
dual_grid
=
self
.
read_dual_grid
(
root
[
'dual_grid'
],
instance
)
return
{
**
mesh
,
**
dual_grid
}
@
staticmethod
def
collate_fn
(
batch
,
split_size
=
None
):
if
split_size
is
None
:
group_idx
=
[
list
(
range
(
len
(
batch
)))]
else
:
group_idx
=
load_balanced_group_indices
([
b
[
'vertices'
].
feats
.
shape
[
0
]
for
b
in
batch
],
split_size
)
packs
=
[]
for
group
in
group_idx
:
sub_batch
=
[
batch
[
i
]
for
i
in
group
]
pack
=
{}
keys
=
[
k
for
k
in
sub_batch
[
0
].
keys
()]
for
k
in
keys
:
if
isinstance
(
sub_batch
[
0
][
k
],
torch
.
Tensor
):
pack
[
k
]
=
torch
.
stack
([
b
[
k
]
for
b
in
sub_batch
])
elif
isinstance
(
sub_batch
[
0
][
k
],
sp
.
SparseTensor
):
pack
[
k
]
=
sp
.
sparse_cat
([
b
[
k
]
for
b
in
sub_batch
],
dim
=
0
)
elif
isinstance
(
sub_batch
[
0
][
k
],
list
):
pack
[
k
]
=
sum
([
b
[
k
]
for
b
in
sub_batch
],
[])
else
:
pack
[
k
]
=
[
b
[
k
]
for
b
in
sub_batch
]
packs
.
append
(
pack
)
if
split_size
is
None
:
return
packs
[
0
]
return
packs
\ No newline at end of file
TRELLIS.2_DCU/trellis2/datasets/sparse_structure_latent.py
0 → 100644
View file @
f05e915f
import
os
import
json
from
typing
import
*
import
numpy
as
np
import
torch
from
..representations
import
Voxel
from
..renderers
import
VoxelRenderer
from
.components
import
StandardDatasetBase
,
ImageConditionedMixin
from
..
import
models
from
..utils.render_utils
import
yaw_pitch_r_fov_to_extrinsics_intrinsics
class
SparseStructureLatentVisMixin
:
def
__init__
(
self
,
*
args
,
pretrained_ss_dec
:
str
=
'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16.json'
,
ss_dec_path
:
Optional
[
str
]
=
None
,
ss_dec_ckpt
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
ss_dec
=
None
self
.
pretrained_ss_dec
=
pretrained_ss_dec
self
.
ss_dec_path
=
ss_dec_path
self
.
ss_dec_ckpt
=
ss_dec_ckpt
def
_loading_ss_dec
(
self
):
if
self
.
ss_dec
is
not
None
:
return
if
self
.
ss_dec_path
is
not
None
:
cfg
=
json
.
load
(
open
(
os
.
path
.
join
(
self
.
ss_dec_path
,
'config.json'
),
'r'
))
decoder
=
getattr
(
models
,
cfg
[
'models'
][
'decoder'
][
'name'
])(
**
cfg
[
'models'
][
'decoder'
][
'args'
])
ckpt_path
=
os
.
path
.
join
(
self
.
ss_dec_path
,
'ckpts'
,
f
'decoder_
{
self
.
ss_dec_ckpt
}
.pt'
)
decoder
.
load_state_dict
(
torch
.
load
(
ckpt_path
,
map_location
=
'cpu'
,
weights_only
=
True
))
else
:
decoder
=
models
.
from_pretrained
(
self
.
pretrained_ss_dec
)
self
.
ss_dec
=
decoder
.
cuda
().
eval
()
def
_delete_ss_dec
(
self
):
del
self
.
ss_dec
self
.
ss_dec
=
None
@
torch
.
no_grad
()
def
decode_latent
(
self
,
z
,
batch_size
=
4
):
self
.
_loading_ss_dec
()
ss
=
[]
if
self
.
normalization
:
z
=
z
*
self
.
std
.
to
(
z
.
device
)
+
self
.
mean
.
to
(
z
.
device
)
for
i
in
range
(
0
,
z
.
shape
[
0
],
batch_size
):
ss
.
append
(
self
.
ss_dec
(
z
[
i
:
i
+
batch_size
]))
ss
=
torch
.
cat
(
ss
,
dim
=
0
)
self
.
_delete_ss_dec
()
return
ss
@
torch
.
no_grad
()
def
visualize_sample
(
self
,
x_0
:
Union
[
torch
.
Tensor
,
dict
]):
x_0
=
x_0
if
isinstance
(
x_0
,
torch
.
Tensor
)
else
x_0
[
'x_0'
]
x_0
=
self
.
decode_latent
(
x_0
.
cuda
())
renderer
=
VoxelRenderer
()
renderer
.
rendering_options
.
resolution
=
512
renderer
.
rendering_options
.
ssaa
=
4
# build camera
yaw
=
[
0
,
np
.
pi
/
2
,
np
.
pi
,
3
*
np
.
pi
/
2
]
yaw_offset
=
-
16
/
180
*
np
.
pi
yaw
=
[
y
+
yaw_offset
for
y
in
yaw
]
pitch
=
[
20
/
180
*
np
.
pi
for
_
in
range
(
4
)]
exts
,
ints
=
yaw_pitch_r_fov_to_extrinsics_intrinsics
(
yaw
,
pitch
,
2
,
30
)
images
=
[]
# Build each representation
x_0
=
x_0
.
cuda
()
for
i
in
range
(
x_0
.
shape
[
0
]):
coords
=
torch
.
nonzero
(
x_0
[
i
,
0
]
>
0
,
as_tuple
=
False
)
resolution
=
x_0
.
shape
[
-
1
]
color
=
coords
/
resolution
rep
=
Voxel
(
origin
=
[
-
0.5
,
-
0.5
,
-
0.5
],
voxel_size
=
1
/
resolution
,
coords
=
coords
,
attrs
=
color
,
layout
=
{
'color'
:
slice
(
0
,
3
),
}
)
image
=
torch
.
zeros
(
3
,
1024
,
1024
).
cuda
()
tile
=
[
2
,
2
]
for
j
,
(
ext
,
intr
)
in
enumerate
(
zip
(
exts
,
ints
)):
res
=
renderer
.
render
(
rep
,
ext
,
intr
,
colors_overwrite
=
color
)
image
[:,
512
*
(
j
//
tile
[
1
]):
512
*
(
j
//
tile
[
1
]
+
1
),
512
*
(
j
%
tile
[
1
]):
512
*
(
j
%
tile
[
1
]
+
1
)]
=
res
[
'color'
]
images
.
append
(
image
)
return
torch
.
stack
(
images
)
class
SparseStructureLatent
(
SparseStructureLatentVisMixin
,
StandardDatasetBase
):
"""
Sparse structure latent dataset
Args:
roots (str): path to the dataset
min_aesthetic_score (float): minimum aesthetic score
normalization (dict): normalization stats
pretrained_ss_dec (str): name of the pretrained sparse structure decoder
ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
"""
def
__init__
(
self
,
roots
:
str
,
*
,
min_aesthetic_score
:
float
=
5.0
,
normalization
:
Optional
[
dict
]
=
None
,
pretrained_ss_dec
:
str
=
'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16'
,
ss_dec_path
:
Optional
[
str
]
=
None
,
ss_dec_ckpt
:
Optional
[
str
]
=
None
,
):
self
.
min_aesthetic_score
=
min_aesthetic_score
self
.
normalization
=
normalization
self
.
value_range
=
(
0
,
1
)
super
().
__init__
(
roots
,
pretrained_ss_dec
=
pretrained_ss_dec
,
ss_dec_path
=
ss_dec_path
,
ss_dec_ckpt
=
ss_dec_ckpt
,
)
if
self
.
normalization
is
not
None
:
self
.
mean
=
torch
.
tensor
(
self
.
normalization
[
'mean'
]).
reshape
(
-
1
,
1
,
1
,
1
)
self
.
std
=
torch
.
tensor
(
self
.
normalization
[
'std'
]).
reshape
(
-
1
,
1
,
1
,
1
)
def
filter_metadata
(
self
,
metadata
):
stats
=
{}
metadata
=
metadata
[
metadata
[
'ss_latent_encoded'
]
==
True
]
stats
[
'With latent'
]
=
len
(
metadata
)
metadata
=
metadata
[
metadata
[
'aesthetic_score'
]
>=
self
.
min_aesthetic_score
]
stats
[
f
'Aesthetic score >=
{
self
.
min_aesthetic_score
}
'
]
=
len
(
metadata
)
return
metadata
,
stats
def
get_instance
(
self
,
root
,
instance
):
latent
=
np
.
load
(
os
.
path
.
join
(
root
[
'ss_latent'
],
f
'
{
instance
}
.npz'
))
z
=
torch
.
tensor
(
latent
[
'z'
]).
float
()
if
self
.
normalization
is
not
None
:
z
=
(
z
-
self
.
mean
)
/
self
.
std
pack
=
{
'x_0'
:
z
,
}
return
pack
class
ImageConditionedSparseStructureLatent
(
ImageConditionedMixin
,
SparseStructureLatent
):
"""
Image-conditioned sparse structure dataset
"""
pass
\ No newline at end of file
TRELLIS.2_DCU/trellis2/datasets/sparse_voxel_pbr.py
0 → 100644
View file @
f05e915f
import
os
import
io
from
typing
import
Union
import
numpy
as
np
import
pickle
import
torch
from
PIL
import
Image
import
o_voxel
import
utils3d
from
.components
import
StandardDatasetBase
from
..modules
import
sparse
as
sp
from
..renderers
import
VoxelRenderer
from
..representations
import
Voxel
from
..representations.mesh
import
MeshWithPbrMaterial
,
TextureFilterMode
,
TextureWrapMode
,
AlphaMode
,
PbrMaterial
,
Texture
from
..utils.data_utils
import
load_balanced_group_indices
def
is_power_of_two
(
n
:
int
)
->
bool
:
return
n
>
0
and
(
n
&
(
n
-
1
))
==
0
def
nearest_power_of_two
(
n
:
int
)
->
int
:
if
n
<
1
:
raise
ValueError
(
"n must be >= 1"
)
if
is_power_of_two
(
n
):
return
n
lower
=
2
**
(
n
.
bit_length
()
-
1
)
upper
=
2
**
n
.
bit_length
()
if
n
-
lower
<
upper
-
n
:
return
lower
else
:
return
upper
class
SparseVoxelPbrVisMixin
:
@
torch
.
no_grad
()
def
visualize_sample
(
self
,
x
:
Union
[
sp
.
SparseTensor
,
dict
]):
x
=
x
if
isinstance
(
x
,
sp
.
SparseTensor
)
else
x
[
'x'
]
renderer
=
VoxelRenderer
()
renderer
.
rendering_options
.
resolution
=
512
renderer
.
rendering_options
.
ssaa
=
4
# Build camera
yaws
=
[
0
,
np
.
pi
/
2
,
np
.
pi
,
3
*
np
.
pi
/
2
]
yaws_offset
=
np
.
random
.
uniform
(
-
np
.
pi
/
4
,
np
.
pi
/
4
)
yaws
=
[
y
+
yaws_offset
for
y
in
yaws
]
pitch
=
[
np
.
random
.
uniform
(
-
np
.
pi
/
4
,
np
.
pi
/
4
)
for
_
in
range
(
4
)]
exts
=
[]
ints
=
[]
for
yaw
,
pitch
in
zip
(
yaws
,
pitch
):
orig
=
torch
.
tensor
([
np
.
sin
(
yaw
)
*
np
.
cos
(
pitch
),
np
.
cos
(
yaw
)
*
np
.
cos
(
pitch
),
np
.
sin
(
pitch
),
]).
float
().
cuda
()
*
2
fov
=
torch
.
deg2rad
(
torch
.
tensor
(
30
)).
cuda
()
extrinsics
=
utils3d
.
torch
.
extrinsics_look_at
(
orig
,
torch
.
tensor
([
0
,
0
,
0
]).
float
().
cuda
(),
torch
.
tensor
([
0
,
0
,
1
]).
float
().
cuda
())
intrinsics
=
utils3d
.
torch
.
intrinsics_from_fov_xy
(
fov
,
fov
)
exts
.
append
(
extrinsics
)
ints
.
append
(
intrinsics
)
images
=
{
k
:
[]
for
k
in
self
.
layout
}
# Build each representation
x
=
x
.
cuda
()
for
i
in
range
(
x
.
shape
[
0
]):
rep
=
Voxel
(
origin
=
[
-
0.5
,
-
0.5
,
-
0.5
],
voxel_size
=
1
/
self
.
resolution
,
coords
=
x
[
i
].
coords
[:,
1
:].
contiguous
(),
attrs
=
None
,
layout
=
{
'color'
:
slice
(
0
,
3
),
}
)
for
k
in
self
.
layout
:
image
=
torch
.
zeros
(
3
,
1024
,
1024
).
cuda
()
tile
=
[
2
,
2
]
for
j
,
(
ext
,
intr
)
in
enumerate
(
zip
(
exts
,
ints
)):
attr
=
x
[
i
].
feats
[:,
self
.
layout
[
k
]].
expand
(
-
1
,
3
)
res
=
renderer
.
render
(
rep
,
ext
,
intr
,
colors_overwrite
=
attr
)
image
[:,
512
*
(
j
//
tile
[
1
]):
512
*
(
j
//
tile
[
1
]
+
1
),
512
*
(
j
%
tile
[
1
]):
512
*
(
j
%
tile
[
1
]
+
1
)]
=
res
[
'color'
]
images
[
k
].
append
(
image
)
for
k
in
self
.
layout
:
images
[
k
]
=
torch
.
stack
(
images
[
k
])
return
images
class
SparseVoxelPbrDataset
(
SparseVoxelPbrVisMixin
,
StandardDatasetBase
):
"""
Sparse Voxel PBR dataset.
Args:
roots (str): path to the dataset
resolution (int): resolution of the voxel grid
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
"""
def
__init__
(
self
,
roots
,
resolution
:
int
=
1024
,
max_active_voxels
:
int
=
1000000
,
max_num_faces
:
int
=
None
,
min_aesthetic_score
:
float
=
5.0
,
attrs
:
list
[
str
]
=
[
'base_color'
,
'metallic'
,
'roughness'
,
'emissive'
,
'alpha'
],
with_mesh
:
bool
=
True
,
):
self
.
resolution
=
resolution
self
.
min_aesthetic_score
=
min_aesthetic_score
self
.
max_active_voxels
=
max_active_voxels
self
.
max_num_faces
=
max_num_faces
self
.
with_mesh
=
with_mesh
self
.
value_range
=
(
-
1
,
1
)
self
.
channels
=
{
'base_color'
:
3
,
'metallic'
:
1
,
'roughness'
:
1
,
'emissive'
:
3
,
'alpha'
:
1
,
}
self
.
layout
=
{}
start
=
0
for
attr
in
attrs
:
self
.
layout
[
attr
]
=
slice
(
start
,
start
+
self
.
channels
[
attr
])
start
+=
self
.
channels
[
attr
]
super
().
__init__
(
roots
)
self
.
loads
=
[
self
.
metadata
.
loc
[
sha256
,
f
'num_pbr_voxels'
]
for
_
,
sha256
in
self
.
instances
]
def
__str__
(
self
):
lines
=
[
super
().
__str__
(),
f
' - Resolution:
{
self
.
resolution
}
'
,
f
' - Attributes:
{
list
(
self
.
layout
.
keys
())
}
'
,
]
return
'
\n
'
.
join
(
lines
)
def
filter_metadata
(
self
,
metadata
):
stats
=
{}
metadata
=
metadata
[
metadata
[
'pbr_voxelized'
]
==
True
]
stats
[
'PBR Voxelized'
]
=
len
(
metadata
)
if
self
.
min_aesthetic_score
is
not
None
:
metadata
=
metadata
[
metadata
[
'aesthetic_score'
]
>=
self
.
min_aesthetic_score
]
stats
[
f
'Aesthetic score >=
{
self
.
min_aesthetic_score
}
'
]
=
len
(
metadata
)
metadata
=
metadata
[
metadata
[
'num_pbr_voxels'
]
<=
self
.
max_active_voxels
]
stats
[
f
'Active voxels <=
{
self
.
max_active_voxels
}
'
]
=
len
(
metadata
)
if
self
.
max_num_faces
is
not
None
:
metadata
=
metadata
[
metadata
[
'num_faces'
]
<=
self
.
max_num_faces
]
stats
[
f
'Faces <=
{
self
.
max_num_faces
}
'
]
=
len
(
metadata
)
return
metadata
,
stats
@
staticmethod
def
_texture_from_dump
(
pack
)
->
Texture
:
png_bytes
=
pack
[
'image'
]
image
=
Image
.
open
(
io
.
BytesIO
(
png_bytes
))
if
image
.
width
!=
image
.
height
or
not
is_power_of_two
(
image
.
width
):
size
=
nearest_power_of_two
(
max
(
image
.
width
,
image
.
height
))
image
=
image
.
resize
((
size
,
size
),
Image
.
LANCZOS
)
texture
=
torch
.
tensor
(
np
.
array
(
image
)
/
255.0
,
dtype
=
torch
.
float32
).
reshape
(
image
.
height
,
image
.
width
,
-
1
)
filter_mode
=
{
'Linear'
:
TextureFilterMode
.
LINEAR
,
'Closest'
:
TextureFilterMode
.
CLOSEST
,
'Cubic'
:
TextureFilterMode
.
LINEAR
,
'Smart'
:
TextureFilterMode
.
LINEAR
,
}[
pack
[
'interpolation'
]]
wrap_mode
=
{
'REPEAT'
:
TextureWrapMode
.
REPEAT
,
'EXTEND'
:
TextureWrapMode
.
CLAMP_TO_EDGE
,
'CLIP'
:
TextureWrapMode
.
CLAMP_TO_EDGE
,
'MIRROR'
:
TextureWrapMode
.
MIRRORED_REPEAT
,
}[
pack
[
'extension'
]]
return
Texture
(
texture
,
filter_mode
=
filter_mode
,
wrap_mode
=
wrap_mode
)
def
read_mesh_with_texture
(
self
,
root
,
instance
):
with
open
(
os
.
path
.
join
(
root
,
f
'
{
instance
}
.pickle'
),
'rb'
)
as
f
:
dump
=
pickle
.
load
(
f
)
# Fix dump alpha map
for
mat
in
dump
[
'materials'
]:
if
mat
[
'alphaTexture'
]
is
not
None
and
mat
[
'alphaMode'
]
==
'OPAQUE'
:
mat
[
'alphaMode'
]
=
'BLEND'
# process material
materials
=
[]
for
mat
in
dump
[
'materials'
]:
materials
.
append
(
PbrMaterial
(
base_color_texture
=
self
.
_texture_from_dump
(
mat
[
'baseColorTexture'
])
if
mat
[
'baseColorTexture'
]
is
not
None
else
None
,
base_color_factor
=
mat
[
'baseColorFactor'
],
metallic_texture
=
self
.
_texture_from_dump
(
mat
[
'metallicTexture'
])
if
mat
[
'metallicTexture'
]
is
not
None
else
None
,
metallic_factor
=
mat
[
'metallicFactor'
],
roughness_texture
=
self
.
_texture_from_dump
(
mat
[
'roughnessTexture'
])
if
mat
[
'roughnessTexture'
]
is
not
None
else
None
,
roughness_factor
=
mat
[
'roughnessFactor'
],
alpha_texture
=
self
.
_texture_from_dump
(
mat
[
'alphaTexture'
])
if
mat
[
'alphaTexture'
]
is
not
None
else
None
,
alpha_factor
=
mat
[
'alphaFactor'
],
alpha_mode
=
{
'OPAQUE'
:
AlphaMode
.
OPAQUE
,
'MASK'
:
AlphaMode
.
MASK
,
'BLEND'
:
AlphaMode
.
BLEND
,
}[
mat
[
'alphaMode'
]],
alpha_cutoff
=
mat
[
'alphaCutoff'
],
))
materials
.
append
(
PbrMaterial
(
base_color_factor
=
[
0.8
,
0.8
,
0.8
],
alpha_factor
=
1.0
,
metallic_factor
=
0.0
,
roughness_factor
=
0.5
,
alpha_mode
=
AlphaMode
.
OPAQUE
,
alpha_cutoff
=
0.5
,
))
# append default material
# process mesh
start
=
0
vertices
=
[]
faces
=
[]
material_ids
=
[]
uv_coords
=
[]
for
obj
in
dump
[
'objects'
]:
if
obj
[
'vertices'
].
size
==
0
or
obj
[
'faces'
].
size
==
0
:
continue
vertices
.
append
(
obj
[
'vertices'
])
faces
.
append
(
obj
[
'faces'
]
+
start
)
obj
[
'mat_ids'
][
obj
[
'mat_ids'
]
==
-
1
]
=
len
(
materials
)
-
1
material_ids
.
append
(
obj
[
'mat_ids'
])
uv_coords
.
append
(
obj
[
'uvs'
]
if
obj
[
'uvs'
]
is
not
None
else
np
.
zeros
((
obj
[
'faces'
].
shape
[
0
],
3
,
2
),
dtype
=
np
.
float32
))
start
+=
len
(
obj
[
'vertices'
])
vertices
=
torch
.
from_numpy
(
np
.
concatenate
(
vertices
,
axis
=
0
)).
float
()
faces
=
torch
.
from_numpy
(
np
.
concatenate
(
faces
,
axis
=
0
)).
long
()
material_ids
=
torch
.
from_numpy
(
np
.
concatenate
(
material_ids
,
axis
=
0
)).
long
()
uv_coords
=
torch
.
from_numpy
(
np
.
concatenate
(
uv_coords
,
axis
=
0
)).
float
()
# Normalize vertices
vertices_min
=
vertices
.
min
(
dim
=
0
)[
0
]
vertices_max
=
vertices
.
max
(
dim
=
0
)[
0
]
center
=
(
vertices_min
+
vertices_max
)
/
2
scale
=
0.99999
/
(
vertices_max
-
vertices_min
).
max
()
vertices
=
(
vertices
-
center
)
*
scale
assert
torch
.
all
(
vertices
>=
-
0.5
)
and
torch
.
all
(
vertices
<=
0.5
),
'vertices out of range'
return
{
'mesh'
:
[
MeshWithPbrMaterial
(
vertices
=
vertices
,
faces
=
faces
,
material_ids
=
material_ids
,
uv_coords
=
uv_coords
,
materials
=
materials
,
)]}
def
read_pbr_voxel
(
self
,
root
,
instance
):
coords
,
attr
=
o_voxel
.
io
.
read_vxz
(
os
.
path
.
join
(
root
,
f
'
{
instance
}
.vxz'
),
num_threads
=
4
)
feats
=
torch
.
concat
([
attr
[
k
]
for
k
in
self
.
layout
],
dim
=-
1
)
/
255.0
*
2
-
1
x
=
sp
.
SparseTensor
(
feats
.
float
(),
torch
.
cat
([
torch
.
zeros_like
(
coords
[:,
0
:
1
]),
coords
],
dim
=-
1
),
)
return
{
'x'
:
x
}
def
get_instance
(
self
,
root
,
instance
):
if
self
.
with_mesh
:
mesh
=
self
.
read_mesh_with_texture
(
root
[
'pbr_dump'
],
instance
)
pbr_voxel
=
self
.
read_pbr_voxel
(
root
[
'pbr_voxel'
],
instance
)
return
{
**
mesh
,
**
pbr_voxel
}
else
:
return
self
.
read_pbr_voxel
(
root
[
'pbr_voxel'
],
instance
)
@
staticmethod
def
collate_fn
(
batch
,
split_size
=
None
):
if
split_size
is
None
:
group_idx
=
[
list
(
range
(
len
(
batch
)))]
else
:
group_idx
=
load_balanced_group_indices
([
b
[
'x'
].
feats
.
shape
[
0
]
for
b
in
batch
],
split_size
)
packs
=
[]
for
group
in
group_idx
:
sub_batch
=
[
batch
[
i
]
for
i
in
group
]
pack
=
{}
keys
=
[
k
for
k
in
sub_batch
[
0
].
keys
()]
for
k
in
keys
:
if
isinstance
(
sub_batch
[
0
][
k
],
torch
.
Tensor
):
pack
[
k
]
=
torch
.
stack
([
b
[
k
]
for
b
in
sub_batch
])
elif
isinstance
(
sub_batch
[
0
][
k
],
sp
.
SparseTensor
):
pack
[
k
]
=
sp
.
sparse_cat
([
b
[
k
]
for
b
in
sub_batch
],
dim
=
0
)
elif
isinstance
(
sub_batch
[
0
][
k
],
list
):
pack
[
k
]
=
sum
([
b
[
k
]
for
b
in
sub_batch
],
[])
else
:
pack
[
k
]
=
[
b
[
k
]
for
b
in
sub_batch
]
packs
.
append
(
pack
)
if
split_size
is
None
:
return
packs
[
0
]
return
packs
TRELLIS.2_DCU/trellis2/datasets/structured_latent.py
0 → 100644
View file @
f05e915f
import
json
import
os
from
typing
import
*
import
numpy
as
np
import
torch
import
utils3d.torch
from
.components
import
StandardDatasetBase
,
ImageConditionedMixin
from
..modules.sparse.basic
import
SparseTensor
from
..
import
models
from
..utils.render_utils
import
get_renderer
from
..utils.data_utils
import
load_balanced_group_indices
class
SLatVisMixin
:
def
__init__
(
self
,
*
args
,
pretrained_slat_dec
:
str
=
'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16'
,
slat_dec_path
:
Optional
[
str
]
=
None
,
slat_dec_ckpt
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
slat_dec
=
None
self
.
pretrained_slat_dec
=
pretrained_slat_dec
self
.
slat_dec_path
=
slat_dec_path
self
.
slat_dec_ckpt
=
slat_dec_ckpt
def
_loading_slat_dec
(
self
):
if
self
.
slat_dec
is
not
None
:
return
if
self
.
slat_dec_path
is
not
None
:
cfg
=
json
.
load
(
open
(
os
.
path
.
join
(
self
.
slat_dec_path
,
'config.json'
),
'r'
))
decoder
=
getattr
(
models
,
cfg
[
'models'
][
'decoder'
][
'name'
])(
**
cfg
[
'models'
][
'decoder'
][
'args'
])
ckpt_path
=
os
.
path
.
join
(
self
.
slat_dec_path
,
'ckpts'
,
f
'decoder_
{
self
.
slat_dec_ckpt
}
.pt'
)
decoder
.
load_state_dict
(
torch
.
load
(
ckpt_path
,
map_location
=
'cpu'
,
weights_only
=
True
))
else
:
decoder
=
models
.
from_pretrained
(
self
.
pretrained_slat_dec
)
self
.
slat_dec
=
decoder
.
cuda
().
eval
()
def
_delete_slat_dec
(
self
):
del
self
.
slat_dec
self
.
slat_dec
=
None
@
torch
.
no_grad
()
def
decode_latent
(
self
,
z
,
batch_size
=
4
):
self
.
_loading_slat_dec
()
reps
=
[]
if
self
.
normalization
is
not
None
:
z
=
z
*
self
.
std
.
to
(
z
.
device
)
+
self
.
mean
.
to
(
z
.
device
)
for
i
in
range
(
0
,
z
.
shape
[
0
],
batch_size
):
reps
.
append
(
self
.
slat_dec
(
z
[
i
:
i
+
batch_size
]))
reps
=
sum
(
reps
,
[])
self
.
_delete_slat_dec
()
return
reps
@
torch
.
no_grad
()
def
visualize_sample
(
self
,
x_0
:
Union
[
SparseTensor
,
dict
]):
x_0
=
x_0
if
isinstance
(
x_0
,
SparseTensor
)
else
x_0
[
'x_0'
]
reps
=
self
.
decode_latent
(
x_0
.
cuda
())
# Build camera
yaws
=
[
0
,
np
.
pi
/
2
,
np
.
pi
,
3
*
np
.
pi
/
2
]
yaws_offset
=
np
.
random
.
uniform
(
-
np
.
pi
/
4
,
np
.
pi
/
4
)
yaws
=
[
y
+
yaws_offset
for
y
in
yaws
]
pitch
=
[
np
.
random
.
uniform
(
-
np
.
pi
/
4
,
np
.
pi
/
4
)
for
_
in
range
(
4
)]
exts
=
[]
ints
=
[]
for
yaw
,
pitch
in
zip
(
yaws
,
pitch
):
orig
=
torch
.
tensor
([
np
.
sin
(
yaw
)
*
np
.
cos
(
pitch
),
np
.
cos
(
yaw
)
*
np
.
cos
(
pitch
),
np
.
sin
(
pitch
),
]).
float
().
cuda
()
*
2
fov
=
torch
.
deg2rad
(
torch
.
tensor
(
40
)).
cuda
()
extrinsics
=
utils3d
.
torch
.
extrinsics_look_at
(
orig
,
torch
.
tensor
([
0
,
0
,
0
]).
float
().
cuda
(),
torch
.
tensor
([
0
,
0
,
1
]).
float
().
cuda
())
intrinsics
=
utils3d
.
torch
.
intrinsics_from_fov_xy
(
fov
,
fov
)
exts
.
append
(
extrinsics
)
ints
.
append
(
intrinsics
)
renderer
=
get_renderer
(
reps
[
0
])
images
=
[]
for
representation
in
reps
:
image
=
torch
.
zeros
(
3
,
1024
,
1024
).
cuda
()
tile
=
[
2
,
2
]
for
j
,
(
ext
,
intr
)
in
enumerate
(
zip
(
exts
,
ints
)):
res
=
renderer
.
render
(
representation
,
ext
,
intr
)
image
[:,
512
*
(
j
//
tile
[
1
]):
512
*
(
j
//
tile
[
1
]
+
1
),
512
*
(
j
%
tile
[
1
]):
512
*
(
j
%
tile
[
1
]
+
1
)]
=
res
[
'color'
]
images
.
append
(
image
)
images
=
torch
.
stack
(
images
)
return
images
class
SLat
(
SLatVisMixin
,
StandardDatasetBase
):
"""
structured latent V2 dataset
Args:
roots (str): path to the dataset
min_aesthetic_score (float): minimum aesthetic score
max_tokens (int): maximum number of tokens
latent_key (str): key of the latent to be used
normalization (dict): normalization stats
pretrained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
"""
def
__init__
(
self
,
roots
:
str
,
*
,
min_aesthetic_score
:
float
=
5.0
,
max_tokens
:
int
=
32768
,
latent_key
:
str
=
'shape_latent'
,
normalization
:
Optional
[
dict
]
=
None
,
pretrained_slat_dec
:
str
=
'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16'
,
slat_dec_path
:
Optional
[
str
]
=
None
,
slat_dec_ckpt
:
Optional
[
str
]
=
None
,
):
self
.
normalization
=
normalization
self
.
min_aesthetic_score
=
min_aesthetic_score
self
.
max_tokens
=
max_tokens
self
.
latent_key
=
latent_key
self
.
value_range
=
(
0
,
1
)
super
().
__init__
(
roots
,
pretrained_slat_dec
=
pretrained_slat_dec
,
slat_dec_path
=
slat_dec_path
,
slat_dec_ckpt
=
slat_dec_ckpt
,
)
self
.
loads
=
[
self
.
metadata
.
loc
[
sha256
,
f
'
{
latent_key
}
_tokens'
]
for
_
,
sha256
in
self
.
instances
]
if
self
.
normalization
is
not
None
:
self
.
mean
=
torch
.
tensor
(
self
.
normalization
[
'mean'
]).
reshape
(
1
,
-
1
)
self
.
std
=
torch
.
tensor
(
self
.
normalization
[
'std'
]).
reshape
(
1
,
-
1
)
def
filter_metadata
(
self
,
metadata
):
stats
=
{}
metadata
=
metadata
[
metadata
[
f
'
{
self
.
latent_key
}
_encoded'
]
==
True
]
stats
[
'With latent'
]
=
len
(
metadata
)
metadata
=
metadata
[
metadata
[
'aesthetic_score'
]
>=
self
.
min_aesthetic_score
]
stats
[
f
'Aesthetic score >=
{
self
.
min_aesthetic_score
}
'
]
=
len
(
metadata
)
metadata
=
metadata
[
metadata
[
f
'
{
self
.
latent_key
}
_tokens'
]
<=
self
.
max_tokens
]
stats
[
f
'Num tokens <=
{
self
.
max_tokens
}
'
]
=
len
(
metadata
)
return
metadata
,
stats
def
get_instance
(
self
,
root
,
instance
):
data
=
np
.
load
(
os
.
path
.
join
(
root
[
self
.
latent_key
],
f
'
{
instance
}
.npz'
))
coords
=
torch
.
tensor
(
data
[
'coords'
]).
int
()
feats
=
torch
.
tensor
(
data
[
'feats'
]).
float
()
if
self
.
normalization
is
not
None
:
feats
=
(
feats
-
self
.
mean
)
/
self
.
std
return
{
'coords'
:
coords
,
'feats'
:
feats
,
}
@
staticmethod
def
collate_fn
(
batch
,
split_size
=
None
):
if
split_size
is
None
:
group_idx
=
[
list
(
range
(
len
(
batch
)))]
else
:
group_idx
=
load_balanced_group_indices
([
b
[
'coords'
].
shape
[
0
]
for
b
in
batch
],
split_size
)
packs
=
[]
for
group
in
group_idx
:
sub_batch
=
[
batch
[
i
]
for
i
in
group
]
pack
=
{}
coords
=
[]
feats
=
[]
layout
=
[]
start
=
0
for
i
,
b
in
enumerate
(
sub_batch
):
coords
.
append
(
torch
.
cat
([
torch
.
full
((
b
[
'coords'
].
shape
[
0
],
1
),
i
,
dtype
=
torch
.
int32
),
b
[
'coords'
]],
dim
=-
1
))
feats
.
append
(
b
[
'feats'
])
layout
.
append
(
slice
(
start
,
start
+
b
[
'coords'
].
shape
[
0
]))
start
+=
b
[
'coords'
].
shape
[
0
]
coords
=
torch
.
cat
(
coords
)
feats
=
torch
.
cat
(
feats
)
pack
[
'x_0'
]
=
SparseTensor
(
coords
=
coords
,
feats
=
feats
,
)
pack
[
'x_0'
].
_shape
=
torch
.
Size
([
len
(
group
),
*
sub_batch
[
0
][
'feats'
].
shape
[
1
:]])
pack
[
'x_0'
].
register_spatial_cache
(
'layout'
,
layout
)
# collate other data
keys
=
[
k
for
k
in
sub_batch
[
0
].
keys
()
if
k
not
in
[
'coords'
,
'feats'
]]
for
k
in
keys
:
if
isinstance
(
sub_batch
[
0
][
k
],
torch
.
Tensor
):
pack
[
k
]
=
torch
.
stack
([
b
[
k
]
for
b
in
sub_batch
])
elif
isinstance
(
sub_batch
[
0
][
k
],
list
):
pack
[
k
]
=
sum
([
b
[
k
]
for
b
in
sub_batch
],
[])
else
:
pack
[
k
]
=
[
b
[
k
]
for
b
in
sub_batch
]
packs
.
append
(
pack
)
if
split_size
is
None
:
return
packs
[
0
]
return
packs
class
ImageConditionedSLat
(
ImageConditionedMixin
,
SLat
):
"""
Image conditioned structured latent dataset
"""
pass
TRELLIS.2_DCU/trellis2/datasets/structured_latent_shape.py
0 → 100644
View file @
f05e915f
import
os
import
json
from
typing
import
*
import
numpy
as
np
import
torch
from
..
import
models
from
.components
import
ImageConditionedMixin
from
..modules.sparse
import
SparseTensor
from
.structured_latent
import
SLatVisMixin
,
SLat
from
..utils.render_utils
import
get_renderer
,
yaw_pitch_r_fov_to_extrinsics_intrinsics
class
SLatShapeVisMixin
(
SLatVisMixin
):
def
_loading_slat_dec
(
self
):
if
self
.
slat_dec
is
not
None
:
return
if
self
.
slat_dec_path
is
not
None
:
cfg
=
json
.
load
(
open
(
os
.
path
.
join
(
self
.
slat_dec_path
,
'config.json'
),
'r'
))
decoder
=
getattr
(
models
,
cfg
[
'models'
][
'decoder'
][
'name'
])(
**
cfg
[
'models'
][
'decoder'
][
'args'
])
ckpt_path
=
os
.
path
.
join
(
self
.
slat_dec_path
,
'ckpts'
,
f
'decoder_
{
self
.
slat_dec_ckpt
}
.pt'
)
decoder
.
load_state_dict
(
torch
.
load
(
ckpt_path
,
map_location
=
'cpu'
,
weights_only
=
True
))
else
:
decoder
=
models
.
from_pretrained
(
self
.
pretrained_slat_dec
)
decoder
.
set_resolution
(
self
.
resolution
)
self
.
slat_dec
=
decoder
.
cuda
().
eval
()
@
torch
.
no_grad
()
def
visualize_sample
(
self
,
x_0
:
Union
[
SparseTensor
,
dict
]):
x_0
=
x_0
if
isinstance
(
x_0
,
SparseTensor
)
else
x_0
[
'x_0'
]
reps
=
self
.
decode_latent
(
x_0
.
cuda
())
# build camera
yaw
=
[
0
,
np
.
pi
/
2
,
np
.
pi
,
3
*
np
.
pi
/
2
]
yaw_offset
=
-
16
/
180
*
np
.
pi
yaw
=
[
y
+
yaw_offset
for
y
in
yaw
]
pitch
=
[
20
/
180
*
np
.
pi
for
_
in
range
(
4
)]
exts
,
ints
=
yaw_pitch_r_fov_to_extrinsics_intrinsics
(
yaw
,
pitch
,
2
,
30
)
# render
renderer
=
get_renderer
(
reps
[
0
])
images
=
[]
for
representation
in
reps
:
image
=
torch
.
zeros
(
3
,
1024
,
1024
).
cuda
()
tile
=
[
2
,
2
]
for
j
,
(
ext
,
intr
)
in
enumerate
(
zip
(
exts
,
ints
)):
res
=
renderer
.
render
(
representation
,
ext
,
intr
)
image
[:,
512
*
(
j
//
tile
[
1
]):
512
*
(
j
//
tile
[
1
]
+
1
),
512
*
(
j
%
tile
[
1
]):
512
*
(
j
%
tile
[
1
]
+
1
)]
=
res
[
'normal'
]
images
.
append
(
image
)
images
=
torch
.
stack
(
images
)
return
images
class
SLatShape
(
SLatShapeVisMixin
,
SLat
):
"""
structured latent for shape generation
Args:
roots (str): path to the dataset
resolution (int): resolution of the shape
min_aesthetic_score (float): minimum aesthetic score
max_tokens (int): maximum number of tokens
latent_key (str): key of the latent to be used
normalization (dict): normalization stats
pretrained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
"""
def
__init__
(
self
,
roots
:
str
,
*
,
resolution
:
int
,
min_aesthetic_score
:
float
=
5.0
,
max_tokens
:
int
=
32768
,
normalization
:
Optional
[
dict
]
=
None
,
pretrained_slat_dec
:
str
=
'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16'
,
slat_dec_path
:
Optional
[
str
]
=
None
,
slat_dec_ckpt
:
Optional
[
str
]
=
None
,
):
super
().
__init__
(
roots
,
min_aesthetic_score
=
min_aesthetic_score
,
max_tokens
=
max_tokens
,
latent_key
=
'shape_latent'
,
normalization
=
normalization
,
pretrained_slat_dec
=
pretrained_slat_dec
,
slat_dec_path
=
slat_dec_path
,
slat_dec_ckpt
=
slat_dec_ckpt
,
)
self
.
resolution
=
resolution
class
ImageConditionedSLatShape
(
ImageConditionedMixin
,
SLatShape
):
"""
Image conditioned structured latent for shape generation
"""
pass
TRELLIS.2_DCU/trellis2/datasets/structured_latent_svpbr.py
0 → 100644
View file @
f05e915f
import
os
os
.
environ
[
'OPENCV_IO_ENABLE_OPENEXR'
]
=
'1'
import
json
from
typing
import
*
import
numpy
as
np
import
torch
import
cv2
from
..
import
models
from
.components
import
StandardDatasetBase
,
ImageConditionedMixin
from
..modules.sparse
import
SparseTensor
,
sparse_cat
from
..representations
import
MeshWithVoxel
from
..renderers
import
PbrMeshRenderer
,
EnvMap
from
..utils.data_utils
import
load_balanced_group_indices
from
..utils.render_utils
import
yaw_pitch_r_fov_to_extrinsics_intrinsics
class
SLatPbrVisMixin
:
def
__init__
(
self
,
*
args
,
pretrained_pbr_slat_dec
:
str
=
'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16'
,
pbr_slat_dec_path
:
Optional
[
str
]
=
None
,
pbr_slat_dec_ckpt
:
Optional
[
str
]
=
None
,
pretrained_shape_slat_dec
:
str
=
'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16'
,
shape_slat_dec_path
:
Optional
[
str
]
=
None
,
shape_slat_dec_ckpt
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
pbr_slat_dec
=
None
self
.
pretrained_pbr_slat_dec
=
pretrained_pbr_slat_dec
self
.
pbr_slat_dec_path
=
pbr_slat_dec_path
self
.
pbr_slat_dec_ckpt
=
pbr_slat_dec_ckpt
self
.
shape_slat_dec
=
None
self
.
pretrained_shape_slat_dec
=
pretrained_shape_slat_dec
self
.
shape_slat_dec_path
=
shape_slat_dec_path
self
.
shape_slat_dec_ckpt
=
shape_slat_dec_ckpt
def
_loading_slat_dec
(
self
):
if
self
.
pbr_slat_dec
is
not
None
and
self
.
shape_slat_dec
is
not
None
:
return
if
self
.
pbr_slat_dec_path
is
not
None
:
cfg
=
json
.
load
(
open
(
os
.
path
.
join
(
self
.
pbr_slat_dec_path
,
'config.json'
),
'r'
))
decoder
=
getattr
(
models
,
cfg
[
'models'
][
'decoder'
][
'name'
])(
**
cfg
[
'models'
][
'decoder'
][
'args'
])
ckpt_path
=
os
.
path
.
join
(
self
.
pbr_slat_dec_path
,
'ckpts'
,
f
'decoder_
{
self
.
pbr_slat_dec_ckpt
}
.pt'
)
decoder
.
load_state_dict
(
torch
.
load
(
ckpt_path
,
map_location
=
'cpu'
,
weights_only
=
True
))
else
:
decoder
=
models
.
from_pretrained
(
self
.
pretrained_pbr_slat_dec
)
self
.
pbr_slat_dec
=
decoder
.
cuda
().
eval
()
if
self
.
shape_slat_dec_path
is
not
None
:
cfg
=
json
.
load
(
open
(
os
.
path
.
join
(
self
.
shape_slat_dec_path
,
'config.json'
),
'r'
))
decoder
=
getattr
(
models
,
cfg
[
'models'
][
'decoder'
][
'name'
])(
**
cfg
[
'models'
][
'decoder'
][
'args'
])
ckpt_path
=
os
.
path
.
join
(
self
.
shape_slat_dec_path
,
'ckpts'
,
f
'decoder_
{
self
.
shape_slat_dec_ckpt
}
.pt'
)
decoder
.
load_state_dict
(
torch
.
load
(
ckpt_path
,
map_location
=
'cpu'
,
weights_only
=
True
))
else
:
decoder
=
models
.
from_pretrained
(
self
.
pretrained_shape_slat_dec
)
decoder
.
set_resolution
(
self
.
resolution
)
self
.
shape_slat_dec
=
decoder
.
cuda
().
eval
()
def
_delete_slat_dec
(
self
):
del
self
.
pbr_slat_dec
self
.
pbr_slat_dec
=
None
del
self
.
shape_slat_dec
self
.
shape_slat_dec
=
None
@
torch
.
no_grad
()
def
decode_latent
(
self
,
z
,
shape_z
,
batch_size
=
4
):
self
.
_loading_slat_dec
()
reps
=
[]
if
self
.
shape_slat_normalization
is
not
None
:
shape_z
=
shape_z
*
self
.
shape_slat_std
.
to
(
z
.
device
)
+
self
.
shape_slat_mean
.
to
(
z
.
device
)
if
self
.
pbr_slat_normalization
is
not
None
:
z
=
z
*
self
.
pbr_slat_std
.
to
(
z
.
device
)
+
self
.
pbr_slat_mean
.
to
(
z
.
device
)
for
i
in
range
(
0
,
z
.
shape
[
0
],
batch_size
):
mesh
,
subs
=
self
.
shape_slat_dec
(
shape_z
[
i
:
i
+
batch_size
],
return_subs
=
True
)
vox
=
self
.
pbr_slat_dec
(
z
[
i
:
i
+
batch_size
],
guide_subs
=
subs
)
*
0.5
+
0.5
reps
.
extend
([
MeshWithVoxel
(
m
.
vertices
,
m
.
faces
,
origin
=
[
-
0.5
,
-
0.5
,
-
0.5
],
voxel_size
=
1
/
self
.
resolution
,
coords
=
v
.
coords
[:,
1
:],
attrs
=
v
.
feats
,
voxel_shape
=
torch
.
Size
([
*
v
.
shape
,
*
v
.
spatial_shape
]),
layout
=
self
.
layout
,
)
for
m
,
v
in
zip
(
mesh
,
vox
)
])
self
.
_delete_slat_dec
()
return
reps
@
torch
.
no_grad
()
def
visualize_sample
(
self
,
sample
:
dict
):
shape_z
=
sample
[
'concat_cond'
].
cuda
()
z
=
sample
[
'x_0'
].
cuda
()
reps
=
self
.
decode_latent
(
z
,
shape_z
)
# build camera
yaw
=
[
0
,
np
.
pi
/
2
,
np
.
pi
,
3
*
np
.
pi
/
2
]
yaw_offset
=
-
16
/
180
*
np
.
pi
yaw
=
[
y
+
yaw_offset
for
y
in
yaw
]
pitch
=
[
20
/
180
*
np
.
pi
for
_
in
range
(
4
)]
exts
,
ints
=
yaw_pitch_r_fov_to_extrinsics_intrinsics
(
yaw
,
pitch
,
2
,
30
)
# render
renderer
=
PbrMeshRenderer
()
renderer
.
rendering_options
.
resolution
=
512
renderer
.
rendering_options
.
near
=
1
renderer
.
rendering_options
.
far
=
100
renderer
.
rendering_options
.
ssaa
=
2
renderer
.
rendering_options
.
peel_layers
=
8
envmap
=
EnvMap
(
torch
.
tensor
(
cv2
.
cvtColor
(
cv2
.
imread
(
'assets/hdri/forest.exr'
,
cv2
.
IMREAD_UNCHANGED
),
cv2
.
COLOR_BGR2RGB
),
dtype
=
torch
.
float32
,
device
=
'cuda'
))
images
=
{}
for
representation
in
reps
:
image
=
{}
tile
=
[
2
,
2
]
for
j
,
(
ext
,
intr
)
in
enumerate
(
zip
(
exts
,
ints
)):
res
=
renderer
.
render
(
representation
,
ext
,
intr
,
envmap
=
envmap
)
for
k
,
v
in
res
.
items
():
if
k
not
in
images
:
images
[
k
]
=
[]
if
k
not
in
image
:
image
[
k
]
=
torch
.
zeros
(
3
,
1024
,
1024
).
cuda
()
image
[
k
][:,
512
*
(
j
//
tile
[
1
]):
512
*
(
j
//
tile
[
1
]
+
1
),
512
*
(
j
%
tile
[
1
]):
512
*
(
j
%
tile
[
1
]
+
1
)]
=
v
for
k
in
images
.
keys
():
images
[
k
].
append
(
image
[
k
])
for
k
in
images
.
keys
():
images
[
k
]
=
torch
.
stack
(
images
[
k
],
dim
=
0
)
return
images
class
SLatPbr
(
SLatPbrVisMixin
,
StandardDatasetBase
):
"""
structured latent for sparse voxel pbr dataset
Args:
roots (str): path to the dataset
latent_key (str): key of the latent to be used
min_aesthetic_score (float): minimum aesthetic score
normalization (dict): normalization stats
resolution (int): resolution of decoded sparse voxel
attrs (list): attributes to be decoded
pretained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
"""
def
__init__
(
self
,
roots
:
str
,
*
,
resolution
:
int
,
min_aesthetic_score
:
float
=
5.0
,
max_tokens
:
int
=
32768
,
full_pbr
:
bool
=
False
,
pbr_slat_normalization
:
Optional
[
dict
]
=
None
,
shape_slat_normalization
:
Optional
[
dict
]
=
None
,
attrs
:
list
[
str
]
=
[
'base_color'
,
'metallic'
,
'roughness'
,
'emissive'
,
'alpha'
],
pretrained_pbr_slat_dec
:
str
=
'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16'
,
pbr_slat_dec_path
:
Optional
[
str
]
=
None
,
pbr_slat_dec_ckpt
:
Optional
[
str
]
=
None
,
pretrained_shape_slat_dec
:
str
=
'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16'
,
shape_slat_dec_path
:
Optional
[
str
]
=
None
,
shape_slat_dec_ckpt
:
Optional
[
str
]
=
None
,
**
kwargs
):
self
.
resolution
=
resolution
self
.
pbr_slat_normalization
=
pbr_slat_normalization
self
.
shape_slat_normalization
=
shape_slat_normalization
self
.
min_aesthetic_score
=
min_aesthetic_score
self
.
max_tokens
=
max_tokens
self
.
full_pbr
=
full_pbr
self
.
value_range
=
(
0
,
1
)
super
().
__init__
(
roots
,
pretrained_pbr_slat_dec
=
pretrained_pbr_slat_dec
,
pbr_slat_dec_path
=
pbr_slat_dec_path
,
pbr_slat_dec_ckpt
=
pbr_slat_dec_ckpt
,
pretrained_shape_slat_dec
=
pretrained_shape_slat_dec
,
shape_slat_dec_path
=
shape_slat_dec_path
,
shape_slat_dec_ckpt
=
shape_slat_dec_ckpt
,
**
kwargs
)
self
.
loads
=
[
self
.
metadata
.
loc
[
sha256
,
'pbr_latent_tokens'
]
for
_
,
sha256
in
self
.
instances
]
if
self
.
pbr_slat_normalization
is
not
None
:
self
.
pbr_slat_mean
=
torch
.
tensor
(
self
.
pbr_slat_normalization
[
'mean'
]).
reshape
(
1
,
-
1
)
self
.
pbr_slat_std
=
torch
.
tensor
(
self
.
pbr_slat_normalization
[
'std'
]).
reshape
(
1
,
-
1
)
if
self
.
shape_slat_normalization
is
not
None
:
self
.
shape_slat_mean
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'mean'
]).
reshape
(
1
,
-
1
)
self
.
shape_slat_std
=
torch
.
tensor
(
self
.
shape_slat_normalization
[
'std'
]).
reshape
(
1
,
-
1
)
self
.
attrs
=
attrs
self
.
channels
=
{
'base_color'
:
3
,
'metallic'
:
1
,
'roughness'
:
1
,
'emissive'
:
3
,
'alpha'
:
1
,
}
self
.
layout
=
{}
start
=
0
for
attr
in
attrs
:
self
.
layout
[
attr
]
=
slice
(
start
,
start
+
self
.
channels
[
attr
])
start
+=
self
.
channels
[
attr
]
def
filter_metadata
(
self
,
metadata
):
stats
=
{}
metadata
=
metadata
[
metadata
[
'pbr_latent_encoded'
]
==
True
]
stats
[
'With PBR latent'
]
=
len
(
metadata
)
metadata
=
metadata
[
metadata
[
'shape_latent_encoded'
]
==
True
]
stats
[
'With shape latent'
]
=
len
(
metadata
)
metadata
=
metadata
[
metadata
[
'aesthetic_score'
]
>=
self
.
min_aesthetic_score
]
stats
[
f
'Aesthetic score >=
{
self
.
min_aesthetic_score
}
'
]
=
len
(
metadata
)
metadata
=
metadata
[
metadata
[
'pbr_latent_tokens'
]
<=
self
.
max_tokens
]
stats
[
f
'Num tokens <=
{
self
.
max_tokens
}
'
]
=
len
(
metadata
)
if
self
.
full_pbr
:
metadata
=
metadata
[
metadata
[
'num_basecolor_tex'
]
>
0
]
metadata
=
metadata
[
metadata
[
'num_metallic_tex'
]
>
0
]
metadata
=
metadata
[
metadata
[
'num_roughness_tex'
]
>
0
]
stats
[
'Full PBR'
]
=
len
(
metadata
)
return
metadata
,
stats
def
get_instance
(
self
,
root
,
instance
):
# PBR latent
data
=
np
.
load
(
os
.
path
.
join
(
root
[
'pbr_latent'
],
f
'
{
instance
}
.npz'
))
coords
=
torch
.
tensor
(
data
[
'coords'
]).
int
()
coords
=
torch
.
cat
([
torch
.
zeros_like
(
coords
)[:,
:
1
],
coords
],
dim
=
1
)
feats
=
torch
.
tensor
(
data
[
'feats'
]).
float
()
if
self
.
pbr_slat_normalization
is
not
None
:
feats
=
(
feats
-
self
.
pbr_slat_mean
)
/
self
.
pbr_slat_std
pbr_z
=
SparseTensor
(
feats
,
coords
)
# Shape latent
data
=
np
.
load
(
os
.
path
.
join
(
root
[
'shape_latent'
],
f
'
{
instance
}
.npz'
))
coords
=
torch
.
tensor
(
data
[
'coords'
]).
int
()
coords
=
torch
.
cat
([
torch
.
zeros_like
(
coords
)[:,
:
1
],
coords
],
dim
=
1
)
feats
=
torch
.
tensor
(
data
[
'feats'
]).
float
()
if
self
.
shape_slat_normalization
is
not
None
:
feats
=
(
feats
-
self
.
shape_slat_mean
)
/
self
.
shape_slat_std
shape_z
=
SparseTensor
(
feats
,
coords
)
assert
torch
.
equal
(
shape_z
.
coords
,
pbr_z
.
coords
),
\
f
"Shape latent and PBR latent have different coordinates:
{
shape_z
.
coords
.
shape
}
vs
{
pbr_z
.
coords
.
shape
}
"
return
{
'x_0'
:
pbr_z
,
'concat_cond'
:
shape_z
,
}
@
staticmethod
def
collate_fn
(
batch
,
split_size
=
None
):
if
split_size
is
None
:
group_idx
=
[
list
(
range
(
len
(
batch
)))]
else
:
group_idx
=
load_balanced_group_indices
([
b
[
'x_0'
].
feats
.
shape
[
0
]
for
b
in
batch
],
split_size
)
packs
=
[]
for
group
in
group_idx
:
sub_batch
=
[
batch
[
i
]
for
i
in
group
]
pack
=
{}
keys
=
[
k
for
k
in
sub_batch
[
0
].
keys
()]
for
k
in
keys
:
if
isinstance
(
sub_batch
[
0
][
k
],
torch
.
Tensor
):
pack
[
k
]
=
torch
.
stack
([
b
[
k
]
for
b
in
sub_batch
])
elif
isinstance
(
sub_batch
[
0
][
k
],
SparseTensor
):
pack
[
k
]
=
sparse_cat
([
b
[
k
]
for
b
in
sub_batch
],
dim
=
0
)
elif
isinstance
(
sub_batch
[
0
][
k
],
list
):
pack
[
k
]
=
sum
([
b
[
k
]
for
b
in
sub_batch
],
[])
else
:
pack
[
k
]
=
[
b
[
k
]
for
b
in
sub_batch
]
packs
.
append
(
pack
)
if
split_size
is
None
:
return
packs
[
0
]
return
packs
class
ImageConditionedSLatPbr
(
ImageConditionedMixin
,
SLatPbr
):
"""
Image conditioned structured latent dataset
"""
pass
TRELLIS.2_DCU/trellis2/models/__init__.py
0 → 100644
View file @
f05e915f
import
importlib
__attributes
=
{
# Sparse Structure
'SparseStructureEncoder'
:
'sparse_structure_vae'
,
'SparseStructureDecoder'
:
'sparse_structure_vae'
,
'SparseStructureFlowModel'
:
'sparse_structure_flow'
,
# SLat Generation
'SLatFlowModel'
:
'structured_latent_flow'
,
'ElasticSLatFlowModel'
:
'structured_latent_flow'
,
# SC-VAEs
'SparseUnetVaeEncoder'
:
'sc_vaes.sparse_unet_vae'
,
'SparseUnetVaeDecoder'
:
'sc_vaes.sparse_unet_vae'
,
'FlexiDualGridVaeEncoder'
:
'sc_vaes.fdg_vae'
,
'FlexiDualGridVaeDecoder'
:
'sc_vaes.fdg_vae'
}
__submodules
=
[]
__all__
=
list
(
__attributes
.
keys
())
+
__submodules
def
__getattr__
(
name
):
if
name
not
in
globals
():
if
name
in
__attributes
:
module_name
=
__attributes
[
name
]
module
=
importlib
.
import_module
(
f
".
{
module_name
}
"
,
__name__
)
globals
()[
name
]
=
getattr
(
module
,
name
)
elif
name
in
__submodules
:
module
=
importlib
.
import_module
(
f
".
{
name
}
"
,
__name__
)
globals
()[
name
]
=
module
else
:
raise
AttributeError
(
f
"module
{
__name__
}
has no attribute
{
name
}
"
)
return
globals
()[
name
]
def
from_pretrained
(
path
:
str
,
**
kwargs
):
"""
Load a model from a pretrained checkpoint.
Args:
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
**kwargs: Additional arguments for the model constructor.
"""
import
os
import
json
from
safetensors.torch
import
load_file
is_local
=
os
.
path
.
exists
(
f
"
{
path
}
.json"
)
and
os
.
path
.
exists
(
f
"
{
path
}
.safetensors"
)
if
is_local
:
config_file
=
f
"
{
path
}
.json"
model_file
=
f
"
{
path
}
.safetensors"
else
:
from
huggingface_hub
import
hf_hub_download
path_parts
=
path
.
split
(
'/'
)
repo_id
=
f
'
{
path_parts
[
0
]
}
/
{
path_parts
[
1
]
}
'
model_name
=
'/'
.
join
(
path_parts
[
2
:])
config_file
=
hf_hub_download
(
repo_id
,
f
"
{
model_name
}
.json"
)
model_file
=
hf_hub_download
(
repo_id
,
f
"
{
model_name
}
.safetensors"
)
with
open
(
config_file
,
'r'
)
as
f
:
config
=
json
.
load
(
f
)
model
=
__getattr__
(
config
[
'name'
])(
**
config
[
'args'
],
**
kwargs
)
model
.
load_state_dict
(
load_file
(
model_file
),
strict
=
False
)
return
model
# For Pylance
if
__name__
==
'__main__'
:
from
.sparse_structure_vae
import
SparseStructureEncoder
,
SparseStructureDecoder
from
.sparse_structure_flow
import
SparseStructureFlowModel
from
.structured_latent_flow
import
SLatFlowModel
,
ElasticSLatFlowModel
from
.sc_vaes.sparse_unet_vae
import
SparseUnetVaeEncoder
,
SparseUnetVaeDecoder
from
.sc_vaes.fdg_vae
import
FlexiDualGridVaeEncoder
,
FlexiDualGridVaeDecoder
TRELLIS.2_DCU/trellis2/models/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/models/__pycache__/sparse_structure_flow.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/models/__pycache__/sparse_structure_vae.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/models/__pycache__/structured_latent_flow.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/models/sc_vaes/__pycache__/fdg_vae.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/models/sc_vaes/__pycache__/sparse_unet_vae.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/models/sc_vaes/fdg_vae.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
...modules
import
sparse
as
sp
from
...utils.pipeline_logger
import
get_logger
from
.sparse_unet_vae
import
(
SparseResBlock3d
,
SparseConvNeXtBlock3d
,
SparseResBlockDownsample3d
,
SparseResBlockUpsample3d
,
SparseResBlockS2C3d
,
SparseResBlockC2S3d
,
)
from
.sparse_unet_vae
import
(
SparseUnetVaeEncoder
,
SparseUnetVaeDecoder
,
)
from
...representations
import
Mesh
from
o_voxel.convert
import
flexible_dual_grid_to_mesh
class
FlexiDualGridVaeEncoder
(
SparseUnetVaeEncoder
):
def
__init__
(
self
,
model_channels
:
List
[
int
],
latent_channels
:
int
,
num_blocks
:
List
[
int
],
block_type
:
List
[
str
],
down_block_type
:
List
[
str
],
block_args
:
List
[
Dict
[
str
,
Any
]],
use_fp16
:
bool
=
False
,
):
super
().
__init__
(
6
,
model_channels
,
latent_channels
,
num_blocks
,
block_type
,
down_block_type
,
block_args
,
use_fp16
,
)
def
forward
(
self
,
vertices
:
sp
.
SparseTensor
,
intersected
:
sp
.
SparseTensor
,
sample_posterior
=
False
,
return_raw
=
False
):
x
=
vertices
.
replace
(
torch
.
cat
([
vertices
.
feats
-
0.5
,
intersected
.
feats
.
float
()
-
0.5
,
],
dim
=
1
))
return
super
().
forward
(
x
,
sample_posterior
,
return_raw
)
class
FlexiDualGridVaeDecoder
(
SparseUnetVaeDecoder
):
def
__init__
(
self
,
resolution
:
int
,
model_channels
:
List
[
int
],
latent_channels
:
int
,
num_blocks
:
List
[
int
],
block_type
:
List
[
str
],
up_block_type
:
List
[
str
],
block_args
:
List
[
Dict
[
str
,
Any
]],
voxel_margin
:
float
=
0.5
,
use_fp16
:
bool
=
False
,
):
self
.
resolution
=
resolution
self
.
voxel_margin
=
voxel_margin
super
().
__init__
(
7
,
model_channels
,
latent_channels
,
num_blocks
,
block_type
,
up_block_type
,
block_args
,
use_fp16
,
)
def
set_resolution
(
self
,
resolution
:
int
)
->
None
:
self
.
resolution
=
resolution
def
forward
(
self
,
x
:
sp
.
SparseTensor
,
gt_intersected
:
sp
.
SparseTensor
=
None
,
**
kwargs
):
decoded
=
super
().
forward
(
x
,
**
kwargs
)
if
self
.
training
:
h
,
subs_gt
,
subs
=
decoded
vertices
=
h
.
replace
((
1
+
2
*
self
.
voxel_margin
)
*
F
.
sigmoid
(
h
.
feats
[...,
0
:
3
])
-
self
.
voxel_margin
)
intersected_logits
=
h
.
replace
(
h
.
feats
[...,
3
:
6
])
quad_lerp
=
h
.
replace
(
F
.
softplus
(
h
.
feats
[...,
6
:
7
]))
mesh
=
[
Mesh
(
*
flexible_dual_grid_to_mesh
(
v
.
coords
[:,
1
:],
v
.
feats
,
i
.
feats
,
q
.
feats
,
aabb
=
[[
-
0.5
,
-
0.5
,
-
0.5
],
[
0.5
,
0.5
,
0.5
]],
grid_size
=
self
.
resolution
,
train
=
True
))
for
v
,
i
,
q
in
zip
(
vertices
,
gt_intersected
,
quad_lerp
)]
return
mesh
,
vertices
,
intersected_logits
,
subs_gt
,
subs
else
:
out_list
=
list
(
decoded
)
if
isinstance
(
decoded
,
tuple
)
else
[
decoded
]
h
=
out_list
[
0
]
get_logger
().
debug
(
f
"post-forward dtype=
{
h
.
feats
.
dtype
}
has_nan=
{
torch
.
isnan
(
h
.
feats
).
any
()
}
"
)
get_logger
().
debug
(
f
"DEBUG 1: VAE output h.feats has NaNs:
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
"
)
vertices
=
h
.
replace
((
1
+
2
*
self
.
voxel_margin
)
*
F
.
sigmoid
(
h
.
feats
[...,
0
:
3
])
-
self
.
voxel_margin
)
intersected
=
h
.
replace
(
h
.
feats
[...,
3
:
6
]
>
0
)
get_logger
().
debug
(
f
"DEBUG INTERSECTED: total=
{
intersected
.
feats
.
shape
[
0
]
}
, "
f
"true=
{
intersected
.
feats
.
any
(
dim
=-
1
).
sum
().
item
()
}
, "
f
"ratio=
{
intersected
.
feats
.
any
(
dim
=-
1
).
float
().
mean
():.
3
f
}
"
)
quad_lerp
=
h
.
replace
(
F
.
softplus
(
h
.
feats
[...,
6
:
7
]))
mesh
=
[
Mesh
(
*
flexible_dual_grid_to_mesh
(
v
.
coords
[:,
1
:],
v
.
feats
,
i
.
feats
,
q
.
feats
,
aabb
=
[[
-
0.5
,
-
0.5
,
-
0.5
],
[
0.5
,
0.5
,
0.5
]],
grid_size
=
self
.
resolution
,
train
=
False
))
for
v
,
i
,
q
in
zip
(
vertices
,
intersected
,
quad_lerp
)]
get_logger
().
debug
(
f
"DEBUG 2: o_voxel mesh[0] vertices has NaNs:
{
torch
.
isnan
(
mesh
[
0
].
vertices
).
any
().
item
()
}
"
)
out_list
[
0
]
=
mesh
return
out_list
[
0
]
if
len
(
out_list
)
==
1
else
tuple
(
out_list
)
TRELLIS.2_DCU/trellis2/models/sc_vaes/sparse_unet_vae.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
from
...modules.utils
import
convert_module_to_f16
,
convert_module_to_bf16
,
convert_module_to_f32
,
zero_module
from
...modules
import
sparse
as
sp
from
...modules.sparse.linear
import
rocm_safe_linear
,
ROCM_SAFE_CHUNK
from
...modules.norm
import
LayerNorm32
from
...utils.pipeline_logger
import
get_logger
class
SparseResBlock3d
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
downsample
:
bool
=
False
,
upsample
:
bool
=
False
,
resample_mode
:
Literal
[
'nearest'
,
'spatial2channel'
]
=
'nearest'
,
use_checkpoint
:
bool
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
downsample
=
downsample
self
.
upsample
=
upsample
self
.
resample_mode
=
resample_mode
self
.
use_checkpoint
=
use_checkpoint
assert
not
(
downsample
and
upsample
),
"Cannot downsample and upsample at the same time"
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
True
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
self
.
out_channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
if
resample_mode
==
'nearest'
:
self
.
conv1
=
sp
.
SparseConv3d
(
channels
,
self
.
out_channels
,
3
)
elif
resample_mode
==
'spatial2channel'
and
not
self
.
downsample
:
self
.
conv1
=
sp
.
SparseConv3d
(
channels
,
self
.
out_channels
*
8
,
3
)
elif
resample_mode
==
'spatial2channel'
and
self
.
downsample
:
self
.
conv1
=
sp
.
SparseConv3d
(
channels
,
self
.
out_channels
//
8
,
3
)
self
.
conv2
=
zero_module
(
sp
.
SparseConv3d
(
self
.
out_channels
,
self
.
out_channels
,
3
))
if
resample_mode
==
'nearest'
:
self
.
skip_connection
=
sp
.
SparseLinear
(
channels
,
self
.
out_channels
)
if
channels
!=
self
.
out_channels
else
nn
.
Identity
()
elif
resample_mode
==
'spatial2channel'
and
self
.
downsample
:
self
.
skip_connection
=
lambda
x
:
x
.
replace
(
x
.
feats
.
reshape
(
x
.
feats
.
shape
[
0
],
out_channels
,
channels
*
8
//
out_channels
).
mean
(
dim
=-
1
))
elif
resample_mode
==
'spatial2channel'
and
not
self
.
downsample
:
self
.
skip_connection
=
lambda
x
:
x
.
replace
(
x
.
feats
.
repeat_interleave
(
out_channels
//
(
channels
//
8
),
dim
=
1
))
self
.
updown
=
None
if
self
.
downsample
:
if
resample_mode
==
'nearest'
:
self
.
updown
=
sp
.
SparseDownsample
(
2
)
elif
resample_mode
==
'spatial2channel'
:
self
.
updown
=
sp
.
SparseSpatial2Channel
(
2
)
elif
self
.
upsample
:
self
.
to_subdiv
=
sp
.
SparseLinear
(
channels
,
8
)
if
resample_mode
==
'nearest'
:
self
.
updown
=
sp
.
SparseUpsample
(
2
)
elif
resample_mode
==
'spatial2channel'
:
self
.
updown
=
sp
.
SparseChannel2Spatial
(
2
)
def
_updown
(
self
,
x
:
sp
.
SparseTensor
,
subdiv
:
sp
.
SparseTensor
=
None
)
->
sp
.
SparseTensor
:
if
self
.
downsample
:
x
=
self
.
updown
(
x
)
elif
self
.
upsample
:
x
=
self
.
updown
(
x
,
subdiv
.
replace
(
subdiv
.
feats
>
0
))
return
x
def
_forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
subdiv
=
None
if
self
.
upsample
:
subdiv
=
self
.
to_subdiv
(
x
)
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
if
self
.
resample_mode
==
'spatial2channel'
:
h
=
self
.
conv1
(
h
)
h
=
self
.
_updown
(
h
,
subdiv
)
x
=
self
.
_updown
(
x
,
subdiv
)
if
self
.
resample_mode
==
'nearest'
:
h
=
self
.
conv1
(
h
)
h
=
h
.
replace
(
self
.
norm2
(
h
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
h
=
self
.
conv2
(
h
)
h
=
h
+
self
.
skip_connection
(
x
)
if
self
.
upsample
:
return
h
,
subdiv
return
h
def
forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
)
class
SparseResBlockDownsample3d
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
use_checkpoint
:
bool
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
True
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
self
.
out_channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
conv1
=
sp
.
SparseConv3d
(
channels
,
self
.
out_channels
,
3
)
self
.
conv2
=
zero_module
(
sp
.
SparseConv3d
(
self
.
out_channels
,
self
.
out_channels
,
3
))
self
.
skip_connection
=
sp
.
SparseLinear
(
channels
,
self
.
out_channels
)
if
channels
!=
self
.
out_channels
else
nn
.
Identity
()
self
.
updown
=
sp
.
SparseDownsample
(
2
)
def
_forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
h
=
self
.
updown
(
h
)
x
=
self
.
updown
(
x
)
h
=
self
.
conv1
(
h
)
h
=
h
.
replace
(
self
.
norm2
(
h
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
h
=
self
.
conv2
(
h
)
h
=
h
+
self
.
skip_connection
(
x
)
return
h
def
forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
)
class
SparseResBlockUpsample3d
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
use_checkpoint
:
bool
=
False
,
pred_subdiv
:
bool
=
True
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_checkpoint
=
use_checkpoint
self
.
pred_subdiv
=
pred_subdiv
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
True
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
self
.
out_channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
conv1
=
sp
.
SparseConv3d
(
channels
,
self
.
out_channels
,
3
)
self
.
conv2
=
zero_module
(
sp
.
SparseConv3d
(
self
.
out_channels
,
self
.
out_channels
,
3
))
self
.
skip_connection
=
sp
.
SparseLinear
(
channels
,
self
.
out_channels
)
if
channels
!=
self
.
out_channels
else
nn
.
Identity
()
if
self
.
pred_subdiv
:
self
.
to_subdiv
=
sp
.
SparseLinear
(
channels
,
8
)
self
.
updown
=
sp
.
SparseUpsample
(
2
)
def
_forward
(
self
,
x
:
sp
.
SparseTensor
,
subdiv
:
sp
.
SparseTensor
=
None
)
->
sp
.
SparseTensor
:
if
self
.
pred_subdiv
:
subdiv
=
self
.
to_subdiv
(
x
)
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
subdiv_binarized
=
subdiv
.
replace
(
subdiv
.
feats
>
0
)
if
subdiv
is
not
None
else
None
h
=
self
.
updown
(
h
,
subdiv_binarized
)
x
=
self
.
updown
(
x
,
subdiv_binarized
)
h
=
self
.
conv1
(
h
)
h
=
h
.
replace
(
self
.
norm2
(
h
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
h
=
self
.
conv2
(
h
)
h
=
h
+
self
.
skip_connection
(
x
)
if
self
.
pred_subdiv
:
return
h
,
subdiv
else
:
return
h
def
forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
)
class
SparseResBlockS2C3d
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
use_checkpoint
:
bool
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
True
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
self
.
out_channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
conv1
=
sp
.
SparseConv3d
(
channels
,
self
.
out_channels
//
8
,
3
)
self
.
conv2
=
zero_module
(
sp
.
SparseConv3d
(
self
.
out_channels
,
self
.
out_channels
,
3
))
self
.
skip_connection
=
lambda
x
:
x
.
replace
(
x
.
feats
.
reshape
(
x
.
feats
.
shape
[
0
],
out_channels
,
channels
*
8
//
out_channels
).
mean
(
dim
=-
1
))
self
.
updown
=
sp
.
SparseSpatial2Channel
(
2
)
def
_forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
h
=
self
.
conv1
(
h
)
h
=
self
.
updown
(
h
)
x
=
self
.
updown
(
x
)
h
=
h
.
replace
(
self
.
norm2
(
h
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
h
=
self
.
conv2
(
h
)
h
=
h
+
self
.
skip_connection
(
x
)
return
h
def
forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
)
class
SparseResBlockC2S3d
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
use_checkpoint
:
bool
=
False
,
pred_subdiv
:
bool
=
True
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_checkpoint
=
use_checkpoint
self
.
pred_subdiv
=
pred_subdiv
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
True
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
self
.
out_channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
conv1
=
sp
.
SparseConv3d
(
channels
,
self
.
out_channels
*
8
,
3
)
self
.
conv2
=
zero_module
(
sp
.
SparseConv3d
(
self
.
out_channels
,
self
.
out_channels
,
3
))
self
.
skip_connection
=
lambda
x
:
x
.
replace
(
x
.
feats
.
repeat_interleave
(
out_channels
//
(
channels
//
8
),
dim
=
1
))
if
pred_subdiv
:
self
.
to_subdiv
=
sp
.
SparseLinear
(
channels
,
8
)
self
.
updown
=
sp
.
SparseChannel2Spatial
(
2
)
def
_forward
(
self
,
x
:
sp
.
SparseTensor
,
subdiv
:
sp
.
SparseTensor
=
None
)
->
sp
.
SparseTensor
:
if
self
.
pred_subdiv
:
subdiv
=
self
.
to_subdiv
(
x
)
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
h
=
self
.
conv1
(
h
)
# ROCm: cast to fp32 before threshold - bf16 trained weights produce shifted logits\n
subdiv_binarized
=
subdiv
.
replace
(
subdiv
.
feats
.
float
()
>
0
)
if
subdiv
is
not
None
else
None
h
=
self
.
updown
(
h
,
subdiv_binarized
)
x
=
self
.
updown
(
x
,
subdiv_binarized
)
h
=
h
.
replace
(
self
.
norm2
(
h
.
feats
))
h
=
h
.
replace
(
F
.
silu
(
h
.
feats
))
h
=
self
.
conv2
(
h
)
h
=
h
+
self
.
skip_connection
(
x
)
if
self
.
pred_subdiv
:
return
h
,
subdiv
else
:
return
h
def
forward
(
self
,
x
:
sp
.
SparseTensor
,
subdiv
:
sp
.
SparseTensor
=
None
)
->
sp
.
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
subdiv
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
,
subdiv
)
class
SparseConvNeXtBlock3d
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
mlp_ratio
:
float
=
4.0
,
use_checkpoint
:
bool
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
LayerNorm32
(
channels
,
elementwise_affine
=
True
,
eps
=
1e-6
)
self
.
conv
=
sp
.
SparseConv3d
(
channels
,
channels
,
3
)
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
channels
,
int
(
channels
*
mlp_ratio
)),
nn
.
SiLU
(),
zero_module
(
nn
.
Linear
(
int
(
channels
*
mlp_ratio
),
channels
)),
)
def
_forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
h
=
self
.
conv
(
x
)
h
=
h
.
replace
(
self
.
norm
(
h
.
feats
))
# ROCm GFX1201 bug workaround: chunk MLP (two nn.Linear layers inside) for large N
# The MLP is row-independent so chunking is exact, not an approximation
feats
=
h
.
feats
N
=
feats
.
shape
[
0
]
if
N
<=
ROCM_SAFE_CHUNK
:
h
=
h
.
replace
(
self
.
mlp
(
feats
))
else
:
out
=
torch
.
empty_like
(
feats
)
for
s
in
range
(
0
,
N
,
ROCM_SAFE_CHUNK
):
e
=
min
(
s
+
ROCM_SAFE_CHUNK
,
N
)
out
[
s
:
e
]
=
self
.
mlp
(
feats
[
s
:
e
])
h
=
h
.
replace
(
out
)
return
h
+
x
def
forward
(
self
,
x
:
sp
.
SparseTensor
)
->
sp
.
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
)
class
SparseUnetVaeEncoder
(
nn
.
Module
):
"""
Sparse Swin Transformer Unet VAE model.
"""
def
__init__
(
self
,
in_channels
:
int
,
model_channels
:
List
[
int
],
latent_channels
:
int
,
num_blocks
:
List
[
int
],
block_type
:
List
[
str
],
down_block_type
:
List
[
str
],
block_args
:
List
[
Dict
[
str
,
Any
]],
use_fp16
:
bool
=
False
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
num_blocks
=
num_blocks
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
input_layer
=
sp
.
SparseLinear
(
in_channels
,
model_channels
[
0
])
self
.
to_latent
=
sp
.
SparseLinear
(
model_channels
[
-
1
],
2
*
latent_channels
)
self
.
blocks
=
nn
.
ModuleList
([])
for
i
in
range
(
len
(
num_blocks
)):
self
.
blocks
.
append
(
nn
.
ModuleList
([]))
for
j
in
range
(
num_blocks
[
i
]):
self
.
blocks
[
-
1
].
append
(
globals
()[
block_type
[
i
]](
model_channels
[
i
],
**
block_args
[
i
],
)
)
if
i
<
len
(
num_blocks
)
-
1
:
self
.
blocks
[
-
1
].
append
(
globals
()[
down_block_type
[
i
]](
model_channels
[
i
],
model_channels
[
i
+
1
],
**
block_args
[
i
],
)
)
self
.
initialize_weights
()
if
use_fp16
:
self
.
convert_to_fp16
()
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Return the device of the model.
"""
return
next
(
self
.
parameters
()).
device
def
convert_to_fp16
(
self
)
->
None
:
"""
Convert the torso of the model to float16 (actually bfloat16 for ROCm stability).
"""
self
.
blocks
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
)
->
None
:
"""
Convert the torso of the model to float32.
"""
self
.
blocks
.
apply
(
convert_module_to_f32
)
def
initialize_weights
(
self
)
->
None
:
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
def
forward
(
self
,
x
:
sp
.
SparseTensor
,
sample_posterior
=
False
,
return_raw
=
False
):
h
=
self
.
input_layer
(
x
)
h
=
h
.
type
(
self
.
dtype
)
for
i
,
res
in
enumerate
(
self
.
blocks
):
for
j
,
block
in
enumerate
(
res
):
h
=
block
(
h
)
h
=
h
.
type
(
x
.
dtype
)
h
=
h
.
replace
(
F
.
layer_norm
(
h
.
feats
,
h
.
feats
.
shape
[
-
1
:]))
h
=
self
.
to_latent
(
h
)
# Sample from the posterior distribution
mean
,
logvar
=
h
.
feats
.
chunk
(
2
,
dim
=-
1
)
if
sample_posterior
:
std
=
torch
.
exp
(
0.5
*
logvar
)
z
=
mean
+
std
*
torch
.
randn_like
(
std
)
else
:
z
=
mean
z
=
h
.
replace
(
z
)
if
return_raw
:
return
z
,
mean
,
logvar
else
:
return
z
class
SparseUnetVaeDecoder
(
nn
.
Module
):
"""
Sparse Swin Transformer Unet VAE model.
"""
def
__init__
(
self
,
out_channels
:
int
,
model_channels
:
List
[
int
],
latent_channels
:
int
,
num_blocks
:
List
[
int
],
block_type
:
List
[
str
],
up_block_type
:
List
[
str
],
block_args
:
List
[
Dict
[
str
,
Any
]],
use_fp16
:
bool
=
False
,
pred_subdiv
:
bool
=
True
,
):
super
().
__init__
()
self
.
out_channels
=
out_channels
self
.
model_channels
=
model_channels
self
.
num_blocks
=
num_blocks
self
.
use_fp16
=
use_fp16
self
.
pred_subdiv
=
pred_subdiv
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
low_vram
=
False
self
.
output_layer
=
sp
.
SparseLinear
(
model_channels
[
-
1
],
out_channels
)
self
.
from_latent
=
sp
.
SparseLinear
(
latent_channels
,
model_channels
[
0
])
self
.
blocks
=
nn
.
ModuleList
([])
for
i
in
range
(
len
(
num_blocks
)):
self
.
blocks
.
append
(
nn
.
ModuleList
([]))
for
j
in
range
(
num_blocks
[
i
]):
self
.
blocks
[
-
1
].
append
(
globals
()[
block_type
[
i
]](
model_channels
[
i
],
**
block_args
[
i
],
)
)
if
i
<
len
(
num_blocks
)
-
1
:
self
.
blocks
[
-
1
].
append
(
globals
()[
up_block_type
[
i
]](
model_channels
[
i
],
model_channels
[
i
+
1
],
pred_subdiv
=
pred_subdiv
,
**
block_args
[
i
],
)
)
self
.
initialize_weights
()
if
use_fp16
:
self
.
convert_to_fp16
()
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Return the device of the model.
"""
return
next
(
self
.
parameters
()).
device
def
convert_to_fp16
(
self
)
->
None
:
"""
Convert the torso of the model to float16 (actually bfloat16 for ROCm stability).
"""
self
.
blocks
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
)
->
None
:
"""
Convert the torso of the model to float32.
"""
self
.
blocks
.
apply
(
convert_module_to_f32
)
def
initialize_weights
(
self
)
->
None
:
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
def
forward
(
self
,
x
:
sp
.
SparseTensor
,
guide_subs
:
Optional
[
List
[
sp
.
SparseTensor
]]
=
None
,
return_subs
:
bool
=
False
)
->
sp
.
SparseTensor
:
assert
guide_subs
is
None
or
self
.
pred_subdiv
==
False
,
"Only decoders with pred_subdiv=False can be used with guide_subs"
assert
return_subs
==
False
or
self
.
pred_subdiv
==
True
,
"Only decoders with pred_subdiv=True can be used with return_subs"
h
=
self
.
from_latent
(
x
)
get_logger
().
debug
(
f
"DECODER from_latent: nan=
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
inf=
{
torch
.
isinf
(
h
.
feats
).
any
().
item
()
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
dtype=
{
h
.
feats
.
dtype
}
"
)
h
=
h
.
type
(
self
.
dtype
)
get_logger
().
debug
(
f
"DECODER after dtype cast: nan=
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
inf=
{
torch
.
isinf
(
h
.
feats
).
any
().
item
()
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
dtype=
{
h
.
feats
.
dtype
}
"
)
subs_gt
=
[]
subs
=
[]
for
i
,
res
in
enumerate
(
self
.
blocks
):
for
j
,
block
in
enumerate
(
res
):
if
i
<
len
(
self
.
blocks
)
-
1
and
j
==
len
(
res
)
-
1
:
if
self
.
pred_subdiv
:
if
self
.
training
:
subs_gt
.
append
(
h
.
get_spatial_cache
(
'subdivision'
))
h
,
sub
=
block
(
h
)
subs
.
append
(
sub
)
else
:
h
=
block
(
h
,
subdiv
=
guide_subs
[
i
]
if
guide_subs
is
not
None
else
None
)
else
:
h
=
block
(
h
)
if
not
torch
.
isfinite
(
h
.
feats
).
all
():
print
(
f
"FATAL: NaN/Inf at decoder block i=
{
i
}
j=
{
j
}
type=
{
type
(
block
).
__name__
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
"
)
import
sys
;
sys
.
exit
(
1
)
h
=
h
.
type
(
x
.
dtype
)
get_logger
().
debug
(
f
"DECODER post-blocks cast: nan=
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
inf=
{
torch
.
isinf
(
h
.
feats
).
any
().
item
()
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
dtype=
{
h
.
feats
.
dtype
}
"
)
h
=
h
.
replace
(
F
.
layer_norm
(
h
.
feats
,
h
.
feats
.
shape
[
-
1
:]))
get_logger
().
debug
(
f
"DECODER post-layernorm: nan=
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
inf=
{
torch
.
isinf
(
h
.
feats
).
any
().
item
()
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
"
)
get_logger
().
debug
(
f
"DECODER output_layer input: shape=
{
h
.
feats
.
shape
}
stride=
{
h
.
feats
.
stride
()
}
contiguous=
{
h
.
feats
.
is_contiguous
()
}
"
)
get_logger
().
debug
(
f
"DECODER output_layer weight: shape=
{
self
.
output_layer
.
weight
.
shape
}
dtype=
{
self
.
output_layer
.
weight
.
dtype
}
"
)
get_logger
().
debug
(
f
"DECODER pre-output_layer: feats shape=
{
h
.
feats
.
shape
}
contiguous=
{
h
.
feats
.
is_contiguous
()
}
weight shape=
{
self
.
output_layer
.
weight
.
shape
}
weight dtype=
{
self
.
output_layer
.
weight
.
dtype
}
"
)
# ROCm workaround: ensure contiguous before F.linear
h
=
h
.
replace
(
h
.
feats
.
contiguous
())
h
=
self
.
output_layer
(
h
)
get_logger
().
debug
(
f
"DECODER post-output_layer: nan=
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
inf=
{
torch
.
isinf
(
h
.
feats
).
any
().
item
()
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
dtype=
{
h
.
feats
.
dtype
}
"
)
get_logger
().
debug
(
f
"DEBUG OUTPUT_LAYER: dtype=
{
h
.
feats
.
dtype
}
has_nan=
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
max_abs=
{
h
.
feats
.
abs
().
max
().
item
()
if
h
.
feats
.
numel
()
>
0
else
0
}
"
)
if
self
.
training
and
self
.
pred_subdiv
:
return
h
,
subs_gt
,
subs
else
:
if
return_subs
:
return
h
,
subs
else
:
return
h
# REPLACE WITH:
def
upsample
(
self
,
x
:
sp
.
SparseTensor
,
upsample_times
:
int
)
->
torch
.
Tensor
:
assert
self
.
pred_subdiv
==
True
,
"Only decoders with pred_subdiv=True can be used with upsampling"
h
=
self
.
from_latent
(
x
)
get_logger
().
debug
(
f
"UPSAMPLE from_latent: dtype=
{
h
.
feats
.
dtype
}
nan=
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
inf=
{
torch
.
isinf
(
h
.
feats
).
any
().
item
()
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
"
)
h
=
h
.
type
(
self
.
dtype
)
get_logger
().
debug
(
f
"UPSAMPLE after type cast to
{
self
.
dtype
}
: nan=
{
torch
.
isnan
(
h
.
feats
).
any
().
item
()
}
inf=
{
torch
.
isinf
(
h
.
feats
).
any
().
item
()
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
"
)
for
i
,
res
in
enumerate
(
self
.
blocks
):
if
i
==
upsample_times
:
return
h
.
coords
for
j
,
block
in
enumerate
(
res
):
if
i
<
len
(
self
.
blocks
)
-
1
and
j
==
len
(
res
)
-
1
:
h
,
sub
=
block
(
h
)
else
:
h
=
block
(
h
)
if
torch
.
isnan
(
h
.
feats
).
any
()
or
torch
.
isinf
(
h
.
feats
).
any
():
print
(
f
"UPSAMPLE NaN/Inf at block i=
{
i
}
j=
{
j
}
type=
{
type
(
block
).
__name__
}
max=
{
h
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
"
)
break
else
:
continue
break
def
dump_debug
(
self
,
tag
:
str
,
tensor
)
->
None
:
import
os
os
.
makedirs
(
'/tmp/trellis_debug'
,
exist_ok
=
True
)
path
=
f
'/tmp/trellis_debug/
{
tag
}
.pt'
torch
.
save
({
'feats'
:
tensor
.
feats
.
float
().
cpu
(),
'coords'
:
tensor
.
coords
.
cpu
()},
path
)
print
(
f
"DUMPED
{
tag
}
: feats dtype=
{
tensor
.
feats
.
dtype
}
shape=
{
tensor
.
feats
.
shape
}
"
f
"has_nan=
{
torch
.
isnan
(
tensor
.
feats
).
any
().
item
()
}
"
f
"has_inf=
{
torch
.
isinf
(
tensor
.
feats
).
any
().
item
()
}
"
f
"max=
{
tensor
.
feats
.
float
().
abs
().
max
().
item
():.
4
f
}
"
)
TRELLIS.2_DCU/trellis2/models/sparse_elastic_mixin.py
0 → 100644
View file @
f05e915f
from
contextlib
import
contextmanager
from
typing
import
*
import
math
from
..modules
import
sparse
as
sp
from
..utils.elastic_utils
import
ElasticModuleMixin
class
SparseTransformerElasticMixin
(
ElasticModuleMixin
):
def
_get_input_size
(
self
,
x
:
sp
.
SparseTensor
,
*
args
,
**
kwargs
):
return
x
.
feats
.
shape
[
0
]
@
contextmanager
def
with_mem_ratio
(
self
,
mem_ratio
=
1.0
):
if
mem_ratio
==
1.0
:
yield
1.0
return
num_blocks
=
len
(
self
.
blocks
)
num_checkpoint_blocks
=
min
(
math
.
ceil
((
1
-
mem_ratio
)
*
num_blocks
)
+
1
,
num_blocks
)
exact_mem_ratio
=
1
-
(
num_checkpoint_blocks
-
1
)
/
num_blocks
for
i
in
range
(
num_blocks
):
self
.
blocks
[
i
].
use_checkpoint
=
i
<
num_checkpoint_blocks
yield
exact_mem_ratio
for
i
in
range
(
num_blocks
):
self
.
blocks
[
i
].
use_checkpoint
=
False
TRELLIS.2_DCU/trellis2/models/sparse_structure_flow.py
0 → 100644
View file @
f05e915f
from
typing
import
*
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
..modules.utils
import
convert_module_to
,
manual_cast
,
str_to_dtype
from
..modules.transformer
import
AbsolutePositionEmbedder
,
ModulatedTransformerCrossBlock
from
..modules.attention
import
RotaryPositionEmbedder
class
TimestepEmbedder
(
nn
.
Module
):
"""
Embeds scalar timesteps into vector representations.
"""
def
__init__
(
self
,
hidden_size
,
frequency_embedding_size
=
256
):
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
frequency_embedding_size
,
hidden_size
,
bias
=
True
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
),
)
self
.
frequency_embedding_size
=
frequency_embedding_size
@
staticmethod
def
timestep_embedding
(
t
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
Args:
t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
dim: the dimension of the output.
max_period: controls the minimum frequency of the embeddings.
Returns:
an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
np
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
t
.
device
)
args
=
t
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
def
forward
(
self
,
t
):
t_freq
=
self
.
timestep_embedding
(
t
,
self
.
frequency_embedding_size
)
t_emb
=
self
.
mlp
(
t_freq
)
return
t_emb
class
SparseStructureFlowModel
(
nn
.
Module
):
def
__init__
(
self
,
resolution
:
int
,
in_channels
:
int
,
model_channels
:
int
,
cond_channels
:
int
,
out_channels
:
int
,
num_blocks
:
int
,
num_heads
:
Optional
[
int
]
=
None
,
num_head_channels
:
Optional
[
int
]
=
64
,
mlp_ratio
:
float
=
4
,
pe_mode
:
Literal
[
"ape"
,
"rope"
]
=
"ape"
,
rope_freq
:
Tuple
[
float
,
float
]
=
(
1.0
,
10000.0
),
dtype
:
str
=
'float32'
,
use_checkpoint
:
bool
=
False
,
share_mod
:
bool
=
False
,
initialization
:
str
=
'vanilla'
,
qk_rms_norm
:
bool
=
False
,
qk_rms_norm_cross
:
bool
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
cond_channels
=
cond_channels
self
.
out_channels
=
out_channels
self
.
num_blocks
=
num_blocks
self
.
num_heads
=
num_heads
or
model_channels
//
num_head_channels
self
.
mlp_ratio
=
mlp_ratio
self
.
pe_mode
=
pe_mode
self
.
use_checkpoint
=
use_checkpoint
self
.
share_mod
=
share_mod
self
.
initialization
=
initialization
self
.
qk_rms_norm
=
qk_rms_norm
self
.
qk_rms_norm_cross
=
qk_rms_norm_cross
self
.
dtype
=
str_to_dtype
(
dtype
)
self
.
t_embedder
=
TimestepEmbedder
(
model_channels
)
if
share_mod
:
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
model_channels
,
6
*
model_channels
,
bias
=
True
)
)
if
pe_mode
==
"ape"
:
pos_embedder
=
AbsolutePositionEmbedder
(
model_channels
,
3
)
coords
=
torch
.
meshgrid
(
*
[
torch
.
arange
(
res
,
device
=
self
.
device
)
for
res
in
[
resolution
]
*
3
],
indexing
=
'ij'
)
coords
=
torch
.
stack
(
coords
,
dim
=-
1
).
reshape
(
-
1
,
3
)
pos_emb
=
pos_embedder
(
coords
)
self
.
register_buffer
(
"pos_emb"
,
pos_emb
)
elif
pe_mode
==
"rope"
:
pos_embedder
=
RotaryPositionEmbedder
(
self
.
model_channels
//
self
.
num_heads
,
3
)
coords
=
torch
.
meshgrid
(
*
[
torch
.
arange
(
res
,
device
=
self
.
device
)
for
res
in
[
resolution
]
*
3
],
indexing
=
'ij'
)
coords
=
torch
.
stack
(
coords
,
dim
=-
1
).
reshape
(
-
1
,
3
)
rope_phases
=
pos_embedder
(
coords
)
self
.
register_buffer
(
"rope_phases"
,
rope_phases
)
if
pe_mode
!=
"rope"
:
self
.
rope_phases
=
None
self
.
input_layer
=
nn
.
Linear
(
in_channels
,
model_channels
)
self
.
blocks
=
nn
.
ModuleList
([
ModulatedTransformerCrossBlock
(
model_channels
,
cond_channels
,
num_heads
=
self
.
num_heads
,
mlp_ratio
=
self
.
mlp_ratio
,
attn_mode
=
'full'
,
use_checkpoint
=
self
.
use_checkpoint
,
use_rope
=
(
pe_mode
==
"rope"
),
rope_freq
=
rope_freq
,
share_mod
=
share_mod
,
qk_rms_norm
=
self
.
qk_rms_norm
,
qk_rms_norm_cross
=
self
.
qk_rms_norm_cross
,
)
for
_
in
range
(
num_blocks
)
])
self
.
out_layer
=
nn
.
Linear
(
model_channels
,
out_channels
)
self
.
initialize_weights
()
self
.
convert_to
(
self
.
dtype
)
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Return the device of the model.
"""
return
next
(
self
.
parameters
()).
device
def
convert_to
(
self
,
dtype
:
torch
.
dtype
)
->
None
:
"""
Convert the torso of the model to the specified dtype.
"""
self
.
dtype
=
dtype
self
.
blocks
.
apply
(
partial
(
convert_module_to
,
dtype
=
dtype
))
def
initialize_weights
(
self
)
->
None
:
if
self
.
initialization
==
'vanilla'
:
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in DiT blocks:
if
self
.
share_mod
:
nn
.
init
.
constant_
(
self
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
self
.
adaLN_modulation
[
-
1
].
bias
,
0
)
else
:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
out_layer
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
out_layer
.
bias
,
0
)
elif
self
.
initialization
==
'scaled'
:
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
normal_
(
module
.
weight
,
std
=
np
.
sqrt
(
2.0
/
(
5.0
*
self
.
model_channels
)))
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Scaled init for to_out and ffn2
def
_scaled_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
normal_
(
module
.
weight
,
std
=
1.0
/
np
.
sqrt
(
5
*
self
.
num_blocks
*
self
.
model_channels
))
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
for
block
in
self
.
blocks
:
block
.
self_attn
.
to_out
.
apply
(
_scaled_init
)
block
.
cross_attn
.
to_out
.
apply
(
_scaled_init
)
block
.
mlp
.
mlp
[
2
].
apply
(
_scaled_init
)
# Initialize input layer to make the initial representation have variance 1
nn
.
init
.
normal_
(
self
.
input_layer
.
weight
,
std
=
1.0
/
np
.
sqrt
(
self
.
in_channels
))
nn
.
init
.
zeros_
(
self
.
input_layer
.
bias
)
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in DiT blocks:
if
self
.
share_mod
:
nn
.
init
.
constant_
(
self
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
self
.
adaLN_modulation
[
-
1
].
bias
,
0
)
else
:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
out_layer
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
out_layer
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
cond
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
[
*
x
.
shape
]
==
[
x
.
shape
[
0
],
self
.
in_channels
,
*
[
self
.
resolution
]
*
3
],
\
f
"Input shape mismatch, got
{
x
.
shape
}
, expected
{
[
x
.
shape
[
0
],
self
.
in_channels
,
*
[
self
.
resolution
]
*
3
]
}
"
h
=
x
.
view
(
*
x
.
shape
[:
2
],
-
1
).
permute
(
0
,
2
,
1
).
contiguous
()
h
=
self
.
input_layer
(
h
)
if
self
.
pe_mode
==
"ape"
:
h
=
h
+
self
.
pos_emb
[
None
]
t_emb
=
self
.
t_embedder
(
t
)
if
self
.
share_mod
:
t_emb
=
self
.
adaLN_modulation
(
t_emb
)
t_emb
=
manual_cast
(
t_emb
,
self
.
dtype
)
h
=
manual_cast
(
h
,
self
.
dtype
)
cond
=
manual_cast
(
cond
,
self
.
dtype
)
for
block
in
self
.
blocks
:
h
=
block
(
h
,
t_emb
,
cond
,
self
.
rope_phases
)
h
=
manual_cast
(
h
,
x
.
dtype
)
h
=
F
.
layer_norm
(
h
,
h
.
shape
[
-
1
:])
h
=
self
.
out_layer
(
h
)
h
=
h
.
permute
(
0
,
2
,
1
).
view
(
h
.
shape
[
0
],
h
.
shape
[
2
],
*
[
self
.
resolution
]
*
3
).
contiguous
()
return
h
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment