Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
TRELLIS.2
Commits
f05e915f
Commit
f05e915f
authored
May 27, 2026
by
weishb
Browse files
首次提交
parent
297bf637
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3167 additions
and
0 deletions
+3167
-0
TRELLIS.2_DCU/trellis2/representations/__pycache__/__init__.cpython-310.pyc
...lis2/representations/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/representations/mesh/__init__.py
TRELLIS.2_DCU/trellis2/representations/mesh/__init__.py
+1
-0
TRELLIS.2_DCU/trellis2/representations/mesh/__pycache__/__init__.cpython-310.pyc
...representations/mesh/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/representations/mesh/__pycache__/base.cpython-310.pyc
...is2/representations/mesh/__pycache__/base.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/representations/mesh/base.py
TRELLIS.2_DCU/trellis2/representations/mesh/base.py
+311
-0
TRELLIS.2_DCU/trellis2/representations/voxel/__init__.py
TRELLIS.2_DCU/trellis2/representations/voxel/__init__.py
+1
-0
TRELLIS.2_DCU/trellis2/representations/voxel/__pycache__/__init__.cpython-310.pyc
...epresentations/voxel/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/representations/voxel/__pycache__/voxel_model.cpython-310.pyc
...esentations/voxel/__pycache__/voxel_model.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/representations/voxel/voxel_model.py
TRELLIS.2_DCU/trellis2/representations/voxel/voxel_model.py
+54
-0
TRELLIS.2_DCU/trellis2/trainers/__init__.py
TRELLIS.2_DCU/trellis2/trainers/__init__.py
+68
-0
TRELLIS.2_DCU/trellis2/trainers/basic.py
TRELLIS.2_DCU/trellis2/trainers/basic.py
+910
-0
TRELLIS.2_DCU/trellis2/trainers/flow_matching/flow_matching.py
...IS.2_DCU/trellis2/trainers/flow_matching/flow_matching.py
+353
-0
TRELLIS.2_DCU/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py
...trainers/flow_matching/mixins/classifier_free_guidance.py
+59
-0
TRELLIS.2_DCU/trellis2/trainers/flow_matching/mixins/image_conditioned.py
...ellis2/trainers/flow_matching/mixins/image_conditioned.py
+249
-0
TRELLIS.2_DCU/trellis2/trainers/flow_matching/mixins/text_conditioned.py
...rellis2/trainers/flow_matching/mixins/text_conditioned.py
+68
-0
TRELLIS.2_DCU/trellis2/trainers/flow_matching/sparse_flow_matching.py
...U/trellis2/trainers/flow_matching/sparse_flow_matching.py
+325
-0
TRELLIS.2_DCU/trellis2/trainers/utils.py
TRELLIS.2_DCU/trellis2/trainers/utils.py
+91
-0
TRELLIS.2_DCU/trellis2/trainers/vae/pbr_vae.py
TRELLIS.2_DCU/trellis2/trainers/vae/pbr_vae.py
+281
-0
TRELLIS.2_DCU/trellis2/trainers/vae/shape_vae.py
TRELLIS.2_DCU/trellis2/trainers/vae/shape_vae.py
+266
-0
TRELLIS.2_DCU/trellis2/trainers/vae/sparse_structure_vae.py
TRELLIS.2_DCU/trellis2/trainers/vae/sparse_structure_vae.py
+130
-0
No files found.
TRELLIS.2_DCU/trellis2/representations/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/representations/mesh/__init__.py
0 → 100644
View file @
f05e915f
from
.base
import
Mesh
,
MeshWithVoxel
,
MeshWithPbrMaterial
,
TextureFilterMode
,
TextureWrapMode
,
AlphaMode
,
PbrMaterial
,
Texture
TRELLIS.2_DCU/trellis2/representations/mesh/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/representations/mesh/__pycache__/base.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/representations/mesh/base.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
from
..voxel
import
Voxel
import
cumesh
from
flex_gemm.ops.grid_sample
import
grid_sample_3d
from
...utils.pipeline_logger
import
get_logger
,
log_mesh
,
elapsed
class
Mesh
:
def
__init__
(
self
,
vertices
,
faces
,
vertex_attrs
=
None
):
self
.
vertices
=
vertices
.
float
()
self
.
faces
=
faces
.
int
()
self
.
vertex_attrs
=
vertex_attrs
@
property
def
device
(
self
):
return
self
.
vertices
.
device
def
to
(
self
,
device
,
non_blocking
=
False
):
return
Mesh
(
self
.
vertices
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
faces
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
vertex_attrs
.
to
(
device
,
non_blocking
=
non_blocking
)
if
self
.
vertex_attrs
is
not
None
else
None
,
)
def
cuda
(
self
,
non_blocking
=
False
):
return
self
.
to
(
'cuda'
,
non_blocking
=
non_blocking
)
def
cpu
(
self
):
return
self
.
to
(
'cpu'
)
def
fill_holes
(
self
,
max_hole_perimeter
=
3e-2
):
import
os
,
numpy
as
np
L
=
get_logger
()
log_mesh
(
self
.
vertices
,
self
.
faces
,
"fill_holes:before"
)
vertices
=
self
.
vertices
.
cuda
()
faces
=
self
.
faces
.
cuda
()
# ------------------------------------------------------------------ #
# Debug helpers: per-step .obj dump + stats print
# ------------------------------------------------------------------ #
_dbg_dir
=
os
.
environ
.
get
(
"CUMESH_DEBUG_DIR"
,
"cumesh_debug"
)
_dbg_step
=
[
0
]
def
_snap
(
label
,
v_tensor
,
f_tensor
):
return
"""Dump vertex/face data to an OBJ and print min/max/nan stats."""
v
=
v_tensor
.
detach
().
cpu
().
float
().
numpy
()
# [N, 3]
f
=
f_tensor
.
detach
().
cpu
().
int
().
numpy
()
# [M, 3]
step
=
_dbg_step
[
0
]
_dbg_step
[
0
]
+=
1
vmin
=
v
.
min
(
axis
=
0
)
if
len
(
v
)
else
[
float
(
'nan'
)]
*
3
vmax
=
v
.
max
(
axis
=
0
)
if
len
(
v
)
else
[
float
(
'nan'
)]
*
3
all_zero_v
=
bool
((
v
==
0
).
all
())
if
len
(
v
)
else
True
all_zero_f
=
bool
((
f
==
0
).
all
())
if
len
(
f
)
else
True
nan_v
=
bool
(
np
.
isnan
(
v
).
any
())
print
(
f
"[CUMESH_DBG] step=
{
step
:
02
d
}
{
label
}
"
)
print
(
f
" verts :
{
v
.
shape
[
0
]
}
min=
{
vmin
}
max=
{
vmax
}
all_zero=
{
all_zero_v
}
nan=
{
nan_v
}
"
)
print
(
f
" faces :
{
f
.
shape
[
0
]
}
all_zero=
{
all_zero_f
}
"
)
os
.
makedirs
(
_dbg_dir
,
exist_ok
=
True
)
obj_path
=
os
.
path
.
join
(
_dbg_dir
,
f
"step
{
step
:
02
d
}
_
{
label
.
replace
(
':'
,
'_'
).
replace
(
'/'
,
'_'
)
}
.obj"
)
with
open
(
obj_path
,
"w"
)
as
fp
:
fp
.
write
(
f
"# step=
{
step
}
{
label
}
\n
"
)
fp
.
write
(
f
"#
{
v
.
shape
[
0
]
}
vertices,
{
f
.
shape
[
0
]
}
faces
\n\n
"
)
for
row
in
v
:
fp
.
write
(
f
"v
{
row
[
0
]:.
6
f
}
{
row
[
1
]:.
6
f
}
{
row
[
2
]:.
6
f
}
\n
"
)
fp
.
write
(
"
\n
"
)
for
row
in
f
:
fp
.
write
(
f
"f
{
row
[
0
]
+
1
}
{
row
[
1
]
+
1
}
{
row
[
2
]
+
1
}
\n
"
)
print
(
f
" ->
{
obj_path
}
"
)
def
_snap_mesh
(
label
):
return
"""Read current CuMesh state and dump it."""
v
,
f
=
mesh
.
read
()
_snap
(
label
,
v
,
f
)
# ------------------------------------------------------------------ #
mesh
=
cumesh
.
CuMesh
()
mesh
.
init
(
vertices
,
faces
)
_snap
(
"00_after_init"
,
vertices
,
faces
)
mesh
.
get_edges
()
_snap_mesh
(
"01_after_get_edges"
)
mesh
.
get_boundary_info
()
L
.
info
(
f
"
{
elapsed
()
}
fill_holes: num_boundaries=
{
mesh
.
num_boundaries
}
"
)
_snap_mesh
(
"02_after_get_boundary_info"
)
if
mesh
.
num_boundaries
==
0
:
L
.
info
(
f
"
{
elapsed
()
}
fill_holes: no boundaries, skipping"
)
return
mesh
.
get_vertex_edge_adjacency
()
_snap_mesh
(
"03_after_get_vertex_edge_adjacency"
)
mesh
.
get_vertex_boundary_adjacency
()
_snap_mesh
(
"04_after_get_vertex_boundary_adjacency"
)
mesh
.
get_manifold_boundary_adjacency
()
_snap_mesh
(
"05_after_get_manifold_boundary_adjacency"
)
mesh
.
read_manifold_boundary_adjacency
()
_snap_mesh
(
"06_after_read_manifold_boundary_adjacency"
)
mesh
.
get_boundary_connected_components
()
_snap_mesh
(
"07_after_get_boundary_connected_components"
)
mesh
.
get_boundary_loops
()
L
.
info
(
f
"
{
elapsed
()
}
fill_holes: num_boundary_loops=
{
mesh
.
num_boundary_loops
}
"
)
_snap_mesh
(
"08_after_get_boundary_loops"
)
if
mesh
.
num_boundary_loops
==
0
:
return
mesh
.
fill_holes
(
max_hole_perimeter
=
max_hole_perimeter
)
_snap_mesh
(
"09_after_fill_holes"
)
new_vertices
,
new_faces
=
mesh
.
read
()
_snap
(
"10_final_read"
,
new_vertices
,
new_faces
)
log_mesh
(
new_vertices
,
new_faces
,
"fill_holes:after"
)
self
.
vertices
=
new_vertices
.
to
(
self
.
device
)
self
.
faces
=
new_faces
.
to
(
self
.
device
)
def
remove_faces
(
self
,
face_mask
:
torch
.
Tensor
):
vertices
=
self
.
vertices
.
cuda
()
faces
=
self
.
faces
.
cuda
()
mesh
=
cumesh
.
CuMesh
()
mesh
.
init
(
vertices
,
faces
)
mesh
.
remove_faces
(
face_mask
)
new_vertices
,
new_faces
=
mesh
.
read
()
self
.
vertices
=
new_vertices
.
to
(
self
.
device
)
self
.
faces
=
new_faces
.
to
(
self
.
device
)
def
simplify
(
self
,
target
=
1000000
,
verbose
:
bool
=
False
,
options
:
dict
=
{}):
L
=
get_logger
()
log_mesh
(
self
.
vertices
,
self
.
faces
,
f
"simplify:before(target=
{
target
}
)"
)
vertices
=
self
.
vertices
.
cuda
()
faces
=
self
.
faces
.
cuda
()
mesh
=
cumesh
.
CuMesh
()
mesh
.
init
(
vertices
,
faces
)
mesh
.
simplify
(
target
,
verbose
=
verbose
,
options
=
options
)
new_vertices
,
new_faces
=
mesh
.
read
()
log_mesh
(
new_vertices
,
new_faces
,
"simplify:after"
)
self
.
vertices
=
new_vertices
.
to
(
self
.
device
)
self
.
faces
=
new_faces
.
to
(
self
.
device
)
class
TextureFilterMode
:
CLOSEST
=
0
LINEAR
=
1
class
TextureWrapMode
:
CLAMP_TO_EDGE
=
0
REPEAT
=
1
MIRRORED_REPEAT
=
2
class
AlphaMode
:
OPAQUE
=
0
MASK
=
1
BLEND
=
2
class
Texture
:
def
__init__
(
self
,
image
:
torch
.
Tensor
,
filter_mode
:
TextureFilterMode
=
TextureFilterMode
.
LINEAR
,
wrap_mode
:
TextureWrapMode
=
TextureWrapMode
.
REPEAT
):
self
.
image
=
image
self
.
filter_mode
=
filter_mode
self
.
wrap_mode
=
wrap_mode
def
to
(
self
,
device
,
non_blocking
=
False
):
return
Texture
(
self
.
image
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
filter_mode
,
self
.
wrap_mode
,
)
class
PbrMaterial
:
def
__init__
(
self
,
base_color_texture
:
Optional
[
Texture
]
=
None
,
base_color_factor
:
Union
[
torch
.
Tensor
,
List
[
float
]]
=
[
1.0
,
1.0
,
1.0
],
metallic_texture
:
Optional
[
Texture
]
=
None
,
metallic_factor
:
float
=
1.0
,
roughness_texture
:
Optional
[
Texture
]
=
None
,
roughness_factor
:
float
=
1.0
,
alpha_texture
:
Optional
[
Texture
]
=
None
,
alpha_factor
:
float
=
1.0
,
alpha_mode
:
AlphaMode
=
AlphaMode
.
OPAQUE
,
alpha_cutoff
:
float
=
0.5
,
):
self
.
base_color_texture
=
base_color_texture
self
.
base_color_factor
=
torch
.
tensor
(
base_color_factor
,
dtype
=
torch
.
float32
)[:
3
]
self
.
metallic_texture
=
metallic_texture
self
.
metallic_factor
=
metallic_factor
self
.
roughness_texture
=
roughness_texture
self
.
roughness_factor
=
roughness_factor
self
.
alpha_texture
=
alpha_texture
self
.
alpha_factor
=
alpha_factor
self
.
alpha_mode
=
alpha_mode
self
.
alpha_cutoff
=
alpha_cutoff
def
to
(
self
,
device
,
non_blocking
=
False
):
return
PbrMaterial
(
base_color_texture
=
self
.
base_color_texture
.
to
(
device
,
non_blocking
=
non_blocking
)
if
self
.
base_color_texture
is
not
None
else
None
,
base_color_factor
=
self
.
base_color_factor
.
to
(
device
,
non_blocking
=
non_blocking
),
metallic_texture
=
self
.
metallic_texture
.
to
(
device
,
non_blocking
=
non_blocking
)
if
self
.
metallic_texture
is
not
None
else
None
,
metallic_factor
=
self
.
metallic_factor
,
roughness_texture
=
self
.
roughness_texture
.
to
(
device
,
non_blocking
=
non_blocking
)
if
self
.
roughness_texture
is
not
None
else
None
,
roughness_factor
=
self
.
roughness_factor
,
alpha_texture
=
self
.
alpha_texture
.
to
(
device
,
non_blocking
=
non_blocking
)
if
self
.
alpha_texture
is
not
None
else
None
,
alpha_factor
=
self
.
alpha_factor
,
alpha_mode
=
self
.
alpha_mode
,
alpha_cutoff
=
self
.
alpha_cutoff
,
)
class
MeshWithPbrMaterial
(
Mesh
):
def
__init__
(
self
,
vertices
,
faces
,
material_ids
,
uv_coords
,
materials
:
List
[
PbrMaterial
],
):
self
.
vertices
=
vertices
.
float
()
self
.
faces
=
faces
.
int
()
self
.
material_ids
=
material_ids
# [M]
self
.
uv_coords
=
uv_coords
# [M, 3, 2]
self
.
materials
=
materials
self
.
layout
=
{
'base_color'
:
slice
(
0
,
3
),
'metallic'
:
slice
(
3
,
4
),
'roughness'
:
slice
(
4
,
5
),
'alpha'
:
slice
(
5
,
6
),
}
def
to
(
self
,
device
,
non_blocking
=
False
):
return
MeshWithPbrMaterial
(
self
.
vertices
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
faces
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
material_ids
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
uv_coords
.
to
(
device
,
non_blocking
=
non_blocking
),
[
material
.
to
(
device
,
non_blocking
=
non_blocking
)
for
material
in
self
.
materials
],
)
class
MeshWithVoxel
(
Mesh
,
Voxel
):
def
__init__
(
self
,
vertices
:
torch
.
Tensor
,
faces
:
torch
.
Tensor
,
origin
:
list
,
voxel_size
:
float
,
coords
:
torch
.
Tensor
,
attrs
:
torch
.
Tensor
,
voxel_shape
:
torch
.
Size
,
layout
:
Dict
=
{},
):
self
.
vertices
=
vertices
.
float
()
self
.
faces
=
faces
.
int
()
self
.
origin
=
torch
.
tensor
(
origin
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
voxel_size
=
voxel_size
self
.
coords
=
coords
self
.
attrs
=
attrs
self
.
voxel_shape
=
voxel_shape
self
.
layout
=
layout
def
to
(
self
,
device
,
non_blocking
=
False
):
return
MeshWithVoxel
(
self
.
vertices
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
faces
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
origin
.
tolist
(),
self
.
voxel_size
,
self
.
coords
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
attrs
.
to
(
device
,
non_blocking
=
non_blocking
),
self
.
voxel_shape
,
self
.
layout
,
)
def
query_attrs
(
self
,
xyz
):
grid
=
((
xyz
-
self
.
origin
)
/
self
.
voxel_size
).
reshape
(
1
,
-
1
,
3
)
vertex_attrs
=
grid_sample_3d
(
self
.
attrs
,
torch
.
cat
([
torch
.
zeros_like
(
self
.
coords
[...,
:
1
]),
self
.
coords
],
dim
=-
1
),
self
.
voxel_shape
,
grid
,
mode
=
'trilinear'
)[
0
]
return
vertex_attrs
def
query_vertex_attrs
(
self
):
return
self
.
query_attrs
(
self
.
vertices
)
TRELLIS.2_DCU/trellis2/representations/voxel/__init__.py
0 → 100644
View file @
f05e915f
from
.voxel_model
import
Voxel
\ No newline at end of file
TRELLIS.2_DCU/trellis2/representations/voxel/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/representations/voxel/__pycache__/voxel_model.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/representations/voxel/voxel_model.py
0 → 100644
View file @
f05e915f
from
typing
import
Dict
import
torch
class
Voxel
:
def
__init__
(
self
,
origin
:
list
,
voxel_size
:
float
,
coords
:
torch
.
Tensor
=
None
,
attrs
:
torch
.
Tensor
=
None
,
layout
:
Dict
=
{},
device
:
torch
.
device
=
'cuda'
):
self
.
origin
=
torch
.
tensor
(
origin
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
voxel_size
=
voxel_size
self
.
coords
=
coords
self
.
attrs
=
attrs
self
.
layout
=
layout
self
.
device
=
device
@
property
def
position
(
self
):
return
(
self
.
coords
+
0.5
)
*
self
.
voxel_size
+
self
.
origin
[
None
,
:]
def
split_attrs
(
self
):
return
{
k
:
self
.
attrs
[:,
self
.
layout
[
k
]]
for
k
in
self
.
layout
}
def
save
(
self
,
path
):
# lazy import
if
'o_voxel'
not
in
globals
():
import
o_voxel
o_voxel
.
io
.
write
(
path
,
self
.
coords
,
self
.
split_attrs
(),
)
def
load
(
self
,
path
):
# lazy import
if
'o_voxel'
not
in
globals
():
import
o_voxel
coord
,
attrs
=
o_voxel
.
io
.
read
(
path
)
self
.
coords
=
coord
.
int
().
to
(
self
.
device
)
self
.
attrs
=
torch
.
cat
([
attrs
[
k
]
for
k
in
attrs
],
dim
=
1
).
to
(
self
.
device
)
# build layout
start
=
0
self
.
layout
=
{}
for
k
in
attrs
:
self
.
layout
[
k
]
=
slice
(
start
,
start
+
attrs
[
k
].
shape
[
1
])
start
+=
attrs
[
k
].
shape
[
1
]
TRELLIS.2_DCU/trellis2/trainers/__init__.py
0 → 100644
View file @
f05e915f
import
importlib
__attributes
=
{
'BasicTrainer'
:
'basic'
,
'SparseStructureVaeTrainer'
:
'vae.sparse_structure_vae'
,
'ShapeVaeTrainer'
:
'vae.shape_vae'
,
'PbrVaeTrainer'
:
'vae.pbr_vae'
,
'FlowMatchingTrainer'
:
'flow_matching.flow_matching'
,
'FlowMatchingCFGTrainer'
:
'flow_matching.flow_matching'
,
'TextConditionedFlowMatchingCFGTrainer'
:
'flow_matching.flow_matching'
,
'ImageConditionedFlowMatchingCFGTrainer'
:
'flow_matching.flow_matching'
,
'SparseFlowMatchingTrainer'
:
'flow_matching.sparse_flow_matching'
,
'SparseFlowMatchingCFGTrainer'
:
'flow_matching.sparse_flow_matching'
,
'TextConditionedSparseFlowMatchingCFGTrainer'
:
'flow_matching.sparse_flow_matching'
,
'ImageConditionedSparseFlowMatchingCFGTrainer'
:
'flow_matching.sparse_flow_matching'
,
'MultiImageConditionedSparseFlowMatchingCFGTrainer'
:
'flow_matching.sparse_flow_matching'
,
'DinoV2FeatureExtractor'
:
'flow_matching.mixins.image_conditioned'
,
'DinoV3FeatureExtractor'
:
'flow_matching.mixins.image_conditioned'
,
}
__submodules
=
[]
__all__
=
list
(
__attributes
.
keys
())
+
__submodules
def
__getattr__
(
name
):
if
name
not
in
globals
():
if
name
in
__attributes
:
module_name
=
__attributes
[
name
]
module
=
importlib
.
import_module
(
f
".
{
module_name
}
"
,
__name__
)
globals
()[
name
]
=
getattr
(
module
,
name
)
elif
name
in
__submodules
:
module
=
importlib
.
import_module
(
f
".
{
name
}
"
,
__name__
)
globals
()[
name
]
=
module
else
:
raise
AttributeError
(
f
"module
{
__name__
}
has no attribute
{
name
}
"
)
return
globals
()[
name
]
# For Pylance
if
__name__
==
'__main__'
:
from
.basic
import
BasicTrainer
from
.vae.sparse_structure_vae
import
SparseStructureVaeTrainer
from
.vae.shape_vae
import
ShapeVaeTrainer
from
.vae.pbr_vae
import
PbrVaeTrainer
from
.flow_matching.flow_matching
import
(
FlowMatchingTrainer
,
FlowMatchingCFGTrainer
,
TextConditionedFlowMatchingCFGTrainer
,
ImageConditionedFlowMatchingCFGTrainer
,
)
from
.flow_matching.sparse_flow_matching
import
(
SparseFlowMatchingTrainer
,
SparseFlowMatchingCFGTrainer
,
TextConditionedSparseFlowMatchingCFGTrainer
,
ImageConditionedSparseFlowMatchingCFGTrainer
,
)
from
.flow_matching.mixins.image_conditioned
import
(
DinoV2FeatureExtractor
,
DinoV3FeatureExtractor
,
)
TRELLIS.2_DCU/trellis2/trainers/basic.py
0 → 100644
View file @
f05e915f
from
abc
import
abstractmethod
import
os
import
time
import
json
import
copy
import
threading
from
functools
import
partial
from
contextlib
import
nullcontext
import
torch
import
torch.distributed
as
dist
from
torch.utils.data
import
DataLoader
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
numpy
as
np
from
torchvision
import
utils
from
torch.utils.tensorboard
import
SummaryWriter
from
.utils
import
*
from
..utils.general_utils
import
*
from
..utils.data_utils
import
recursive_to_device
,
cycle
,
ResumableSampler
from
..utils.dist_utils
import
*
from
..utils
import
grad_clip_utils
,
elastic_utils
class
BasicTrainer
:
"""
Trainer for basic training loop.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
mix_precision_mode (str):
- None: No mixed precision.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
mix_precision_dtype (str): Mixed precision dtype.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
parallel_mode (str): Parallel mode. Options are 'ddp'.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
"""
def
__init__
(
self
,
models
,
dataset
,
*
,
output_dir
,
load_dir
,
step
,
max_steps
,
batch_size
=
None
,
batch_size_per_gpu
=
None
,
batch_split
=
None
,
optimizer
=
{},
lr_scheduler
=
None
,
elastic
=
None
,
grad_clip
=
None
,
ema_rate
=
0.9999
,
fp16_mode
=
None
,
mix_precision_mode
=
'inflat_all'
,
mix_precision_dtype
=
'float16'
,
fp16_scale_growth
=
1e-3
,
parallel_mode
=
'ddp'
,
finetune_ckpt
=
None
,
log_param_stats
=
False
,
prefetch_data
=
True
,
snapshot_batch_size
=
4
,
i_print
=
1000
,
i_log
=
500
,
i_sample
=
10000
,
i_save
=
10000
,
i_ddpcheck
=
10000
,
**
kwargs
):
assert
batch_size
is
not
None
or
batch_size_per_gpu
is
not
None
,
'Either batch_size or batch_size_per_gpu must be specified.'
self
.
models
=
models
self
.
dataset
=
dataset
self
.
batch_split
=
batch_split
if
batch_split
is
not
None
else
1
self
.
max_steps
=
max_steps
self
.
optimizer_config
=
optimizer
self
.
lr_scheduler_config
=
lr_scheduler
self
.
elastic_controller_config
=
elastic
self
.
grad_clip
=
grad_clip
self
.
ema_rate
=
[
ema_rate
]
if
isinstance
(
ema_rate
,
float
)
else
ema_rate
if
fp16_mode
is
not
None
:
mix_precision_dtype
=
'float16'
mix_precision_mode
=
fp16_mode
self
.
mix_precision_mode
=
mix_precision_mode
self
.
mix_precision_dtype
=
str_to_dtype
(
mix_precision_dtype
)
self
.
fp16_scale_growth
=
fp16_scale_growth
self
.
parallel_mode
=
parallel_mode
self
.
log_param_stats
=
log_param_stats
self
.
prefetch_data
=
prefetch_data
self
.
snapshot_batch_size
=
snapshot_batch_size
self
.
log
=
[]
if
self
.
prefetch_data
:
self
.
_data_prefetched
=
None
self
.
output_dir
=
output_dir
self
.
i_print
=
i_print
self
.
i_log
=
i_log
self
.
i_sample
=
i_sample
self
.
i_save
=
i_save
self
.
i_ddpcheck
=
i_ddpcheck
if
dist
.
is_initialized
():
# Multi-GPU params
self
.
world_size
=
dist
.
get_world_size
()
self
.
rank
=
dist
.
get_rank
()
self
.
local_rank
=
dist
.
get_rank
()
%
torch
.
cuda
.
device_count
()
self
.
is_master
=
self
.
rank
==
0
else
:
# Single-GPU params
self
.
world_size
=
1
self
.
rank
=
0
self
.
local_rank
=
0
self
.
is_master
=
True
self
.
batch_size
=
batch_size
if
batch_size_per_gpu
is
None
else
batch_size_per_gpu
*
self
.
world_size
self
.
batch_size_per_gpu
=
batch_size_per_gpu
if
batch_size_per_gpu
is
not
None
else
batch_size
//
self
.
world_size
assert
self
.
batch_size
%
self
.
world_size
==
0
,
'Batch size must be divisible by the number of GPUs.'
assert
self
.
batch_size_per_gpu
%
self
.
batch_split
==
0
,
'Batch size per GPU must be divisible by batch split.'
self
.
init_models_and_more
(
**
kwargs
)
self
.
prepare_dataloader
(
**
kwargs
)
# Load checkpoint
self
.
step
=
0
if
load_dir
is
not
None
and
step
is
not
None
:
self
.
load
(
load_dir
,
step
)
elif
finetune_ckpt
is
not
None
:
self
.
finetune_from
(
finetune_ckpt
)
if
self
.
is_master
:
os
.
makedirs
(
os
.
path
.
join
(
self
.
output_dir
,
'ckpts'
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
self
.
output_dir
,
'samples'
),
exist_ok
=
True
)
self
.
writer
=
SummaryWriter
(
os
.
path
.
join
(
self
.
output_dir
,
'tb_logs'
))
if
self
.
parallel_mode
==
'ddp'
and
self
.
world_size
>
1
:
self
.
check_ddp
()
if
self
.
is_master
:
print
(
'
\n\n
Trainer initialized.'
)
print
(
self
)
def
__str__
(
self
):
lines
=
[]
lines
.
append
(
self
.
__class__
.
__name__
)
lines
.
append
(
f
' - Models:'
)
for
name
,
model
in
self
.
models
.
items
():
lines
.
append
(
f
' -
{
name
}
:
{
model
.
__class__
.
__name__
}
'
)
lines
.
append
(
f
' - Dataset:
{
indent
(
str
(
self
.
dataset
),
2
)
}
'
)
lines
.
append
(
f
' - Dataloader:'
)
lines
.
append
(
f
' - Sampler:
{
self
.
dataloader
.
sampler
.
__class__
.
__name__
}
'
)
lines
.
append
(
f
' - Num workers:
{
self
.
dataloader
.
num_workers
}
'
)
lines
.
append
(
f
' - Number of steps:
{
self
.
max_steps
}
'
)
lines
.
append
(
f
' - Number of GPUs:
{
self
.
world_size
}
'
)
lines
.
append
(
f
' - Batch size:
{
self
.
batch_size
}
'
)
lines
.
append
(
f
' - Batch size per GPU:
{
self
.
batch_size_per_gpu
}
'
)
lines
.
append
(
f
' - Batch split:
{
self
.
batch_split
}
'
)
lines
.
append
(
f
' - Optimizer:
{
self
.
optimizer
.
__class__
.
__name__
}
'
)
lines
.
append
(
f
' - Learning rate:
{
self
.
optimizer
.
param_groups
[
0
][
"lr"
]
}
'
)
if
self
.
lr_scheduler_config
is
not
None
:
lines
.
append
(
f
' - LR scheduler:
{
self
.
lr_scheduler
.
__class__
.
__name__
}
'
)
if
self
.
elastic_controller_config
is
not
None
:
lines
.
append
(
f
' - Elastic memory:
{
indent
(
str
(
self
.
elastic_controller
),
2
)
}
'
)
if
self
.
grad_clip
is
not
None
:
lines
.
append
(
f
' - Gradient clip:
{
indent
(
str
(
self
.
grad_clip
),
2
)
}
'
)
lines
.
append
(
f
' - EMA rate:
{
self
.
ema_rate
}
'
)
lines
.
append
(
f
' - Mixed precision dtype:
{
self
.
mix_precision_dtype
}
'
)
lines
.
append
(
f
' - Mixed precision mode:
{
self
.
mix_precision_mode
}
'
)
if
self
.
mix_precision_mode
==
'amp'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
lines
.
append
(
f
' - FP16 scale growth:
{
self
.
fp16_scale_growth
}
'
)
lines
.
append
(
f
' - Parallel mode:
{
self
.
parallel_mode
}
'
)
return
'
\n
'
.
join
(
lines
)
@
property
def
device
(
self
):
for
_
,
model
in
self
.
models
.
items
():
if
hasattr
(
model
,
'device'
):
return
model
.
device
return
next
(
list
(
self
.
models
.
values
())[
0
].
parameters
()).
device
def
init_models_and_more
(
self
,
**
kwargs
):
"""
Initialize models and more.
"""
if
self
.
world_size
>
1
:
# Prepare distributed data parallel
self
.
training_models
=
{
name
:
DDP
(
model
,
device_ids
=
[
self
.
local_rank
],
output_device
=
self
.
local_rank
,
bucket_cap_mb
=
128
,
find_unused_parameters
=
False
)
for
name
,
model
in
self
.
models
.
items
()
}
else
:
self
.
training_models
=
self
.
models
# Build master params
self
.
model_params
=
sum
(
[[
p
for
p
in
model
.
parameters
()
if
p
.
requires_grad
]
for
model
in
self
.
models
.
values
()]
,
[])
if
self
.
mix_precision_mode
==
'amp'
:
self
.
master_params
=
self
.
model_params
if
self
.
mix_precision_dtype
==
torch
.
float16
:
self
.
scaler
=
torch
.
GradScaler
()
elif
self
.
mix_precision_mode
==
'inflat_all'
:
self
.
master_params
=
make_master_params
(
self
.
model_params
)
if
self
.
mix_precision_dtype
==
torch
.
float16
:
self
.
log_scale
=
20.0
elif
self
.
mix_precision_mode
is
None
:
self
.
master_params
=
self
.
model_params
else
:
raise
NotImplementedError
(
f
'Mix precision mode
{
self
.
mix_precision_mode
}
is not implemented.'
)
# Build EMA params
if
self
.
is_master
:
self
.
ema_params
=
[
copy
.
deepcopy
(
self
.
master_params
)
for
_
in
self
.
ema_rate
]
# Initialize optimizer
if
hasattr
(
torch
.
optim
,
self
.
optimizer_config
[
'name'
]):
self
.
optimizer
=
getattr
(
torch
.
optim
,
self
.
optimizer_config
[
'name'
])(
self
.
master_params
,
**
self
.
optimizer_config
[
'args'
])
else
:
self
.
optimizer
=
globals
()[
self
.
optimizer_config
[
'name'
]](
self
.
master_params
,
**
self
.
optimizer_config
[
'args'
])
# Initalize learning rate scheduler
if
self
.
lr_scheduler_config
is
not
None
:
if
hasattr
(
torch
.
optim
.
lr_scheduler
,
self
.
lr_scheduler_config
[
'name'
]):
self
.
lr_scheduler
=
getattr
(
torch
.
optim
.
lr_scheduler
,
self
.
lr_scheduler_config
[
'name'
])(
self
.
optimizer
,
**
self
.
lr_scheduler_config
[
'args'
])
else
:
self
.
lr_scheduler
=
globals
()[
self
.
lr_scheduler_config
[
'name'
]](
self
.
optimizer
,
**
self
.
lr_scheduler_config
[
'args'
])
# Initialize elastic memory controller
if
self
.
elastic_controller_config
is
not
None
:
assert
any
([
isinstance
(
model
,
(
elastic_utils
.
ElasticModule
,
elastic_utils
.
ElasticModuleMixin
))
for
model
in
self
.
models
.
values
()]),
\
'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
self
.
elastic_controller
=
getattr
(
elastic_utils
,
self
.
elastic_controller_config
[
'name'
])(
**
self
.
elastic_controller_config
[
'args'
])
for
model
in
self
.
models
.
values
():
if
isinstance
(
model
,
(
elastic_utils
.
ElasticModule
,
elastic_utils
.
ElasticModuleMixin
)):
model
.
register_memory_controller
(
self
.
elastic_controller
)
# Initialize gradient clipper
if
self
.
grad_clip
is
not
None
:
if
isinstance
(
self
.
grad_clip
,
(
float
,
int
)):
self
.
grad_clip
=
float
(
self
.
grad_clip
)
else
:
self
.
grad_clip
=
getattr
(
grad_clip_utils
,
self
.
grad_clip
[
'name'
])(
**
self
.
grad_clip
[
'args'
])
def
prepare_dataloader
(
self
,
**
kwargs
):
"""
Prepare dataloader.
"""
self
.
data_sampler
=
ResumableSampler
(
self
.
dataset
,
shuffle
=
True
,
)
self
.
dataloader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size_per_gpu
,
num_workers
=
int
(
np
.
ceil
(
os
.
cpu_count
()
/
torch
.
cuda
.
device_count
())),
pin_memory
=
True
,
drop_last
=
True
,
persistent_workers
=
True
,
collate_fn
=
self
.
dataset
.
collate_fn
if
hasattr
(
self
.
dataset
,
'collate_fn'
)
else
None
,
sampler
=
self
.
data_sampler
,
)
self
.
data_iterator
=
cycle
(
self
.
dataloader
)
def
_master_params_to_state_dicts
(
self
,
master_params
):
"""
Convert master params to dict of state_dicts.
"""
if
self
.
mix_precision_mode
==
'inflat_all'
:
master_params
=
unflatten_master_params
(
self
.
model_params
,
master_params
)
state_dicts
=
{
name
:
model
.
state_dict
()
for
name
,
model
in
self
.
models
.
items
()}
master_params_names
=
sum
(
[[(
name
,
n
)
for
n
,
p
in
model
.
named_parameters
()
if
p
.
requires_grad
]
for
name
,
model
in
self
.
models
.
items
()]
,
[])
for
i
,
(
model_name
,
param_name
)
in
enumerate
(
master_params_names
):
state_dicts
[
model_name
][
param_name
]
=
master_params
[
i
]
return
state_dicts
def
_state_dicts_to_master_params
(
self
,
master_params
,
state_dicts
):
"""
Convert a state_dict to master params.
"""
master_params_names
=
sum
(
[[(
name
,
n
)
for
n
,
p
in
model
.
named_parameters
()
if
p
.
requires_grad
]
for
name
,
model
in
self
.
models
.
items
()]
,
[])
params
=
[
state_dicts
[
name
][
param_name
]
for
name
,
param_name
in
master_params_names
]
if
self
.
mix_precision_mode
==
'inflat_all'
:
model_params_to_master_params
(
params
,
master_params
)
else
:
for
i
,
param
in
enumerate
(
params
):
master_params
[
i
].
data
.
copy_
(
param
.
data
)
def
load
(
self
,
load_dir
,
step
=
0
):
"""
Load a checkpoint.
Should be called by all processes.
"""
if
self
.
is_master
:
print
(
f
'
\n
Loading checkpoint from step
{
step
}
...'
,
end
=
''
)
model_ckpts
=
{}
for
name
,
model
in
self
.
models
.
items
():
model_ckpt
=
torch
.
load
(
read_file_dist
(
os
.
path
.
join
(
load_dir
,
'ckpts'
,
f
'
{
name
}
_step
{
step
:
07
d
}
.pt'
)),
map_location
=
self
.
device
,
weights_only
=
True
)
model_ckpts
[
name
]
=
model_ckpt
model
.
load_state_dict
(
model_ckpt
)
self
.
_state_dicts_to_master_params
(
self
.
master_params
,
model_ckpts
)
del
model_ckpts
if
self
.
is_master
:
for
i
,
ema_rate
in
enumerate
(
self
.
ema_rate
):
ema_ckpts
=
{}
for
name
,
model
in
self
.
models
.
items
():
ema_ckpt
=
torch
.
load
(
os
.
path
.
join
(
load_dir
,
'ckpts'
,
f
'
{
name
}
_ema
{
ema_rate
}
_step
{
step
:
07
d
}
.pt'
),
map_location
=
self
.
device
,
weights_only
=
True
)
ema_ckpts
[
name
]
=
ema_ckpt
self
.
_state_dicts_to_master_params
(
self
.
ema_params
[
i
],
ema_ckpts
)
del
ema_ckpts
misc_ckpt
=
torch
.
load
(
read_file_dist
(
os
.
path
.
join
(
load_dir
,
'ckpts'
,
f
'misc_step
{
step
:
07
d
}
.pt'
)),
map_location
=
torch
.
device
(
'cpu'
),
weights_only
=
False
)
self
.
optimizer
.
load_state_dict
(
misc_ckpt
[
'optimizer'
])
self
.
step
=
misc_ckpt
[
'step'
]
self
.
data_sampler
.
load_state_dict
(
misc_ckpt
[
'data_sampler'
])
if
self
.
mix_precision_mode
==
'amp'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
self
.
scaler
.
load_state_dict
(
misc_ckpt
[
'scaler'
])
elif
self
.
mix_precision_mode
==
'inflat_all'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
self
.
log_scale
=
misc_ckpt
[
'log_scale'
]
if
self
.
lr_scheduler_config
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
misc_ckpt
[
'lr_scheduler'
])
if
self
.
elastic_controller_config
is
not
None
:
self
.
elastic_controller
.
load_state_dict
(
misc_ckpt
[
'elastic_controller'
])
if
self
.
grad_clip
is
not
None
and
not
isinstance
(
self
.
grad_clip
,
float
):
self
.
grad_clip
.
load_state_dict
(
misc_ckpt
[
'grad_clip'
])
del
misc_ckpt
if
self
.
world_size
>
1
:
dist
.
barrier
()
if
self
.
is_master
:
print
(
' Done.'
)
if
self
.
world_size
>
1
:
self
.
check_ddp
()
def
save
(
self
,
non_blocking
=
True
):
"""
Save a checkpoint.
Should be called only by the rank 0 process.
"""
assert
self
.
is_master
,
'save() should be called only by the rank 0 process.'
print
(
f
'
\n
Saving checkpoint at step
{
self
.
step
}
...'
,
end
=
''
)
model_ckpts
=
self
.
_master_params_to_state_dicts
(
self
.
master_params
)
for
name
,
model_ckpt
in
model_ckpts
.
items
():
model_ckpt
=
{
k
:
v
.
cpu
()
for
k
,
v
in
model_ckpt
.
items
()}
# Move to CPU for saving
if
non_blocking
:
threading
.
Thread
(
target
=
torch
.
save
,
args
=
(
model_ckpt
,
os
.
path
.
join
(
self
.
output_dir
,
'ckpts'
,
f
'
{
name
}
_step
{
self
.
step
:
07
d
}
.pt'
)),
).
start
()
else
:
torch
.
save
(
model_ckpt
,
os
.
path
.
join
(
self
.
output_dir
,
'ckpts'
,
f
'
{
name
}
_step
{
self
.
step
:
07
d
}
.pt'
))
for
i
,
ema_rate
in
enumerate
(
self
.
ema_rate
):
ema_ckpts
=
self
.
_master_params_to_state_dicts
(
self
.
ema_params
[
i
])
for
name
,
ema_ckpt
in
ema_ckpts
.
items
():
ema_ckpt
=
{
k
:
v
.
cpu
()
for
k
,
v
in
ema_ckpt
.
items
()}
# Move to CPU for saving
if
non_blocking
:
threading
.
Thread
(
target
=
torch
.
save
,
args
=
(
ema_ckpt
,
os
.
path
.
join
(
self
.
output_dir
,
'ckpts'
,
f
'
{
name
}
_ema
{
ema_rate
}
_step
{
self
.
step
:
07
d
}
.pt'
)),
).
start
()
else
:
torch
.
save
(
ema_ckpt
,
os
.
path
.
join
(
self
.
output_dir
,
'ckpts'
,
f
'
{
name
}
_ema
{
ema_rate
}
_step
{
self
.
step
:
07
d
}
.pt'
))
misc_ckpt
=
{
'optimizer'
:
self
.
optimizer
.
state_dict
(),
'step'
:
self
.
step
,
'data_sampler'
:
self
.
data_sampler
.
state_dict
(),
}
if
self
.
mix_precision_mode
==
'amp'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
misc_ckpt
[
'scaler'
]
=
self
.
scaler
.
state_dict
()
elif
self
.
mix_precision_mode
==
'inflat_all'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
misc_ckpt
[
'log_scale'
]
=
self
.
log_scale
if
self
.
lr_scheduler_config
is
not
None
:
misc_ckpt
[
'lr_scheduler'
]
=
self
.
lr_scheduler
.
state_dict
()
if
self
.
elastic_controller_config
is
not
None
:
misc_ckpt
[
'elastic_controller'
]
=
self
.
elastic_controller
.
state_dict
()
if
self
.
grad_clip
is
not
None
and
not
isinstance
(
self
.
grad_clip
,
float
):
misc_ckpt
[
'grad_clip'
]
=
self
.
grad_clip
.
state_dict
()
if
non_blocking
:
threading
.
Thread
(
target
=
torch
.
save
,
args
=
(
misc_ckpt
,
os
.
path
.
join
(
self
.
output_dir
,
'ckpts'
,
f
'misc_step
{
self
.
step
:
07
d
}
.pt'
)),
).
start
()
else
:
torch
.
save
(
misc_ckpt
,
os
.
path
.
join
(
self
.
output_dir
,
'ckpts'
,
f
'misc_step
{
self
.
step
:
07
d
}
.pt'
))
print
(
' Done.'
)
def
finetune_from
(
self
,
finetune_ckpt
):
"""
Finetune from a checkpoint.
Should be called by all processes.
"""
if
self
.
is_master
:
print
(
'
\n
Finetuning from:'
)
for
name
,
path
in
finetune_ckpt
.
items
():
print
(
f
' -
{
name
}
:
{
path
}
'
)
model_ckpts
=
{}
for
name
,
model
in
self
.
models
.
items
():
model_state_dict
=
model
.
state_dict
()
if
name
in
finetune_ckpt
:
model_ckpt
=
torch
.
load
(
read_file_dist
(
finetune_ckpt
[
name
]),
map_location
=
self
.
device
,
weights_only
=
True
)
for
k
,
v
in
model_ckpt
.
items
():
if
k
not
in
model_state_dict
:
if
self
.
is_master
:
print
(
f
'Warning:
{
k
}
not found in model_state_dict, skipped.'
)
model_ckpt
[
k
]
=
None
elif
model_ckpt
[
k
].
shape
!=
model_state_dict
[
k
].
shape
:
if
self
.
is_master
:
print
(
f
'Warning:
{
k
}
shape mismatch,
{
model_ckpt
[
k
].
shape
}
vs
{
model_state_dict
[
k
].
shape
}
, skipped.'
)
model_ckpt
[
k
]
=
model_state_dict
[
k
]
model_ckpt
=
{
k
:
v
for
k
,
v
in
model_ckpt
.
items
()
if
v
is
not
None
}
model_ckpts
[
name
]
=
model_ckpt
model
.
load_state_dict
(
model_ckpt
)
else
:
if
self
.
is_master
:
print
(
f
'Warning:
{
name
}
not found in finetune_ckpt, skipped.'
)
model_ckpts
[
name
]
=
model_state_dict
self
.
_state_dicts_to_master_params
(
self
.
master_params
,
model_ckpts
)
if
self
.
is_master
:
for
i
,
ema_rate
in
enumerate
(
self
.
ema_rate
):
self
.
_state_dicts_to_master_params
(
self
.
ema_params
[
i
],
model_ckpts
)
del
model_ckpts
if
self
.
world_size
>
1
:
dist
.
barrier
()
if
self
.
is_master
:
print
(
'Done.'
)
if
self
.
world_size
>
1
:
self
.
check_ddp
()
@
abstractmethod
def
run_snapshot
(
self
,
num_samples
,
batch_size
=
4
,
verbose
=
False
,
**
kwargs
):
"""
Run a snapshot of the model.
"""
pass
@
torch
.
no_grad
()
def
visualize_sample
(
self
,
sample
):
"""
Convert a sample to an image.
"""
if
hasattr
(
self
.
dataset
,
'visualize_sample'
):
return
self
.
dataset
.
visualize_sample
(
sample
)
else
:
return
sample
@
torch
.
no_grad
()
def
snapshot_dataset
(
self
,
num_samples
=
100
,
batch_size
=
4
):
"""
Sample images from the dataset.
"""
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
batch_size
=
batch_size
,
num_workers
=
1
,
shuffle
=
True
,
collate_fn
=
self
.
dataset
.
collate_fn
if
hasattr
(
self
.
dataset
,
'collate_fn'
)
else
None
,
)
save_cfg
=
{}
for
i
in
range
(
0
,
num_samples
,
batch_size
):
data
=
next
(
iter
(
dataloader
))
data
=
{
k
:
v
[:
min
(
num_samples
-
i
,
batch_size
)]
for
k
,
v
in
data
.
items
()}
data
=
recursive_to_device
(
data
,
self
.
device
)
vis
=
self
.
visualize_sample
(
data
)
if
isinstance
(
vis
,
dict
):
for
k
,
v
in
vis
.
items
():
if
f
'dataset_
{
k
}
'
not
in
save_cfg
:
save_cfg
[
f
'dataset_
{
k
}
'
]
=
[]
save_cfg
[
f
'dataset_
{
k
}
'
].
append
(
v
)
else
:
if
'dataset'
not
in
save_cfg
:
save_cfg
[
'dataset'
]
=
[]
save_cfg
[
'dataset'
].
append
(
vis
)
for
name
,
image
in
save_cfg
.
items
():
utils
.
save_image
(
torch
.
cat
(
image
,
dim
=
0
),
os
.
path
.
join
(
self
.
output_dir
,
'samples'
,
f
'
{
name
}
.jpg'
),
nrow
=
int
(
np
.
sqrt
(
num_samples
)),
normalize
=
True
,
value_range
=
self
.
dataset
.
value_range
,
)
@
torch
.
no_grad
()
def
snapshot
(
self
,
suffix
=
None
,
num_samples
=
64
,
batch_size
=
4
,
verbose
=
False
):
"""
Sample images from the model.
NOTE: This function should be called by all processes.
"""
if
self
.
is_master
:
print
(
f
'
\n
Sampling
{
num_samples
}
images...'
,
end
=
''
)
if
suffix
is
None
:
suffix
=
f
'step
{
self
.
step
:
07
d
}
'
# Assign tasks
num_samples_per_process
=
int
(
np
.
ceil
(
num_samples
/
self
.
world_size
))
amp_context
=
partial
(
torch
.
autocast
,
device_type
=
'cuda'
,
dtype
=
self
.
mix_precision_dtype
)
if
self
.
mix_precision_mode
==
'amp'
else
nullcontext
with
amp_context
():
samples
=
self
.
run_snapshot
(
num_samples_per_process
,
batch_size
=
batch_size
,
verbose
=
verbose
)
# Preprocess images
for
key
in
list
(
samples
.
keys
()):
if
samples
[
key
][
'type'
]
==
'sample'
:
vis
=
self
.
visualize_sample
(
samples
[
key
][
'value'
])
if
isinstance
(
vis
,
dict
):
for
k
,
v
in
vis
.
items
():
samples
[
f
'
{
key
}
_
{
k
}
'
]
=
{
'value'
:
v
,
'type'
:
'image'
}
del
samples
[
key
]
else
:
samples
[
key
]
=
{
'value'
:
vis
,
'type'
:
'image'
}
# Gather results
if
self
.
world_size
>
1
:
for
key
in
samples
.
keys
():
samples
[
key
][
'value'
]
=
samples
[
key
][
'value'
].
contiguous
()
if
self
.
is_master
:
all_images
=
[
torch
.
empty_like
(
samples
[
key
][
'value'
])
for
_
in
range
(
self
.
world_size
)]
else
:
all_images
=
[]
dist
.
gather
(
samples
[
key
][
'value'
],
all_images
,
dst
=
0
)
if
self
.
is_master
:
samples
[
key
][
'value'
]
=
torch
.
cat
(
all_images
,
dim
=
0
)[:
num_samples
]
# Save images
if
self
.
is_master
:
os
.
makedirs
(
os
.
path
.
join
(
self
.
output_dir
,
'samples'
,
suffix
),
exist_ok
=
True
)
for
key
in
samples
.
keys
():
if
samples
[
key
][
'type'
]
==
'image'
:
utils
.
save_image
(
samples
[
key
][
'value'
],
os
.
path
.
join
(
self
.
output_dir
,
'samples'
,
suffix
,
f
'
{
key
}
_
{
suffix
}
.jpg'
),
nrow
=
int
(
np
.
sqrt
(
num_samples
)),
normalize
=
True
,
value_range
=
self
.
dataset
.
value_range
,
)
elif
samples
[
key
][
'type'
]
==
'number'
:
min
=
samples
[
key
][
'value'
].
min
()
max
=
samples
[
key
][
'value'
].
max
()
images
=
(
samples
[
key
][
'value'
]
-
min
)
/
(
max
-
min
)
images
=
utils
.
make_grid
(
images
,
nrow
=
int
(
np
.
sqrt
(
num_samples
)),
normalize
=
False
,
)
save_image_with_notes
(
images
,
os
.
path
.
join
(
self
.
output_dir
,
'samples'
,
suffix
,
f
'
{
key
}
_
{
suffix
}
.jpg'
),
notes
=
f
'
{
key
}
min:
{
min
}
, max:
{
max
}
'
,
)
if
self
.
is_master
:
print
(
' Done.'
)
def
update_ema
(
self
):
"""
Update exponential moving average.
Should only be called by the rank 0 process.
"""
assert
self
.
is_master
,
'update_ema() should be called only by the rank 0 process.'
for
i
,
ema_rate
in
enumerate
(
self
.
ema_rate
):
for
master_param
,
ema_param
in
zip
(
self
.
master_params
,
self
.
ema_params
[
i
]):
ema_param
.
detach
().
mul_
(
ema_rate
).
add_
(
master_param
,
alpha
=
1.0
-
ema_rate
)
def
check_ddp
(
self
):
"""
Check if DDP is working properly.
Should be called by all process.
"""
if
self
.
is_master
:
print
(
'
\n
Performing DDP check...'
)
if
self
.
is_master
:
print
(
'Checking if parameters are consistent across processes...'
)
dist
.
barrier
()
try
:
for
p
in
self
.
master_params
:
# split to avoid OOM
for
i
in
range
(
0
,
p
.
numel
(),
10000000
):
sub_size
=
min
(
10000000
,
p
.
numel
()
-
i
)
sub_p
=
p
.
detach
().
view
(
-
1
)[
i
:
i
+
sub_size
]
# gather from all processes
sub_p_gather
=
[
torch
.
empty_like
(
sub_p
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
sub_p_gather
,
sub_p
)
# check if equal
assert
all
([
torch
.
equal
(
sub_p
,
sub_p_gather
[
i
])
for
i
in
range
(
self
.
world_size
)]),
'parameters are not consistent across processes'
except
AssertionError
as
e
:
if
self
.
is_master
:
print
(
f
'
\n\033
[91mError:
{
e
}
\033
[0m'
)
print
(
'DDP check failed.'
)
raise
e
dist
.
barrier
()
if
self
.
is_master
:
print
(
'Done.'
)
@
abstractmethod
def
training_losses
(
**
mb_data
):
"""
Compute training losses.
"""
pass
def
load_data
(
self
):
"""
Load data.
"""
if
self
.
prefetch_data
:
if
self
.
_data_prefetched
is
None
:
self
.
_data_prefetched
=
recursive_to_device
(
next
(
self
.
data_iterator
),
self
.
device
,
non_blocking
=
True
)
data
=
self
.
_data_prefetched
self
.
_data_prefetched
=
recursive_to_device
(
next
(
self
.
data_iterator
),
self
.
device
,
non_blocking
=
True
)
else
:
data
=
recursive_to_device
(
next
(
self
.
data_iterator
),
self
.
device
,
non_blocking
=
True
)
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
if
isinstance
(
data
,
dict
):
if
self
.
batch_split
==
1
:
data_list
=
[
data
]
else
:
batch_size
=
list
(
data
.
values
())[
0
].
shape
[
0
]
data_list
=
[
{
k
:
v
[
i
*
batch_size
//
self
.
batch_split
:(
i
+
1
)
*
batch_size
//
self
.
batch_split
]
for
k
,
v
in
data
.
items
()}
for
i
in
range
(
self
.
batch_split
)
]
elif
isinstance
(
data
,
list
):
data_list
=
data
else
:
raise
ValueError
(
'Data must be a dict or a list of dicts.'
)
return
data_list
def
run_step
(
self
,
data_list
):
"""
Run a training step.
"""
step_log
=
{
'loss'
:
{},
'status'
:
{}}
amp_context
=
partial
(
torch
.
autocast
,
device_type
=
'cuda'
,
dtype
=
self
.
mix_precision_dtype
)
if
self
.
mix_precision_mode
==
'amp'
else
nullcontext
elastic_controller_context
=
self
.
elastic_controller
.
record
if
self
.
elastic_controller_config
is
not
None
else
nullcontext
# Train
losses
=
[]
statuses
=
[]
elastic_controller_logs
=
[]
zero_grad
(
self
.
model_params
)
for
i
,
mb_data
in
enumerate
(
data_list
):
## sync at the end of each batch split
sync_contexts
=
[
self
.
training_models
[
name
].
no_sync
for
name
in
self
.
training_models
]
if
i
!=
len
(
data_list
)
-
1
and
self
.
world_size
>
1
else
[
nullcontext
]
with
nested_contexts
(
*
sync_contexts
),
elastic_controller_context
():
with
amp_context
():
loss
,
status
=
self
.
training_losses
(
**
mb_data
)
l
=
loss
[
'loss'
]
/
len
(
data_list
)
## backward
if
self
.
mix_precision_mode
==
'amp'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
self
.
scaler
.
scale
(
l
).
backward
()
elif
self
.
mix_precision_mode
==
'inflat_all'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
scaled_l
=
l
*
(
2
**
self
.
log_scale
)
scaled_l
.
backward
()
else
:
l
.
backward
()
## log
losses
.
append
(
dict_foreach
(
loss
,
lambda
x
:
x
.
item
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
))
statuses
.
append
(
dict_foreach
(
status
,
lambda
x
:
x
.
item
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
))
if
self
.
elastic_controller_config
is
not
None
:
elastic_controller_logs
.
append
(
self
.
elastic_controller
.
log
())
## gradient clip
if
self
.
grad_clip
is
not
None
:
if
self
.
mix_precision_mode
==
'amp'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
self
.
scaler
.
unscale_
(
self
.
optimizer
)
elif
self
.
mix_precision_mode
==
'inflat_all'
:
model_grads_to_master_grads
(
self
.
model_params
,
self
.
master_params
)
if
self
.
mix_precision_dtype
==
torch
.
float16
:
self
.
master_params
[
0
].
grad
.
mul_
(
1.0
/
(
2
**
self
.
log_scale
))
if
isinstance
(
self
.
grad_clip
,
float
):
grad_norm
=
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
master_params
,
self
.
grad_clip
)
else
:
grad_norm
=
self
.
grad_clip
(
self
.
master_params
)
if
torch
.
isfinite
(
grad_norm
):
statuses
[
-
1
][
'grad_norm'
]
=
grad_norm
.
item
()
## step
if
self
.
mix_precision_mode
==
'amp'
and
self
.
mix_precision_dtype
==
torch
.
float16
:
prev_scale
=
self
.
scaler
.
get_scale
()
self
.
scaler
.
step
(
self
.
optimizer
)
self
.
scaler
.
update
()
elif
self
.
mix_precision_mode
==
'inflat_all'
:
if
self
.
mix_precision_dtype
==
torch
.
float16
:
prev_scale
=
2
**
self
.
log_scale
if
not
any
(
not
p
.
grad
.
isfinite
().
all
()
for
p
in
self
.
model_params
):
if
self
.
grad_clip
is
None
:
model_grads_to_master_grads
(
self
.
model_params
,
self
.
master_params
)
self
.
master_params
[
0
].
grad
.
mul_
(
1.0
/
(
2
**
self
.
log_scale
))
self
.
optimizer
.
step
()
master_params_to_model_params
(
self
.
model_params
,
self
.
master_params
)
self
.
log_scale
+=
self
.
fp16_scale_growth
else
:
self
.
log_scale
-=
1
else
:
prev_scale
=
1.0
if
self
.
grad_clip
is
None
:
model_grads_to_master_grads
(
self
.
model_params
,
self
.
master_params
)
if
not
any
(
not
p
.
grad
.
isfinite
().
all
()
for
p
in
self
.
master_params
):
self
.
optimizer
.
step
()
master_params_to_model_params
(
self
.
model_params
,
self
.
master_params
)
else
:
print
(
'
\n\033
[93mWarning: NaN detected in gradients. Skipping update.
\033
[0m'
)
else
:
prev_scale
=
1.0
if
not
any
(
not
p
.
grad
.
isfinite
().
all
()
for
p
in
self
.
model_params
):
self
.
optimizer
.
step
()
else
:
print
(
'
\n\033
[93mWarning: NaN detected in gradients. Skipping update.
\033
[0m'
)
## adjust learning rate
if
self
.
lr_scheduler_config
is
not
None
:
statuses
[
-
1
][
'lr'
]
=
self
.
lr_scheduler
.
get_last_lr
()[
0
]
self
.
lr_scheduler
.
step
()
# Logs
step_log
[
'loss'
]
=
dict_reduce
(
losses
,
lambda
x
:
np
.
mean
(
x
))
step_log
[
'status'
]
=
dict_reduce
(
statuses
,
lambda
x
:
np
.
mean
(
x
),
special_func
=
{
'min'
:
lambda
x
:
np
.
min
(
x
),
'max'
:
lambda
x
:
np
.
max
(
x
)})
if
self
.
elastic_controller_config
is
not
None
:
step_log
[
'elastic'
]
=
dict_reduce
(
elastic_controller_logs
,
lambda
x
:
np
.
mean
(
x
))
if
self
.
grad_clip
is
not
None
:
step_log
[
'grad_clip'
]
=
self
.
grad_clip
if
isinstance
(
self
.
grad_clip
,
float
)
else
self
.
grad_clip
.
log
()
# Check grad and norm of each param
if
self
.
log_param_stats
:
param_norms
=
{}
param_grads
=
{}
for
model_name
,
model
in
self
.
models
.
items
():
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
param_norms
[
f
'
{
model_name
}
.
{
name
}
'
]
=
param
.
norm
().
item
()
if
param
.
grad
is
not
None
and
torch
.
isfinite
(
param
.
grad
).
all
():
param_grads
[
f
'
{
model_name
}
.
{
name
}
'
]
=
param
.
grad
.
norm
().
item
()
/
prev_scale
step_log
[
'param_norms'
]
=
param_norms
step_log
[
'param_grads'
]
=
param_grads
# Update exponential moving average
if
self
.
is_master
:
self
.
update_ema
()
return
step_log
def
save_logs
(
self
):
log_str
=
'
\n
'
.
join
([
f
'
{
step
}
:
{
json
.
dumps
(
dict_foreach
(
log
,
lambda
x
:
float
(
x
)))
}
'
for
step
,
log
in
self
.
log
])
with
open
(
os
.
path
.
join
(
self
.
output_dir
,
'log.txt'
),
'a'
)
as
log_file
:
log_file
.
write
(
log_str
+
'
\n
'
)
# show with mlflow
log_show
=
[
l
for
_
,
l
in
self
.
log
if
not
dict_any
(
l
,
lambda
x
:
np
.
isnan
(
x
))]
log_show
=
dict_reduce
(
log_show
,
lambda
x
:
np
.
mean
(
x
))
log_show
=
dict_flatten
(
log_show
,
sep
=
'/'
)
for
key
,
value
in
log_show
.
items
():
self
.
writer
.
add_scalar
(
key
,
value
,
self
.
step
)
self
.
log
=
[]
def
check_abort
(
self
):
"""
Check if training should be aborted due to certain conditions.
"""
# 1. If log_scale in inflat_all mode is less than 0
if
self
.
mix_precision_dtype
==
torch
.
float16
and
\
self
.
mix_precision_mode
==
'inflat_all'
and
\
self
.
log_scale
<
0
:
if
self
.
is_master
:
print
(
'
\n\n\033
[91m'
)
print
(
f
'ABORT: log_scale in inflat_all mode is less than 0 at step
{
self
.
step
}
.'
)
print
(
'This indicates that the model is diverging. You should look into the model and the data.'
)
print
(
'
\033
[0m'
)
self
.
save
(
non_blocking
=
False
)
self
.
save_logs
()
if
self
.
world_size
>
1
:
dist
.
barrier
()
raise
ValueError
(
'ABORT: log_scale in inflat_all mode is less than 0.'
)
def
run
(
self
):
"""
Run training.
"""
if
self
.
is_master
:
print
(
'
\n
Starting training...'
)
self
.
snapshot_dataset
(
batch_size
=
self
.
snapshot_batch_size
)
if
self
.
step
==
0
:
self
.
snapshot
(
suffix
=
'init'
,
batch_size
=
self
.
snapshot_batch_size
)
else
:
# resume
self
.
snapshot
(
suffix
=
f
'resume_step
{
self
.
step
:
07
d
}
'
,
batch_size
=
self
.
snapshot_batch_size
)
time_last_print
=
0.0
time_elapsed
=
0.0
while
self
.
step
<
self
.
max_steps
:
time_start
=
time
.
time
()
data_list
=
self
.
load_data
()
step_log
=
self
.
run_step
(
data_list
)
time_end
=
time
.
time
()
time_elapsed
+=
time_end
-
time_start
self
.
step
+=
1
# Print progress
if
self
.
is_master
and
self
.
step
%
self
.
i_print
==
0
:
speed
=
self
.
i_print
/
(
time_elapsed
-
time_last_print
)
*
3600
columns
=
[
f
'Step:
{
self
.
step
}
/
{
self
.
max_steps
}
(
{
self
.
step
/
self
.
max_steps
*
100
:.
2
f
}
%)'
,
f
'Elapsed:
{
time_elapsed
/
3600
:.
2
f
}
h'
,
f
'Speed:
{
speed
:.
2
f
}
steps/h'
,
f
'ETA:
{
(
self
.
max_steps
-
self
.
step
)
/
speed
:.
2
f
}
h'
,
]
print
(
' | '
.
join
([
c
.
ljust
(
25
)
for
c
in
columns
]),
flush
=
True
)
time_last_print
=
time_elapsed
# Check ddp
if
self
.
parallel_mode
==
'ddp'
and
self
.
world_size
>
1
and
self
.
i_ddpcheck
is
not
None
and
self
.
step
%
self
.
i_ddpcheck
==
0
:
self
.
check_ddp
()
# Sample images
if
self
.
step
%
self
.
i_sample
==
0
:
self
.
snapshot
()
if
self
.
is_master
:
self
.
log
.
append
((
self
.
step
,
{}))
# Log time
self
.
log
[
-
1
][
1
][
'time'
]
=
{
'step'
:
time_end
-
time_start
,
'elapsed'
:
time_elapsed
,
}
# Log losses
if
step_log
is
not
None
:
self
.
log
[
-
1
][
1
].
update
(
step_log
)
# Log scale
if
self
.
mix_precision_dtype
==
torch
.
float16
:
if
self
.
mix_precision_mode
==
'amp'
:
self
.
log
[
-
1
][
1
][
'scale'
]
=
self
.
scaler
.
get_scale
()
elif
self
.
mix_precision_mode
==
'inflat_all'
:
self
.
log
[
-
1
][
1
][
'log_scale'
]
=
self
.
log_scale
# Save log
if
self
.
step
%
self
.
i_log
==
0
:
self
.
save_logs
()
# Save checkpoint
if
self
.
step
%
self
.
i_save
==
0
:
self
.
save
()
# Check abort
self
.
check_abort
()
self
.
snapshot
(
suffix
=
'final'
,
batch_size
=
self
.
snapshot_batch_size
)
if
self
.
world_size
>
1
:
dist
.
barrier
()
if
self
.
is_master
:
self
.
writer
.
close
()
print
(
'Training finished.'
)
def
profile
(
self
,
wait
=
2
,
warmup
=
3
,
active
=
5
):
"""
Profile the training loop.
"""
with
torch
.
profiler
.
profile
(
schedule
=
torch
.
profiler
.
schedule
(
wait
=
wait
,
warmup
=
warmup
,
active
=
active
,
repeat
=
1
),
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
os
.
path
.
join
(
self
.
output_dir
,
'profile'
)),
profile_memory
=
True
,
with_stack
=
True
,
)
as
prof
:
for
_
in
range
(
wait
+
warmup
+
active
):
self
.
run_step
()
prof
.
step
()
TRELLIS.2_DCU/trellis2/trainers/flow_matching/flow_matching.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
copy
import
torch
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
import
numpy
as
np
from
easydict
import
EasyDict
as
edict
from
..basic
import
BasicTrainer
from
...pipelines
import
samplers
from
...utils.general_utils
import
dict_reduce
from
.mixins.classifier_free_guidance
import
ClassifierFreeGuidanceMixin
from
.mixins.text_conditioned
import
TextConditionedMixin
from
.mixins.image_conditioned
import
ImageConditionedMixin
class
FlowMatchingTrainer
(
BasicTrainer
):
"""
Trainer for diffusion model with flow matching objective.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
"""
def
__init__
(
self
,
*
args
,
t_schedule
:
dict
=
{
'name'
:
'logitNormal'
,
'args'
:
{
'mean'
:
0.0
,
'std'
:
1.0
,
}
},
sigma_min
:
float
=
1e-5
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
t_schedule
=
t_schedule
self
.
sigma_min
=
sigma_min
def
diffuse
(
self
,
x_0
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
noise
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
Args:
x_0: The [N x C x ...] tensor of noiseless inputs.
t: The [N] tensor of diffusion steps [0-1].
noise: If specified, use this noise instead of generating new noise.
Returns:
x_t, the noisy version of x_0 under timestep t.
"""
if
noise
is
None
:
noise
=
torch
.
randn_like
(
x_0
)
assert
noise
.
shape
==
x_0
.
shape
,
"noise must have same shape as x_0"
t
=
t
.
view
(
-
1
,
*
[
1
for
_
in
range
(
len
(
x_0
.
shape
)
-
1
)])
x_t
=
(
1
-
t
)
*
x_0
+
(
self
.
sigma_min
+
(
1
-
self
.
sigma_min
)
*
t
)
*
noise
return
x_t
def
reverse_diffuse
(
self
,
x_t
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Get original image from noisy version under timestep t.
"""
assert
noise
.
shape
==
x_t
.
shape
,
"noise must have same shape as x_t"
t
=
t
.
view
(
-
1
,
*
[
1
for
_
in
range
(
len
(
x_t
.
shape
)
-
1
)])
x_0
=
(
x_t
-
(
self
.
sigma_min
+
(
1
-
self
.
sigma_min
)
*
t
)
*
noise
)
/
(
1
-
t
)
return
x_0
def
get_v
(
self
,
x_0
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute the velocity of the diffusion process at time t.
"""
return
(
1
-
self
.
sigma_min
)
*
noise
-
x_0
def
get_cond
(
self
,
cond
,
**
kwargs
):
"""
Get the conditioning data.
"""
return
cond
def
get_inference_cond
(
self
,
cond
,
**
kwargs
):
"""
Get the conditioning data for inference.
"""
return
{
'cond'
:
cond
,
**
kwargs
}
def
get_sampler
(
self
,
**
kwargs
)
->
samplers
.
FlowEulerSampler
:
"""
Get the sampler for the diffusion process.
"""
return
samplers
.
FlowEulerSampler
(
self
.
sigma_min
)
def
vis_cond
(
self
,
**
kwargs
):
"""
Visualize the conditioning data.
"""
return
{}
def
sample_t
(
self
,
batch_size
:
int
)
->
torch
.
Tensor
:
"""
Sample timesteps.
"""
if
self
.
t_schedule
[
'name'
]
==
'uniform'
:
t
=
torch
.
rand
(
batch_size
)
elif
self
.
t_schedule
[
'name'
]
==
'logitNormal'
:
mean
=
self
.
t_schedule
[
'args'
][
'mean'
]
std
=
self
.
t_schedule
[
'args'
][
'std'
]
t
=
torch
.
sigmoid
(
torch
.
randn
(
batch_size
)
*
std
+
mean
)
else
:
raise
ValueError
(
f
"Unknown t_schedule:
{
self
.
t_schedule
[
'name'
]
}
"
)
return
t
def
training_losses
(
self
,
x_0
:
torch
.
Tensor
,
cond
=
None
,
**
kwargs
)
->
Tuple
[
Dict
,
Dict
]:
"""
Compute training losses for a single timestep.
Args:
x_0: The [N x C x ...] tensor of noiseless inputs.
cond: The [N x ...] tensor of additional conditions.
kwargs: Additional arguments to pass to the backbone.
Returns:
a dict with the key "loss" containing a tensor of shape [N].
may also contain other keys for different terms.
"""
noise
=
torch
.
randn_like
(
x_0
)
t
=
self
.
sample_t
(
x_0
.
shape
[
0
]).
to
(
x_0
.
device
).
float
()
x_t
=
self
.
diffuse
(
x_0
,
t
,
noise
=
noise
)
cond
=
self
.
get_cond
(
cond
,
**
kwargs
)
pred
=
self
.
training_models
[
'denoiser'
](
x_t
,
t
*
1000
,
cond
,
**
kwargs
)
assert
pred
.
shape
==
noise
.
shape
==
x_0
.
shape
target
=
self
.
get_v
(
x_0
,
noise
,
t
)
terms
=
edict
()
terms
[
"mse"
]
=
F
.
mse_loss
(
pred
,
target
)
terms
[
"loss"
]
=
terms
[
"mse"
]
# log loss with time bins
mse_per_instance
=
np
.
array
([
F
.
mse_loss
(
pred
[
i
],
target
[
i
]).
item
()
for
i
in
range
(
x_0
.
shape
[
0
])
])
time_bin
=
np
.
digitize
(
t
.
cpu
().
numpy
(),
np
.
linspace
(
0
,
1
,
11
))
-
1
for
i
in
range
(
10
):
if
(
time_bin
==
i
).
sum
()
!=
0
:
terms
[
f
"bin_
{
i
}
"
]
=
{
"mse"
:
mse_per_instance
[
time_bin
==
i
].
mean
()}
return
terms
,
{}
@
torch
.
no_grad
()
def
run_snapshot
(
self
,
num_samples
:
int
,
batch_size
:
int
,
verbose
:
bool
=
False
,
)
->
Dict
:
dataloader
=
DataLoader
(
copy
.
deepcopy
(
self
.
dataset
),
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
0
,
collate_fn
=
self
.
dataset
.
collate_fn
if
hasattr
(
self
.
dataset
,
'collate_fn'
)
else
None
,
)
# inference
sampler
=
self
.
get_sampler
()
sample_gt
=
[]
sample
=
[]
cond_vis
=
[]
for
i
in
range
(
0
,
num_samples
,
batch_size
):
batch
=
min
(
batch_size
,
num_samples
-
i
)
data
=
next
(
iter
(
dataloader
))
data
=
{
k
:
v
[:
batch
].
cuda
()
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
[:
batch
]
for
k
,
v
in
data
.
items
()}
noise
=
torch
.
randn_like
(
data
[
'x_0'
])
sample_gt
.
append
(
data
[
'x_0'
])
cond_vis
.
append
(
self
.
vis_cond
(
**
data
))
del
data
[
'x_0'
]
args
=
self
.
get_inference_cond
(
**
data
)
res
=
sampler
.
sample
(
self
.
models
[
'denoiser'
],
noise
=
noise
,
**
args
,
steps
=
50
,
guidance_strength
=
3.0
,
verbose
=
verbose
,
)
sample
.
append
(
res
.
samples
)
sample_gt
=
torch
.
cat
(
sample_gt
,
dim
=
0
)
sample
=
torch
.
cat
(
sample
,
dim
=
0
)
sample_dict
=
{
'sample_gt'
:
{
'value'
:
sample_gt
,
'type'
:
'sample'
},
'sample'
:
{
'value'
:
sample
,
'type'
:
'sample'
},
}
sample_dict
.
update
(
dict_reduce
(
cond_vis
,
None
,
{
'value'
:
lambda
x
:
torch
.
cat
(
x
,
dim
=
0
),
'type'
:
lambda
x
:
x
[
0
],
}))
return
sample_dict
class
FlowMatchingCFGTrainer
(
ClassifierFreeGuidanceMixin
,
FlowMatchingTrainer
):
"""
Trainer for diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
"""
pass
class
TextConditionedFlowMatchingCFGTrainer
(
TextConditionedMixin
,
FlowMatchingCFGTrainer
):
"""
Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
text_cond_model(str): Text conditioning model.
"""
pass
class
ImageConditionedFlowMatchingCFGTrainer
(
ImageConditionedMixin
,
FlowMatchingCFGTrainer
):
"""
Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass
TRELLIS.2_DCU/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py
0 → 100644
View file @
f05e915f
import
torch
import
numpy
as
np
from
....utils.general_utils
import
dict_foreach
from
....pipelines
import
samplers
class
ClassifierFreeGuidanceMixin
:
def
__init__
(
self
,
*
args
,
p_uncond
:
float
=
0.1
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
p_uncond
=
p_uncond
def
get_cond
(
self
,
cond
,
neg_cond
=
None
,
**
kwargs
):
"""
Get the conditioning data.
"""
assert
neg_cond
is
not
None
,
"neg_cond must be provided for classifier-free guidance"
if
self
.
p_uncond
>
0
:
# randomly drop the class label
def
get_batch_size
(
cond
):
if
isinstance
(
cond
,
torch
.
Tensor
):
return
cond
.
shape
[
0
]
elif
isinstance
(
cond
,
list
):
return
len
(
cond
)
else
:
raise
ValueError
(
f
"Unsupported type of cond:
{
type
(
cond
)
}
"
)
ref_cond
=
cond
if
not
isinstance
(
cond
,
dict
)
else
cond
[
list
(
cond
.
keys
())[
0
]]
B
=
get_batch_size
(
ref_cond
)
def
select
(
cond
,
neg_cond
,
mask
):
if
isinstance
(
cond
,
torch
.
Tensor
):
mask
=
torch
.
tensor
(
mask
,
device
=
cond
.
device
).
reshape
(
-
1
,
*
[
1
]
*
(
cond
.
ndim
-
1
))
return
torch
.
where
(
mask
,
neg_cond
,
cond
)
elif
isinstance
(
cond
,
list
):
return
[
nc
if
m
else
c
for
c
,
nc
,
m
in
zip
(
cond
,
neg_cond
,
mask
)]
else
:
raise
ValueError
(
f
"Unsupported type of cond:
{
type
(
cond
)
}
"
)
mask
=
list
(
np
.
random
.
rand
(
B
)
<
self
.
p_uncond
)
if
not
isinstance
(
cond
,
dict
):
cond
=
select
(
cond
,
neg_cond
,
mask
)
else
:
cond
=
dict_foreach
([
cond
,
neg_cond
],
lambda
x
:
select
(
x
[
0
],
x
[
1
],
mask
))
return
cond
def
get_inference_cond
(
self
,
cond
,
neg_cond
=
None
,
**
kwargs
):
"""
Get the conditioning data for inference.
"""
assert
neg_cond
is
not
None
,
"neg_cond must be provided for classifier-free guidance"
return
{
'cond'
:
cond
,
'neg_cond'
:
neg_cond
,
**
kwargs
}
def
get_sampler
(
self
,
**
kwargs
)
->
samplers
.
FlowEulerCfgSampler
:
"""
Get the sampler for the diffusion process.
"""
return
samplers
.
FlowEulerCfgSampler
(
self
.
sigma_min
)
TRELLIS.2_DCU/trellis2/trainers/flow_matching/mixins/image_conditioned.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn.functional
as
F
from
torchvision
import
transforms
from
transformers
import
DINOv3ViTModel
import
numpy
as
np
from
PIL
import
Image
from
....utils
import
dist_utils
class
DinoV2FeatureExtractor
:
"""
Feature extractor for DINOv2 models.
"""
def
__init__
(
self
,
model_name
:
str
):
self
.
model_name
=
model_name
self
.
model
=
torch
.
hub
.
load
(
'facebookresearch/dinov2'
,
model_name
,
pretrained
=
True
)
self
.
model
.
eval
()
self
.
transform
=
transforms
.
Compose
([
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
])
def
to
(
self
,
device
):
self
.
model
.
to
(
device
)
def
cuda
(
self
):
self
.
model
.
cuda
()
def
cpu
(
self
):
self
.
model
.
cpu
()
@
torch
.
no_grad
()
def
__call__
(
self
,
image
:
Union
[
torch
.
Tensor
,
List
[
Image
.
Image
]])
->
torch
.
Tensor
:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if
isinstance
(
image
,
torch
.
Tensor
):
assert
image
.
ndim
==
4
,
"Image tensor should be batched (B, C, H, W)"
elif
isinstance
(
image
,
list
):
assert
all
(
isinstance
(
i
,
Image
.
Image
)
for
i
in
image
),
"Image list should be list of PIL images"
image
=
[
i
.
resize
((
518
,
518
),
Image
.
LANCZOS
)
for
i
in
image
]
image
=
[
np
.
array
(
i
.
convert
(
'RGB'
)).
astype
(
np
.
float32
)
/
255
for
i
in
image
]
image
=
[
torch
.
from_numpy
(
i
).
permute
(
2
,
0
,
1
).
float
()
for
i
in
image
]
image
=
torch
.
stack
(
image
).
cuda
()
else
:
raise
ValueError
(
f
"Unsupported type of image:
{
type
(
image
)
}
"
)
image
=
self
.
transform
(
image
).
cuda
()
features
=
self
.
model
(
image
,
is_training
=
True
)[
'x_prenorm'
]
patchtokens
=
F
.
layer_norm
(
features
,
features
.
shape
[
-
1
:])
return
patchtokens
class
DinoV3FeatureExtractor
:
"""
Feature extractor for DINOv3 models.
"""
def
__init__
(
self
,
model_name
:
str
,
image_size
=
512
):
self
.
model_name
=
model_name
self
.
model
=
DINOv3ViTModel
.
from_pretrained
(
model_name
)
self
.
model
.
eval
()
self
.
image_size
=
image_size
self
.
transform
=
transforms
.
Compose
([
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
])
def
to
(
self
,
device
):
self
.
model
.
to
(
device
)
def
cuda
(
self
):
self
.
model
.
cuda
()
def
cpu
(
self
):
self
.
model
.
cpu
()
def
extract_features
(
self
,
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image
=
image
.
to
(
self
.
model
.
embeddings
.
patch_embeddings
.
weight
.
dtype
)
hidden_states
=
self
.
model
.
embeddings
(
image
,
bool_masked_pos
=
None
)
position_embeddings
=
self
.
model
.
rope_embeddings
(
image
)
for
i
,
layer_module
in
enumerate
(
self
.
model
.
layer
):
hidden_states
=
layer_module
(
hidden_states
,
position_embeddings
=
position_embeddings
,
)
return
F
.
layer_norm
(
hidden_states
,
hidden_states
.
shape
[
-
1
:])
@
torch
.
no_grad
()
def
__call__
(
self
,
image
:
Union
[
torch
.
Tensor
,
List
[
Image
.
Image
]])
->
torch
.
Tensor
:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if
isinstance
(
image
,
torch
.
Tensor
):
assert
image
.
ndim
==
4
,
"Image tensor should be batched (B, C, H, W)"
elif
isinstance
(
image
,
list
):
assert
all
(
isinstance
(
i
,
Image
.
Image
)
for
i
in
image
),
"Image list should be list of PIL images"
image
=
[
i
.
resize
((
self
.
image_size
,
self
.
image_size
),
Image
.
LANCZOS
)
for
i
in
image
]
image
=
[
np
.
array
(
i
.
convert
(
'RGB'
)).
astype
(
np
.
float32
)
/
255
for
i
in
image
]
image
=
[
torch
.
from_numpy
(
i
).
permute
(
2
,
0
,
1
).
float
()
for
i
in
image
]
image
=
torch
.
stack
(
image
).
cuda
()
else
:
raise
ValueError
(
f
"Unsupported type of image:
{
type
(
image
)
}
"
)
image
=
self
.
transform
(
image
).
cuda
()
features
=
self
.
extract_features
(
image
)
return
features
class
ImageConditionedMixin
:
"""
Mixin for image-conditioned models.
Args:
image_cond_model: The image conditioning model.
"""
def
__init__
(
self
,
*
args
,
image_cond_model
:
dict
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
image_cond_model_config
=
image_cond_model
self
.
image_cond_model
=
None
# the model is init lazily
def
_init_image_cond_model
(
self
):
"""
Initialize the image conditioning model.
"""
with
dist_utils
.
local_master_first
():
self
.
image_cond_model
=
globals
()[
self
.
image_cond_model_config
[
'name'
]](
**
self
.
image_cond_model_config
.
get
(
'args'
,
{}))
self
.
image_cond_model
.
cuda
()
@
torch
.
no_grad
()
def
encode_image
(
self
,
image
:
Union
[
torch
.
Tensor
,
List
[
Image
.
Image
]])
->
torch
.
Tensor
:
"""
Encode the image.
"""
if
self
.
image_cond_model
is
None
:
self
.
_init_image_cond_model
()
features
=
self
.
image_cond_model
(
image
)
return
features
def
get_cond
(
self
,
cond
,
**
kwargs
):
"""
Get the conditioning data.
"""
cond
=
self
.
encode_image
(
cond
)
kwargs
[
'neg_cond'
]
=
torch
.
zeros_like
(
cond
)
cond
=
super
().
get_cond
(
cond
,
**
kwargs
)
return
cond
def
get_inference_cond
(
self
,
cond
,
**
kwargs
):
"""
Get the conditioning data for inference.
"""
cond
=
self
.
encode_image
(
cond
)
kwargs
[
'neg_cond'
]
=
torch
.
zeros_like
(
cond
)
cond
=
super
().
get_inference_cond
(
cond
,
**
kwargs
)
return
cond
def
vis_cond
(
self
,
cond
,
**
kwargs
):
"""
Visualize the conditioning data.
"""
return
{
'image'
:
{
'value'
:
cond
,
'type'
:
'image'
}}
class
MultiImageConditionedMixin
:
"""
Mixin for multiple-image-conditioned models.
Args:
image_cond_model: The image conditioning model.
"""
def
__init__
(
self
,
*
args
,
image_cond_model
:
dict
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
image_cond_model_config
=
image_cond_model
self
.
image_cond_model
=
None
# the model is init lazily
def
_init_image_cond_model
(
self
):
"""
Initialize the image conditioning model.
"""
with
dist_utils
.
local_master_first
():
self
.
image_cond_model
=
globals
()[
self
.
image_cond_model_config
[
'name'
]](
**
self
.
image_cond_model_config
.
get
(
'args'
,
{}))
@
torch
.
no_grad
()
def
encode_images
(
self
,
images
:
Union
[
List
[
torch
.
Tensor
],
List
[
List
[
Image
.
Image
]]])
->
List
[
torch
.
Tensor
]:
"""
Encode the image.
"""
if
self
.
image_cond_model
is
None
:
self
.
_init_image_cond_model
()
seqlen
=
[
len
(
i
)
for
i
in
images
]
images
=
torch
.
cat
(
images
,
dim
=
0
)
if
isinstance
(
images
[
0
],
torch
.
Tensor
)
else
sum
(
images
,
[])
features
=
self
.
image_cond_model
(
images
)
features
=
torch
.
split
(
features
,
seqlen
)
features
=
[
feature
.
reshape
(
-
1
,
feature
.
shape
[
-
1
])
for
feature
in
features
]
return
features
def
get_cond
(
self
,
cond
,
**
kwargs
):
"""
Get the conditioning data.
"""
cond
=
self
.
encode_images
(
cond
)
kwargs
[
'neg_cond'
]
=
[
torch
.
zeros_like
(
cond
[
0
][:
1
,
:])
for
_
in
range
(
len
(
cond
))
]
cond
=
super
().
get_cond
(
cond
,
**
kwargs
)
return
cond
def
get_inference_cond
(
self
,
cond
,
**
kwargs
):
"""
Get the conditioning data for inference.
"""
cond
=
self
.
encode_images
(
cond
)
kwargs
[
'neg_cond'
]
=
[
torch
.
zeros_like
(
cond
[
0
][:
1
,
:])
for
_
in
range
(
len
(
cond
))
]
cond
=
super
().
get_inference_cond
(
cond
,
**
kwargs
)
return
cond
def
vis_cond
(
self
,
cond
,
**
kwargs
):
"""
Visualize the conditioning data.
"""
H
,
W
=
cond
[
0
].
shape
[
-
2
:]
vis
=
[]
for
images
in
cond
:
canvas
=
torch
.
zeros
(
3
,
H
*
2
,
W
*
2
,
device
=
images
.
device
,
dtype
=
images
.
dtype
)
for
i
,
image
in
enumerate
(
images
):
if
i
==
4
:
break
kh
=
i
//
2
kw
=
i
%
2
canvas
[:,
kh
*
H
:(
kh
+
1
)
*
H
,
kw
*
W
:(
kw
+
1
)
*
W
]
=
image
vis
.
append
(
canvas
)
vis
=
torch
.
stack
(
vis
)
return
{
'image'
:
{
'value'
:
vis
,
'type'
:
'image'
}}
TRELLIS.2_DCU/trellis2/trainers/flow_matching/mixins/text_conditioned.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
os
os
.
environ
[
'TOKENIZERS_PARALLELISM'
]
=
'true'
import
torch
from
transformers
import
AutoTokenizer
,
CLIPTextModel
from
....utils
import
dist_utils
class
TextConditionedMixin
:
"""
Mixin for text-conditioned models.
Args:
text_cond_model: The text conditioning model.
"""
def
__init__
(
self
,
*
args
,
text_cond_model
:
str
=
'openai/clip-vit-large-patch14'
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
text_cond_model_name
=
text_cond_model
self
.
text_cond_model
=
None
# the model is init lazily
def
_init_text_cond_model
(
self
):
"""
Initialize the text conditioning model.
"""
# load model
with
dist_utils
.
local_master_first
():
model
=
CLIPTextModel
.
from_pretrained
(
self
.
text_cond_model_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
text_cond_model_name
)
model
.
eval
()
model
=
model
.
cuda
()
self
.
text_cond_model
=
{
'model'
:
model
,
'tokenizer'
:
tokenizer
,
}
self
.
text_cond_model
[
'null_cond'
]
=
self
.
encode_text
([
''
])
@
torch
.
no_grad
()
def
encode_text
(
self
,
text
:
List
[
str
])
->
torch
.
Tensor
:
"""
Encode the text.
"""
assert
isinstance
(
text
,
list
)
and
isinstance
(
text
[
0
],
str
),
"TextConditionedMixin only supports list of strings as cond"
if
self
.
text_cond_model
is
None
:
self
.
_init_text_cond_model
()
encoding
=
self
.
text_cond_model
[
'tokenizer'
](
text
,
max_length
=
77
,
padding
=
'max_length'
,
truncation
=
True
,
return_tensors
=
'pt'
)
tokens
=
encoding
[
'input_ids'
].
cuda
()
embeddings
=
self
.
text_cond_model
[
'model'
](
input_ids
=
tokens
).
last_hidden_state
return
embeddings
def
get_cond
(
self
,
cond
,
**
kwargs
):
"""
Get the conditioning data.
"""
cond
=
self
.
encode_text
(
cond
)
kwargs
[
'neg_cond'
]
=
self
.
text_cond_model
[
'null_cond'
].
repeat
(
cond
.
shape
[
0
],
1
,
1
)
cond
=
super
().
get_cond
(
cond
,
**
kwargs
)
return
cond
def
get_inference_cond
(
self
,
cond
,
**
kwargs
):
"""
Get the conditioning data for inference.
"""
cond
=
self
.
encode_text
(
cond
)
kwargs
[
'neg_cond'
]
=
self
.
text_cond_model
[
'null_cond'
].
repeat
(
cond
.
shape
[
0
],
1
,
1
)
cond
=
super
().
get_inference_cond
(
cond
,
**
kwargs
)
return
cond
TRELLIS.2_DCU/trellis2/trainers/flow_matching/sparse_flow_matching.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
os
import
copy
import
functools
import
torch
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
import
numpy
as
np
from
easydict
import
EasyDict
as
edict
from
...modules
import
sparse
as
sp
from
...utils.general_utils
import
dict_reduce
from
...utils.data_utils
import
recursive_to_device
,
cycle
,
BalancedResumableSampler
from
.flow_matching
import
FlowMatchingTrainer
from
.mixins.classifier_free_guidance
import
ClassifierFreeGuidanceMixin
from
.mixins.text_conditioned
import
TextConditionedMixin
from
.mixins.image_conditioned
import
ImageConditionedMixin
,
MultiImageConditionedMixin
class
SparseFlowMatchingTrainer
(
FlowMatchingTrainer
):
"""
Trainer for sparse diffusion model with flow matching objective.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
"""
def
prepare_dataloader
(
self
,
**
kwargs
):
"""
Prepare dataloader.
"""
self
.
data_sampler
=
BalancedResumableSampler
(
self
.
dataset
,
shuffle
=
True
,
batch_size
=
self
.
batch_size_per_gpu
,
)
self
.
dataloader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size_per_gpu
,
num_workers
=
int
(
np
.
ceil
(
os
.
cpu_count
()
/
torch
.
cuda
.
device_count
())),
pin_memory
=
True
,
drop_last
=
True
,
persistent_workers
=
True
,
collate_fn
=
functools
.
partial
(
self
.
dataset
.
collate_fn
,
split_size
=
self
.
batch_split
),
sampler
=
self
.
data_sampler
,
)
self
.
data_iterator
=
cycle
(
self
.
dataloader
)
def
training_losses
(
self
,
x_0
:
sp
.
SparseTensor
,
cond
=
None
,
**
kwargs
)
->
Tuple
[
Dict
,
Dict
]:
"""
Compute training losses for a single timestep.
Args:
x_0: The [N x ... x C] sparse tensor of the inputs.
cond: The [N x ...] tensor of additional conditions.
kwargs: Additional arguments to pass to the backbone.
Returns:
a dict with the key "loss" containing a tensor of shape [N].
may also contain other keys for different terms.
"""
noise
=
x_0
.
replace
(
torch
.
randn_like
(
x_0
.
feats
))
t
=
self
.
sample_t
(
x_0
.
shape
[
0
]).
to
(
x_0
.
device
).
float
()
x_t
=
self
.
diffuse
(
x_0
,
t
,
noise
=
noise
)
cond
=
self
.
get_cond
(
cond
,
**
kwargs
)
pred
=
self
.
training_models
[
'denoiser'
](
x_t
,
t
*
1000
,
cond
,
**
kwargs
)
assert
pred
.
shape
==
noise
.
shape
==
x_0
.
shape
target
=
self
.
get_v
(
x_0
,
noise
,
t
)
terms
=
edict
()
terms
[
"mse"
]
=
F
.
mse_loss
(
pred
.
feats
,
target
.
feats
)
terms
[
"loss"
]
=
terms
[
"mse"
]
# log loss with time bins
mse_per_instance
=
np
.
array
([
F
.
mse_loss
(
pred
.
feats
[
x_0
.
layout
[
i
]],
target
.
feats
[
x_0
.
layout
[
i
]]).
item
()
for
i
in
range
(
x_0
.
shape
[
0
])
])
time_bin
=
np
.
digitize
(
t
.
cpu
().
numpy
(),
np
.
linspace
(
0
,
1
,
11
))
-
1
for
i
in
range
(
10
):
if
(
time_bin
==
i
).
sum
()
!=
0
:
terms
[
f
"bin_
{
i
}
"
]
=
{
"mse"
:
mse_per_instance
[
time_bin
==
i
].
mean
()}
return
terms
,
{}
@
torch
.
no_grad
()
def
run_snapshot
(
self
,
num_samples
:
int
,
batch_size
:
int
,
verbose
:
bool
=
False
,
)
->
Dict
:
dataloader
=
DataLoader
(
copy
.
deepcopy
(
self
.
dataset
),
batch_size
=
num_samples
,
shuffle
=
True
,
num_workers
=
0
,
collate_fn
=
self
.
dataset
.
collate_fn
if
hasattr
(
self
.
dataset
,
'collate_fn'
)
else
None
,
)
data
=
next
(
iter
(
dataloader
))
# inference
sampler
=
self
.
get_sampler
()
sample
=
[]
cond_vis
=
[]
for
i
in
range
(
0
,
num_samples
,
batch_size
):
batch_data
=
{
k
:
v
[
i
:
i
+
batch_size
]
for
k
,
v
in
data
.
items
()}
batch_data
=
recursive_to_device
(
batch_data
,
'cuda'
)
noise
=
batch_data
[
'x_0'
].
replace
(
torch
.
randn_like
(
batch_data
[
'x_0'
].
feats
))
cond_vis
.
append
(
self
.
vis_cond
(
**
batch_data
))
del
batch_data
[
'x_0'
]
args
=
self
.
get_inference_cond
(
**
batch_data
)
res
=
sampler
.
sample
(
self
.
models
[
'denoiser'
],
noise
=
noise
,
**
args
,
steps
=
12
,
guidance_strength
=
3.0
,
verbose
=
verbose
,
)
sample
.
append
(
res
.
samples
)
sample
=
sp
.
sparse_cat
(
sample
)
sample_gt
=
{
k
:
v
for
k
,
v
in
data
.
items
()}
sample
=
{
k
:
v
if
k
!=
'x_0'
else
sample
for
k
,
v
in
data
.
items
()}
sample_dict
=
{
'sample_gt'
:
{
'value'
:
sample_gt
,
'type'
:
'sample'
},
'sample'
:
{
'value'
:
sample
,
'type'
:
'sample'
},
}
sample_dict
.
update
(
dict_reduce
(
cond_vis
,
None
,
{
'value'
:
lambda
x
:
torch
.
cat
(
x
,
dim
=
0
),
'type'
:
lambda
x
:
x
[
0
],
}))
return
sample_dict
class
SparseFlowMatchingCFGTrainer
(
ClassifierFreeGuidanceMixin
,
SparseFlowMatchingTrainer
):
"""
Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
"""
pass
class
TextConditionedSparseFlowMatchingCFGTrainer
(
TextConditionedMixin
,
SparseFlowMatchingCFGTrainer
):
"""
Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
text_cond_model(str): Text conditioning model.
"""
pass
class
ImageConditionedSparseFlowMatchingCFGTrainer
(
ImageConditionedMixin
,
SparseFlowMatchingCFGTrainer
):
"""
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass
class
MultiImageConditionedSparseFlowMatchingCFGTrainer
(
MultiImageConditionedMixin
,
SparseFlowMatchingCFGTrainer
):
"""
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass
TRELLIS.2_DCU/trellis2/trainers/utils.py
0 → 100644
View file @
f05e915f
import
torch
import
torch.nn
as
nn
# FP16 utils
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
def
str_to_dtype
(
dtype_str
:
str
):
return
{
'f16'
:
torch
.
float16
,
'fp16'
:
torch
.
float16
,
'float16'
:
torch
.
float16
,
'bf16'
:
torch
.
bfloat16
,
'bfloat16'
:
torch
.
bfloat16
,
'f32'
:
torch
.
float32
,
'fp32'
:
torch
.
float32
,
'float32'
:
torch
.
float32
,
}[
dtype_str
]
def
make_master_params
(
model_params
):
"""
Copy model parameters into a inflated tensor of full-precision parameters.
"""
master_params
=
_flatten_dense_tensors
(
[
param
.
detach
().
float
()
for
param
in
model_params
]
)
master_params
=
nn
.
Parameter
(
master_params
)
master_params
.
requires_grad
=
True
return
[
master_params
]
def
unflatten_master_params
(
model_params
,
master_params
):
"""
Unflatten the master parameters to look like model_params.
"""
return
_unflatten_dense_tensors
(
master_params
[
0
].
detach
(),
model_params
)
def
model_params_to_master_params
(
model_params
,
master_params
):
"""
Copy the model parameter data into the master parameters.
"""
master_params
[
0
].
detach
().
copy_
(
_flatten_dense_tensors
([
param
.
detach
().
float
()
for
param
in
model_params
])
)
def
master_params_to_model_params
(
model_params
,
master_params
):
"""
Copy the master parameter data back into the model parameters.
"""
for
param
,
master_param
in
zip
(
model_params
,
_unflatten_dense_tensors
(
master_params
[
0
].
detach
(),
model_params
)
):
param
.
detach
().
copy_
(
master_param
)
def
model_grads_to_master_grads
(
model_params
,
master_params
):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params
[
0
].
grad
=
_flatten_dense_tensors
(
[
param
.
grad
.
data
.
detach
().
float
()
for
param
in
model_params
]
)
def
zero_grad
(
model_params
):
for
param
in
model_params
:
if
param
.
grad
is
not
None
:
if
param
.
grad
.
grad_fn
is
not
None
:
param
.
grad
.
detach_
()
else
:
param
.
grad
.
requires_grad_
(
False
)
param
.
grad
.
zero_
()
# LR Schedulers
from
torch.optim.lr_scheduler
import
LambdaLR
class
LinearWarmupLRScheduler
(
LambdaLR
):
def
__init__
(
self
,
optimizer
,
warmup_steps
,
last_epoch
=-
1
):
self
.
warmup_steps
=
warmup_steps
super
(
LinearWarmupLRScheduler
,
self
).
__init__
(
optimizer
,
self
.
lr_lambda
,
last_epoch
=
last_epoch
)
def
lr_lambda
(
self
,
current_step
):
if
current_step
<
self
.
warmup_steps
:
return
float
(
current_step
+
1
)
/
self
.
warmup_steps
return
1.0
\ No newline at end of file
TRELLIS.2_DCU/trellis2/trainers/vae/pbr_vae.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
os
import
copy
import
functools
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
import
utils3d
from
easydict
import
EasyDict
as
edict
from
..basic
import
BasicTrainer
from
...modules
import
sparse
as
sp
from
...renderers
import
MeshRenderer
from
...representations
import
Mesh
,
MeshWithPbrMaterial
,
MeshWithVoxel
from
...utils.data_utils
import
recursive_to_device
,
cycle
,
BalancedResumableSampler
from
...utils.loss_utils
import
l1_loss
,
l2_loss
,
ssim
,
lpips
class
PbrVaeTrainer
(
BasicTrainer
):
"""
Trainer for PBR attributes VAE
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
loss_type (str): Loss type.
lambda_kl (float): KL loss weight.
lambda_ssim (float): SSIM loss weight.
lambda_lpips (float): LPIPS loss weight.
"""
def
__init__
(
self
,
*
args
,
loss_type
:
str
=
'l1'
,
lambda_kl
:
float
=
1e-6
,
lambda_ssim
:
float
=
0.2
,
lambda_lpips
:
float
=
0.2
,
lambda_render
:
float
=
1.0
,
render_resolution
:
float
=
1024
,
camera_randomization_config
:
dict
=
{
'radius_range'
:
[
2
,
100
],
},
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
loss_type
=
loss_type
self
.
lambda_kl
=
lambda_kl
self
.
lambda_ssim
=
lambda_ssim
self
.
lambda_lpips
=
lambda_lpips
self
.
lambda_render
=
lambda_render
self
.
camera_randomization_config
=
camera_randomization_config
self
.
renderer
=
MeshRenderer
({
'near'
:
1
,
'far'
:
3
,
'resolution'
:
render_resolution
},
device
=
self
.
device
)
def
prepare_dataloader
(
self
,
**
kwargs
):
"""
Prepare dataloader.
"""
self
.
data_sampler
=
BalancedResumableSampler
(
self
.
dataset
,
shuffle
=
True
,
batch_size
=
self
.
batch_size_per_gpu
,
)
self
.
dataloader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size_per_gpu
,
num_workers
=
int
(
np
.
ceil
(
os
.
cpu_count
()
/
torch
.
cuda
.
device_count
())),
pin_memory
=
True
,
drop_last
=
True
,
persistent_workers
=
True
,
collate_fn
=
functools
.
partial
(
self
.
dataset
.
collate_fn
,
split_size
=
self
.
batch_split
),
sampler
=
self
.
data_sampler
,
)
self
.
data_iterator
=
cycle
(
self
.
dataloader
)
def
_randomize_camera
(
self
,
num_samples
:
int
):
# sample radius and fov
r_min
,
r_max
=
self
.
camera_randomization_config
[
'radius_range'
]
k_min
=
1
/
r_max
**
2
k_max
=
1
/
r_min
**
2
ks
=
torch
.
rand
(
num_samples
,
device
=
self
.
device
)
*
(
k_max
-
k_min
)
+
k_min
radius
=
1
/
torch
.
sqrt
(
ks
)
fov
=
2
*
torch
.
arcsin
(
0.5
/
radius
)
origin
=
radius
.
unsqueeze
(
-
1
)
*
F
.
normalize
(
torch
.
randn
(
num_samples
,
3
,
device
=
self
.
device
),
dim
=-
1
)
# build camera
extrinsics
=
utils3d
.
torch
.
extrinsics_look_at
(
origin
,
torch
.
zeros_like
(
origin
),
torch
.
tensor
([
0
,
0
,
1
],
dtype
=
torch
.
float32
,
device
=
self
.
device
))
intrinsics
=
utils3d
.
torch
.
intrinsics_from_fov_xy
(
fov
,
fov
)
near
=
[
np
.
random
.
uniform
(
r
-
1
,
r
)
for
r
in
radius
.
tolist
()]
return
{
'extrinsics'
:
extrinsics
,
'intrinsics'
:
intrinsics
,
'near'
:
near
,
}
def
_render_batch
(
self
,
reps
:
List
[
Mesh
],
extrinsics
:
torch
.
Tensor
,
intrinsics
:
torch
.
Tensor
,
near
:
List
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Render a batch of representations.
Args:
reps: The dictionary of lists of representations.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
Returns:
a dict with
base_color : [N x 3 x H x W] tensor of base color.
metallic : [N x 1 x H x W] tensor of metallic.
roughness : [N x 1 x H x W] tensor of roughness.
alpha : [N x 1 x H x W] tensor of alpha.
"""
ret
=
{
k
:
[]
for
k
in
[
'base_color'
,
'metallic'
,
'roughness'
,
'alpha'
]}
for
i
,
rep
in
enumerate
(
reps
):
self
.
renderer
.
rendering_options
[
'near'
]
=
near
[
i
]
self
.
renderer
.
rendering_options
[
'far'
]
=
near
[
i
]
+
2
out_dict
=
self
.
renderer
.
render
(
rep
,
extrinsics
[
i
],
intrinsics
[
i
],
return_types
=
[
'attr'
])
for
k
in
out_dict
:
ret
[
k
].
append
(
out_dict
[
k
])
for
k
in
ret
:
ret
[
k
]
=
torch
.
stack
(
ret
[
k
])
return
ret
def
training_losses
(
self
,
x
:
sp
.
SparseTensor
,
mesh
:
List
[
MeshWithPbrMaterial
]
=
None
,
**
kwargs
)
->
Tuple
[
Dict
,
Dict
]:
"""
Compute training losses.
Args:
x (SparseTensor): Input sparse tensor for pbr materials.
mesh (List[MeshWithPbrMaterial]): The list of meshes with PBR materials.
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
z
,
mean
,
logvar
=
self
.
training_models
[
'encoder'
](
x
,
sample_posterior
=
True
,
return_raw
=
True
)
y
=
self
.
training_models
[
'decoder'
](
z
)
terms
=
edict
(
loss
=
0.0
)
# direct regression
if
self
.
loss_type
==
'l1'
:
terms
[
"l1"
]
=
l1_loss
(
x
.
feats
,
y
.
feats
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
terms
[
"l1"
]
elif
self
.
loss_type
==
'l2'
:
terms
[
"l2"
]
=
l2_loss
(
x
.
feats
,
y
.
feats
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
terms
[
"l2"
]
else
:
raise
ValueError
(
f
'Invalid loss type
{
self
.
loss_type
}
'
)
# rendering loss
if
self
.
lambda_render
!=
0.0
:
recon
=
[
MeshWithVoxel
(
m
.
vertices
,
m
.
faces
,
[
-
0.5
,
-
0.5
,
-
0.5
],
1
/
self
.
dataset
.
resolution
,
v
.
coords
[:,
1
:],
v
.
feats
*
0.5
+
0.5
,
torch
.
Size
([
*
v
.
shape
,
*
v
.
spatial_shape
]),
layout
=
{
'base_color'
:
slice
(
0
,
3
),
'metallic'
:
slice
(
3
,
4
),
'roughness'
:
slice
(
4
,
5
),
'alpha'
:
slice
(
5
,
6
),
}
)
for
m
,
v
in
zip
(
mesh
,
y
)]
cameras
=
self
.
_randomize_camera
(
len
(
mesh
))
gt_renders
=
self
.
_render_batch
(
mesh
,
**
cameras
)
pred_renders
=
self
.
_render_batch
(
recon
,
**
cameras
)
gt_base_color
=
gt_renders
[
'base_color'
]
pred_base_color
=
pred_renders
[
'base_color'
]
gt_mra
=
torch
.
cat
([
gt_renders
[
'metallic'
],
gt_renders
[
'roughness'
],
gt_renders
[
'alpha'
]],
dim
=
1
)
pred_mra
=
torch
.
cat
([
pred_renders
[
'metallic'
],
pred_renders
[
'roughness'
],
pred_renders
[
'alpha'
]],
dim
=
1
)
terms
[
'render/base_color/ssim'
]
=
1
-
ssim
(
pred_base_color
,
gt_base_color
)
terms
[
'render/base_color/lpips'
]
=
lpips
(
pred_base_color
,
gt_base_color
)
terms
[
'render/mra/ssim'
]
=
1
-
ssim
(
pred_mra
,
gt_mra
)
terms
[
'render/mra/lpips'
]
=
lpips
(
pred_mra
,
gt_mra
)
terms
[
'loss'
]
=
terms
[
'loss'
]
+
\
self
.
lambda_render
*
(
self
.
lambda_ssim
*
terms
[
'render/base_color/ssim'
]
+
self
.
lambda_lpips
*
terms
[
'render/base_color/lpips'
]
+
\
self
.
lambda_ssim
*
terms
[
'render/mra/ssim'
]
+
self
.
lambda_lpips
*
terms
[
'render/mra/lpips'
])
# KL regularization
terms
[
"kl"
]
=
0.5
*
torch
.
mean
(
mean
.
pow
(
2
)
+
logvar
.
exp
()
-
logvar
-
1
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
self
.
lambda_kl
*
terms
[
"kl"
]
return
terms
,
{}
@
torch
.
no_grad
()
def
run_snapshot
(
self
,
num_samples
:
int
,
batch_size
:
int
,
verbose
:
bool
=
False
,
)
->
Dict
:
dataloader
=
DataLoader
(
copy
.
deepcopy
(
self
.
dataset
),
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
1
,
collate_fn
=
self
.
dataset
.
collate_fn
if
hasattr
(
self
.
dataset
,
'collate_fn'
)
else
None
,
)
dataloader
.
dataset
.
with_mesh
=
True
# inference
gts
=
[]
recons
=
[]
self
.
models
[
'encoder'
].
eval
()
self
.
models
[
'decoder'
].
eval
()
for
i
in
range
(
0
,
num_samples
,
batch_size
):
batch
=
min
(
batch_size
,
num_samples
-
i
)
data
=
next
(
iter
(
dataloader
))
args
=
{
k
:
v
[:
batch
]
for
k
,
v
in
data
.
items
()}
args
=
recursive_to_device
(
args
,
self
.
device
)
z
=
self
.
models
[
'encoder'
](
args
[
'x'
])
y
=
self
.
models
[
'decoder'
](
z
)
gts
.
extend
(
args
[
'mesh'
])
recons
.
extend
([
MeshWithVoxel
(
m
.
vertices
,
m
.
faces
,
[
-
0.5
,
-
0.5
,
-
0.5
],
1
/
self
.
dataset
.
resolution
,
v
.
coords
[:,
1
:],
v
.
feats
*
0.5
+
0.5
,
torch
.
Size
([
*
v
.
shape
,
*
v
.
spatial_shape
]),
layout
=
{
'base_color'
:
slice
(
0
,
3
),
'metallic'
:
slice
(
3
,
4
),
'roughness'
:
slice
(
4
,
5
),
'alpha'
:
slice
(
5
,
6
),
}
)
for
m
,
v
in
zip
(
args
[
'mesh'
],
y
)])
self
.
models
[
'encoder'
].
train
()
self
.
models
[
'decoder'
].
train
()
cameras
=
self
.
_randomize_camera
(
num_samples
)
gt_renders
=
self
.
_render_batch
(
gts
,
**
cameras
)
pred_renders
=
self
.
_render_batch
(
recons
,
**
cameras
)
sample_dict
=
{
'gt_base_color'
:
{
'value'
:
gt_renders
[
'base_color'
]
*
2
-
1
,
'type'
:
'image'
},
'pred_base_color'
:
{
'value'
:
pred_renders
[
'base_color'
]
*
2
-
1
,
'type'
:
'image'
},
'gt_mra'
:
{
'value'
:
torch
.
cat
([
gt_renders
[
'metallic'
],
gt_renders
[
'roughness'
],
gt_renders
[
'alpha'
]],
dim
=
1
)
*
2
-
1
,
'type'
:
'image'
},
'pred_mra'
:
{
'value'
:
torch
.
cat
([
pred_renders
[
'metallic'
],
pred_renders
[
'roughness'
],
pred_renders
[
'alpha'
]],
dim
=
1
)
*
2
-
1
,
'type'
:
'image'
},
}
return
sample_dict
TRELLIS.2_DCU/trellis2/trainers/vae/shape_vae.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
os
import
copy
import
functools
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
import
utils3d
from
easydict
import
EasyDict
as
edict
from
..basic
import
BasicTrainer
from
...modules
import
sparse
as
sp
from
...renderers
import
MeshRenderer
from
...representations
import
Mesh
from
...utils.data_utils
import
recursive_to_device
,
cycle
,
BalancedResumableSampler
from
...utils.loss_utils
import
l1_loss
,
ssim
,
lpips
class
ShapeVaeTrainer
(
BasicTrainer
):
"""
Trainer for Shape VAE
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
lambda_subdiv (float): Subdivision loss weight.
lambda_intersected (float): Intersected loss weight.
lambda_vertice (float): Vertice loss weight.
lambda_kl (float): KL loss weight.
lambda_ssim (float): SSIM loss weight.
lambda_lpips (float): LPIPS loss weight.
"""
def
__init__
(
self
,
*
args
,
lambda_subdiv
:
float
=
0.1
,
lambda_intersected
:
float
=
0.1
,
lambda_vertice
:
float
=
1e-2
,
lambda_mask
:
float
=
1
,
lambda_depth
:
float
=
10
,
lambda_normal
:
float
=
1
,
lambda_kl
:
float
=
1e-6
,
lambda_ssim
:
float
=
0.2
,
lambda_lpips
:
float
=
0.2
,
render_resolution
:
float
=
1024
,
camera_randomization_config
:
dict
=
{
'radius_range'
:
[
2
,
100
],
},
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
lambda_subdiv
=
lambda_subdiv
self
.
lambda_intersected
=
lambda_intersected
self
.
lambda_mask
=
lambda_mask
self
.
lambda_vertice
=
lambda_vertice
self
.
lambda_depth
=
lambda_depth
self
.
lambda_normal
=
lambda_normal
self
.
lambda_kl
=
lambda_kl
self
.
lambda_ssim
=
lambda_ssim
self
.
lambda_lpips
=
lambda_lpips
self
.
camera_randomization_config
=
camera_randomization_config
self
.
renderer
=
MeshRenderer
({
'near'
:
1
,
'far'
:
3
,
'resolution'
:
render_resolution
},
device
=
self
.
device
)
def
prepare_dataloader
(
self
,
**
kwargs
):
"""
Prepare dataloader.
"""
self
.
data_sampler
=
BalancedResumableSampler
(
self
.
dataset
,
shuffle
=
True
,
batch_size
=
self
.
batch_size_per_gpu
,
)
self
.
dataloader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size_per_gpu
,
num_workers
=
int
(
np
.
ceil
(
os
.
cpu_count
()
/
torch
.
cuda
.
device_count
())),
pin_memory
=
True
,
drop_last
=
True
,
persistent_workers
=
True
,
collate_fn
=
functools
.
partial
(
self
.
dataset
.
collate_fn
,
split_size
=
self
.
batch_split
),
sampler
=
self
.
data_sampler
,
)
self
.
data_iterator
=
cycle
(
self
.
dataloader
)
def
_randomize_camera
(
self
,
num_samples
:
int
):
# sample radius and fov
r_min
,
r_max
=
self
.
camera_randomization_config
[
'radius_range'
]
k_min
=
1
/
r_max
**
2
k_max
=
1
/
r_min
**
2
ks
=
torch
.
rand
(
num_samples
,
device
=
self
.
device
)
*
(
k_max
-
k_min
)
+
k_min
radius
=
1
/
torch
.
sqrt
(
ks
)
fov
=
2
*
torch
.
arcsin
(
0.5
/
radius
)
origin
=
radius
.
unsqueeze
(
-
1
)
*
F
.
normalize
(
torch
.
randn
(
num_samples
,
3
,
device
=
self
.
device
),
dim
=-
1
)
# build camera
extrinsics
=
utils3d
.
torch
.
extrinsics_look_at
(
origin
,
torch
.
zeros_like
(
origin
),
torch
.
tensor
([
0
,
0
,
1
],
dtype
=
torch
.
float32
,
device
=
self
.
device
))
intrinsics
=
utils3d
.
torch
.
intrinsics_from_fov_xy
(
fov
,
fov
)
near
=
[
np
.
random
.
uniform
(
r
-
1
,
r
)
for
r
in
radius
.
tolist
()]
return
{
'extrinsics'
:
extrinsics
,
'intrinsics'
:
intrinsics
,
'near'
:
near
,
}
def
_render_batch
(
self
,
reps
:
List
[
Mesh
],
extrinsics
:
torch
.
Tensor
,
intrinsics
:
torch
.
Tensor
,
near
:
List
,
return_types
=
[
'mask'
,
'normal'
,
'depth'
])
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Render a batch of representations.
Args:
reps: The dictionary of lists of representations.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
Returns:
a dict with
mask : [N x 1 x H x W] tensor of rendered masks
normal : [N x 3 x H x W] tensor of rendered normals
depth : [N x 1 x H x W] tensor of rendered depths
"""
ret
=
{
k
:
[]
for
k
in
return_types
}
for
i
,
rep
in
enumerate
(
reps
):
self
.
renderer
.
rendering_options
[
'near'
]
=
near
[
i
]
self
.
renderer
.
rendering_options
[
'far'
]
=
near
[
i
]
+
2
out_dict
=
self
.
renderer
.
render
(
rep
,
extrinsics
[
i
],
intrinsics
[
i
],
return_types
=
return_types
)
for
k
in
out_dict
:
ret
[
k
].
append
(
out_dict
[
k
][
None
]
if
k
in
[
'mask'
,
'depth'
]
else
out_dict
[
k
])
for
k
in
ret
:
ret
[
k
]
=
torch
.
stack
(
ret
[
k
])
return
ret
def
training_losses
(
self
,
vertices
:
sp
.
SparseTensor
,
intersected
:
sp
.
SparseTensor
,
mesh
:
List
[
Mesh
],
)
->
Tuple
[
Dict
,
Dict
]:
"""
Compute training losses.
Args:
vertices (SparseTensor): vertices of each active voxel
intersected (SparseTensor): intersected flag of each active voxel
mesh (List[Mesh]): the list of meshes to render
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
z
,
mean
,
logvar
=
self
.
training_models
[
'encoder'
](
vertices
,
intersected
,
sample_posterior
=
True
,
return_raw
=
True
)
recon
,
pred_vertice
,
pred_intersected
,
subs_gt
,
subs
=
self
.
training_models
[
'decoder'
](
z
,
intersected
)
terms
=
edict
(
loss
=
0.0
)
# direct regression
if
self
.
lambda_intersected
>
0
:
terms
[
"direct/intersected"
]
=
F
.
binary_cross_entropy_with_logits
(
pred_intersected
.
feats
.
flatten
(),
intersected
.
feats
.
flatten
().
float
())
terms
[
"loss"
]
=
terms
[
"loss"
]
+
self
.
lambda_intersected
*
terms
[
"direct/intersected"
]
if
self
.
lambda_vertice
>
0
:
terms
[
"direct/vertice"
]
=
F
.
mse_loss
(
pred_vertice
.
feats
,
vertices
.
feats
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
self
.
lambda_vertice
*
terms
[
"direct/vertice"
]
# subdivision prediction loss
for
i
,
(
sub_gt
,
sub
)
in
enumerate
(
zip
(
subs_gt
,
subs
)):
terms
[
f
"bce_sub
{
i
}
"
]
=
F
.
binary_cross_entropy_with_logits
(
sub
.
feats
,
sub_gt
.
float
())
terms
[
"loss"
]
=
terms
[
"loss"
]
+
self
.
lambda_subdiv
*
terms
[
f
"bce_sub
{
i
}
"
]
# rendering loss
cameras
=
self
.
_randomize_camera
(
len
(
mesh
))
gt_renders
=
self
.
_render_batch
(
mesh
,
**
cameras
,
return_types
=
[
'mask'
,
'normal'
,
'depth'
])
pred_renders
=
self
.
_render_batch
(
recon
,
**
cameras
,
return_types
=
[
'mask'
,
'normal'
,
'depth'
])
terms
[
'render/mask'
]
=
l1_loss
(
pred_renders
[
'mask'
],
gt_renders
[
'mask'
])
terms
[
'render/depth'
]
=
l1_loss
(
pred_renders
[
'depth'
],
gt_renders
[
'depth'
])
terms
[
'render/normal/l1'
]
=
l1_loss
(
pred_renders
[
'normal'
],
gt_renders
[
'normal'
])
terms
[
'render/normal/ssim'
]
=
1
-
ssim
(
pred_renders
[
'normal'
],
gt_renders
[
'normal'
])
terms
[
'render/normal/lpips'
]
=
lpips
(
pred_renders
[
'normal'
],
gt_renders
[
'normal'
])
terms
[
'loss'
]
=
terms
[
'loss'
]
+
\
self
.
lambda_mask
*
terms
[
'render/mask'
]
+
\
self
.
lambda_depth
*
terms
[
'render/depth'
]
+
\
self
.
lambda_normal
*
(
terms
[
'render/normal/l1'
]
+
self
.
lambda_ssim
*
terms
[
'render/normal/ssim'
]
+
self
.
lambda_lpips
*
terms
[
'render/normal/lpips'
])
# KL regularization
terms
[
"kl"
]
=
0.5
*
torch
.
mean
(
mean
.
pow
(
2
)
+
logvar
.
exp
()
-
logvar
-
1
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
self
.
lambda_kl
*
terms
[
"kl"
]
return
terms
,
{}
@
torch
.
no_grad
()
def
run_snapshot
(
self
,
num_samples
:
int
,
batch_size
:
int
,
verbose
:
bool
=
False
,
)
->
Dict
:
dataloader
=
DataLoader
(
copy
.
deepcopy
(
self
.
dataset
),
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
1
,
collate_fn
=
self
.
dataset
.
collate_fn
if
hasattr
(
self
.
dataset
,
'collate_fn'
)
else
None
,
)
# inference
gts
=
[]
recons
=
[]
recons2
=
[]
self
.
models
[
'encoder'
].
eval
()
for
i
in
range
(
0
,
num_samples
,
batch_size
):
batch
=
min
(
batch_size
,
num_samples
-
i
)
data
=
next
(
iter
(
dataloader
))
args
=
{
k
:
v
[:
batch
]
for
k
,
v
in
data
.
items
()}
args
=
recursive_to_device
(
args
,
self
.
device
)
z
=
self
.
models
[
'encoder'
](
args
[
'vertices'
],
args
[
'intersected'
])
self
.
models
[
'decoder'
].
train
()
y
=
self
.
models
[
'decoder'
](
z
,
args
[
'intersected'
])[
0
]
z
.
clear_spatial_cache
()
self
.
models
[
'decoder'
].
eval
()
y2
=
self
.
models
[
'decoder'
](
z
)
gts
.
extend
(
args
[
'mesh'
])
recons
.
extend
(
y
)
recons2
.
extend
(
y2
)
self
.
models
[
'encoder'
].
train
()
self
.
models
[
'decoder'
].
train
()
cameras
=
self
.
_randomize_camera
(
num_samples
)
gt_renders
=
self
.
_render_batch
(
gts
,
**
cameras
,
return_types
=
[
'normal'
])
recons_renders
=
self
.
_render_batch
(
recons
,
**
cameras
,
return_types
=
[
'normal'
])
recons2_renders
=
self
.
_render_batch
(
recons2
,
**
cameras
,
return_types
=
[
'normal'
])
sample_dict
=
{
'gt'
:
{
'value'
:
gt_renders
[
'normal'
],
'type'
:
'image'
},
'rec'
:
{
'value'
:
recons_renders
[
'normal'
],
'type'
:
'image'
},
'rec2'
:
{
'value'
:
recons2_renders
[
'normal'
],
'type'
:
'image'
},
}
return
sample_dict
TRELLIS.2_DCU/trellis2/trainers/vae/sparse_structure_vae.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
copy
import
torch
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
from
easydict
import
EasyDict
as
edict
from
..basic
import
BasicTrainer
class
SparseStructureVaeTrainer
(
BasicTrainer
):
"""
Trainer for Sparse Structure VAE.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss.
lambda_kl (float): KL divergence loss weight.
"""
def
__init__
(
self
,
*
args
,
loss_type
=
'bce'
,
lambda_kl
=
1e-6
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
loss_type
=
loss_type
self
.
lambda_kl
=
lambda_kl
def
training_losses
(
self
,
ss
:
torch
.
Tensor
,
**
kwargs
)
->
Tuple
[
Dict
,
Dict
]:
"""
Compute training losses.
Args:
ss: The [N x 1 x H x W x D] tensor of binary sparse structure.
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
z
,
mean
,
logvar
=
self
.
training_models
[
'encoder'
](
ss
.
float
(),
sample_posterior
=
True
,
return_raw
=
True
)
logits
=
self
.
training_models
[
'decoder'
](
z
)
terms
=
edict
(
loss
=
0.0
)
if
self
.
loss_type
==
'bce'
:
terms
[
"bce"
]
=
F
.
binary_cross_entropy_with_logits
(
logits
,
ss
.
float
(),
reduction
=
'mean'
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
terms
[
"bce"
]
elif
self
.
loss_type
==
'l1'
:
terms
[
"l1"
]
=
F
.
l1_loss
(
F
.
sigmoid
(
logits
),
ss
.
float
(),
reduction
=
'mean'
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
terms
[
"l1"
]
elif
self
.
loss_type
==
'dice'
:
logits
=
F
.
sigmoid
(
logits
)
terms
[
"dice"
]
=
1
-
(
2
*
(
logits
*
ss
.
float
()).
sum
()
+
1
)
/
(
logits
.
sum
()
+
ss
.
float
().
sum
()
+
1
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
terms
[
"dice"
]
else
:
raise
ValueError
(
f
'Invalid loss type
{
self
.
loss_type
}
'
)
terms
[
"kl"
]
=
0.5
*
torch
.
mean
(
mean
.
pow
(
2
)
+
logvar
.
exp
()
-
logvar
-
1
)
terms
[
"loss"
]
=
terms
[
"loss"
]
+
self
.
lamda_kl
*
terms
[
"kl"
]
return
terms
,
{}
@
torch
.
no_grad
()
def
snapshot
(
self
,
suffix
=
None
,
num_samples
=
64
,
batch_size
=
1
,
verbose
=
False
):
super
().
snapshot
(
suffix
=
suffix
,
num_samples
=
num_samples
,
batch_size
=
batch_size
,
verbose
=
verbose
)
@
torch
.
no_grad
()
def
run_snapshot
(
self
,
num_samples
:
int
,
batch_size
:
int
,
verbose
:
bool
=
False
,
)
->
Dict
:
dataloader
=
DataLoader
(
copy
.
deepcopy
(
self
.
dataset
),
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
0
,
collate_fn
=
self
.
dataset
.
collate_fn
if
hasattr
(
self
.
dataset
,
'collate_fn'
)
else
None
,
)
# inference
gts
=
[]
recons
=
[]
for
i
in
range
(
0
,
num_samples
,
batch_size
):
batch
=
min
(
batch_size
,
num_samples
-
i
)
data
=
next
(
iter
(
dataloader
))
args
=
{
k
:
v
[:
batch
].
cuda
()
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
[:
batch
]
for
k
,
v
in
data
.
items
()}
z
=
self
.
models
[
'encoder'
](
args
[
'ss'
].
float
(),
sample_posterior
=
False
)
logits
=
self
.
models
[
'decoder'
](
z
)
recon
=
(
logits
>
0
).
long
()
gts
.
append
(
args
[
'ss'
])
recons
.
append
(
recon
)
sample_dict
=
{
'gt'
:
{
'value'
:
torch
.
cat
(
gts
,
dim
=
0
),
'type'
:
'sample'
},
'recon'
:
{
'value'
:
torch
.
cat
(
recons
,
dim
=
0
),
'type'
:
'sample'
},
}
return
sample_dict
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment