Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
TRELLIS.2
Commits
f05e915f
Commit
f05e915f
authored
May 27, 2026
by
weishb
Browse files
首次提交
parent
297bf637
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1066 additions
and
0 deletions
+1066
-0
TRELLIS.2_DCU/trellis2/models/sparse_structure_vae.py
TRELLIS.2_DCU/trellis2/models/sparse_structure_vae.py
+306
-0
TRELLIS.2_DCU/trellis2/models/structured_latent_flow.py
TRELLIS.2_DCU/trellis2/models/structured_latent_flow.py
+207
-0
TRELLIS.2_DCU/trellis2/modules/__pycache__/image_feature_extractor.cpython-310.pyc
...dules/__pycache__/image_feature_extractor.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/__pycache__/norm.cpython-310.pyc
...S.2_DCU/trellis2/modules/__pycache__/norm.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/__pycache__/spatial.cpython-310.pyc
..._DCU/trellis2/modules/__pycache__/spatial.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/__pycache__/utils.cpython-310.pyc
....2_DCU/trellis2/modules/__pycache__/utils.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/attention/__init__.py
TRELLIS.2_DCU/trellis2/modules/attention/__init__.py
+3
-0
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/__init__.cpython-310.pyc
...s2/modules/attention/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/config.cpython-310.pyc
...lis2/modules/attention/__pycache__/config.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/full_attn.cpython-310.pyc
...2/modules/attention/__pycache__/full_attn.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/modules.cpython-310.pyc
...is2/modules/attention/__pycache__/modules.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/rope.cpython-310.pyc
...ellis2/modules/attention/__pycache__/rope.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/attention/config.py
TRELLIS.2_DCU/trellis2/modules/attention/config.py
+32
-0
TRELLIS.2_DCU/trellis2/modules/attention/full_attn.py
TRELLIS.2_DCU/trellis2/modules/attention/full_attn.py
+145
-0
TRELLIS.2_DCU/trellis2/modules/attention/modules.py
TRELLIS.2_DCU/trellis2/modules/attention/modules.py
+102
-0
TRELLIS.2_DCU/trellis2/modules/attention/rope.py
TRELLIS.2_DCU/trellis2/modules/attention/rope.py
+48
-0
TRELLIS.2_DCU/trellis2/modules/image_feature_extractor.py
TRELLIS.2_DCU/trellis2/modules/image_feature_extractor.py
+122
-0
TRELLIS.2_DCU/trellis2/modules/norm.py
TRELLIS.2_DCU/trellis2/modules/norm.py
+32
-0
TRELLIS.2_DCU/trellis2/modules/sparse/__init__.py
TRELLIS.2_DCU/trellis2/modules/sparse/__init__.py
+69
-0
TRELLIS.2_DCU/trellis2/modules/sparse/__pycache__/__init__.cpython-310.pyc
...llis2/modules/sparse/__pycache__/__init__.cpython-310.pyc
+0
-0
No files found.
TRELLIS.2_DCU/trellis2/models/sparse_structure_vae.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
..modules.norm
import
GroupNorm32
,
ChannelLayerNorm32
from
..modules.spatial
import
pixel_shuffle_3d
from
..modules.utils
import
zero_module
,
convert_module_to_f16
,
convert_module_to_f32
def
norm_layer
(
norm_type
:
str
,
*
args
,
**
kwargs
)
->
nn
.
Module
:
"""
Return a normalization layer.
"""
if
norm_type
==
"group"
:
return
GroupNorm32
(
32
,
*
args
,
**
kwargs
)
elif
norm_type
==
"layer"
:
return
ChannelLayerNorm32
(
*
args
,
**
kwargs
)
else
:
raise
ValueError
(
f
"Invalid norm type
{
norm_type
}
"
)
class
ResBlock3d
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
norm_type
:
Literal
[
"group"
,
"layer"
]
=
"layer"
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
norm1
=
norm_layer
(
norm_type
,
channels
)
self
.
norm2
=
norm_layer
(
norm_type
,
self
.
out_channels
)
self
.
conv1
=
nn
.
Conv3d
(
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
self
.
conv2
=
zero_module
(
nn
.
Conv3d
(
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
))
self
.
skip_connection
=
nn
.
Conv3d
(
channels
,
self
.
out_channels
,
1
)
if
channels
!=
self
.
out_channels
else
nn
.
Identity
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
self
.
norm1
(
x
)
h
=
F
.
silu
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
norm2
(
h
)
h
=
F
.
silu
(
h
)
h
=
self
.
conv2
(
h
)
h
=
h
+
self
.
skip_connection
(
x
)
return
h
class
DownsampleBlock3d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mode
:
Literal
[
"conv"
,
"avgpool"
]
=
"conv"
,
):
assert
mode
in
[
"conv"
,
"avgpool"
],
f
"Invalid mode
{
mode
}
"
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
if
mode
==
"conv"
:
self
.
conv
=
nn
.
Conv3d
(
in_channels
,
out_channels
,
2
,
stride
=
2
)
elif
mode
==
"avgpool"
:
assert
in_channels
==
out_channels
,
"Pooling mode requires in_channels to be equal to out_channels"
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"conv"
):
return
self
.
conv
(
x
)
else
:
return
F
.
avg_pool3d
(
x
,
2
)
class
UpsampleBlock3d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mode
:
Literal
[
"conv"
,
"nearest"
]
=
"conv"
,
):
assert
mode
in
[
"conv"
,
"nearest"
],
f
"Invalid mode
{
mode
}
"
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
if
mode
==
"conv"
:
self
.
conv
=
nn
.
Conv3d
(
in_channels
,
out_channels
*
8
,
3
,
padding
=
1
)
elif
mode
==
"nearest"
:
assert
in_channels
==
out_channels
,
"Nearest mode requires in_channels to be equal to out_channels"
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"conv"
):
x
=
self
.
conv
(
x
)
return
pixel_shuffle_3d
(
x
,
2
)
else
:
return
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
class
SparseStructureEncoder
(
nn
.
Module
):
"""
Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
Args:
in_channels (int): Channels of the input.
latent_channels (int): Channels of the latent representation.
num_res_blocks (int): Number of residual blocks at each resolution.
channels (List[int]): Channels of the encoder blocks.
num_res_blocks_middle (int): Number of residual blocks in the middle.
norm_type (Literal["group", "layer"]): Type of normalization layer.
use_fp16 (bool): Whether to use FP16.
"""
def
__init__
(
self
,
in_channels
:
int
,
latent_channels
:
int
,
num_res_blocks
:
int
,
channels
:
List
[
int
],
num_res_blocks_middle
:
int
=
2
,
norm_type
:
Literal
[
"group"
,
"layer"
]
=
"layer"
,
use_fp16
:
bool
=
False
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
latent_channels
=
latent_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
channels
=
channels
self
.
num_res_blocks_middle
=
num_res_blocks_middle
self
.
norm_type
=
norm_type
self
.
use_fp16
=
use_fp16
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
input_layer
=
nn
.
Conv3d
(
in_channels
,
channels
[
0
],
3
,
padding
=
1
)
self
.
blocks
=
nn
.
ModuleList
([])
for
i
,
ch
in
enumerate
(
channels
):
self
.
blocks
.
extend
([
ResBlock3d
(
ch
,
ch
)
for
_
in
range
(
num_res_blocks
)
])
if
i
<
len
(
channels
)
-
1
:
self
.
blocks
.
append
(
DownsampleBlock3d
(
ch
,
channels
[
i
+
1
])
)
self
.
middle_block
=
nn
.
Sequential
(
*
[
ResBlock3d
(
channels
[
-
1
],
channels
[
-
1
])
for
_
in
range
(
num_res_blocks_middle
)
])
self
.
out_layer
=
nn
.
Sequential
(
norm_layer
(
norm_type
,
channels
[
-
1
]),
nn
.
SiLU
(),
nn
.
Conv3d
(
channels
[
-
1
],
latent_channels
*
2
,
3
,
padding
=
1
)
)
if
use_fp16
:
self
.
convert_to_fp16
()
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Return the device of the model.
"""
return
next
(
self
.
parameters
()).
device
def
convert_to_fp16
(
self
)
->
None
:
"""
Convert the torso of the model to float16.
"""
self
.
use_fp16
=
True
self
.
dtype
=
torch
.
float16
self
.
blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
)
->
None
:
"""
Convert the torso of the model to float32.
"""
self
.
use_fp16
=
False
self
.
dtype
=
torch
.
float32
self
.
blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
sample_posterior
:
bool
=
False
,
return_raw
:
bool
=
False
)
->
torch
.
Tensor
:
h
=
self
.
input_layer
(
x
)
h
=
h
.
type
(
self
.
dtype
)
for
block
in
self
.
blocks
:
h
=
block
(
h
)
h
=
self
.
middle_block
(
h
)
h
=
h
.
type
(
x
.
dtype
)
h
=
self
.
out_layer
(
h
)
mean
,
logvar
=
h
.
chunk
(
2
,
dim
=
1
)
if
sample_posterior
:
std
=
torch
.
exp
(
0.5
*
logvar
)
z
=
mean
+
std
*
torch
.
randn_like
(
std
)
else
:
z
=
mean
if
return_raw
:
return
z
,
mean
,
logvar
return
z
class
SparseStructureDecoder
(
nn
.
Module
):
"""
Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
Args:
out_channels (int): Channels of the output.
latent_channels (int): Channels of the latent representation.
num_res_blocks (int): Number of residual blocks at each resolution.
channels (List[int]): Channels of the decoder blocks.
num_res_blocks_middle (int): Number of residual blocks in the middle.
norm_type (Literal["group", "layer"]): Type of normalization layer.
use_fp16 (bool): Whether to use FP16.
"""
def
__init__
(
self
,
out_channels
:
int
,
latent_channels
:
int
,
num_res_blocks
:
int
,
channels
:
List
[
int
],
num_res_blocks_middle
:
int
=
2
,
norm_type
:
Literal
[
"group"
,
"layer"
]
=
"layer"
,
use_fp16
:
bool
=
False
,
):
super
().
__init__
()
self
.
out_channels
=
out_channels
self
.
latent_channels
=
latent_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
channels
=
channels
self
.
num_res_blocks_middle
=
num_res_blocks_middle
self
.
norm_type
=
norm_type
self
.
use_fp16
=
use_fp16
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
input_layer
=
nn
.
Conv3d
(
latent_channels
,
channels
[
0
],
3
,
padding
=
1
)
self
.
middle_block
=
nn
.
Sequential
(
*
[
ResBlock3d
(
channels
[
0
],
channels
[
0
])
for
_
in
range
(
num_res_blocks_middle
)
])
self
.
blocks
=
nn
.
ModuleList
([])
for
i
,
ch
in
enumerate
(
channels
):
self
.
blocks
.
extend
([
ResBlock3d
(
ch
,
ch
)
for
_
in
range
(
num_res_blocks
)
])
if
i
<
len
(
channels
)
-
1
:
self
.
blocks
.
append
(
UpsampleBlock3d
(
ch
,
channels
[
i
+
1
])
)
self
.
out_layer
=
nn
.
Sequential
(
norm_layer
(
norm_type
,
channels
[
-
1
]),
nn
.
SiLU
(),
nn
.
Conv3d
(
channels
[
-
1
],
out_channels
,
3
,
padding
=
1
)
)
if
use_fp16
:
self
.
convert_to_fp16
()
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Return the device of the model.
"""
return
next
(
self
.
parameters
()).
device
def
convert_to_fp16
(
self
)
->
None
:
"""
Convert the torso of the model to float16.
"""
self
.
use_fp16
=
True
self
.
dtype
=
torch
.
float16
self
.
blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
)
->
None
:
"""
Convert the torso of the model to float32.
"""
self
.
use_fp16
=
False
self
.
dtype
=
torch
.
float32
self
.
blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
self
.
input_layer
(
x
)
h
=
h
.
type
(
self
.
dtype
)
h
=
self
.
middle_block
(
h
)
for
block
in
self
.
blocks
:
h
=
block
(
h
)
h
=
h
.
type
(
x
.
dtype
)
h
=
self
.
out_layer
(
h
)
return
h
TRELLIS.2_DCU/trellis2/models/structured_latent_flow.py
0 → 100644
View file @
f05e915f
from
typing
import
*
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
..modules.utils
import
convert_module_to
,
manual_cast
,
str_to_dtype
from
..modules.transformer
import
AbsolutePositionEmbedder
from
..modules
import
sparse
as
sp
from
..modules.sparse.transformer
import
ModulatedSparseTransformerCrossBlock
from
.sparse_structure_flow
import
TimestepEmbedder
from
.sparse_elastic_mixin
import
SparseTransformerElasticMixin
class
SLatFlowModel
(
nn
.
Module
):
def
__init__
(
self
,
resolution
:
int
,
in_channels
:
int
,
model_channels
:
int
,
cond_channels
:
int
,
out_channels
:
int
,
num_blocks
:
int
,
num_heads
:
Optional
[
int
]
=
None
,
num_head_channels
:
Optional
[
int
]
=
64
,
mlp_ratio
:
float
=
4
,
pe_mode
:
Literal
[
"ape"
,
"rope"
]
=
"ape"
,
rope_freq
:
Tuple
[
float
,
float
]
=
(
1.0
,
10000.0
),
dtype
:
str
=
'float32'
,
use_checkpoint
:
bool
=
False
,
share_mod
:
bool
=
False
,
initialization
:
str
=
'vanilla'
,
qk_rms_norm
:
bool
=
False
,
qk_rms_norm_cross
:
bool
=
False
,
):
super
().
__init__
()
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
cond_channels
=
cond_channels
self
.
out_channels
=
out_channels
self
.
num_blocks
=
num_blocks
self
.
num_heads
=
num_heads
or
model_channels
//
num_head_channels
self
.
mlp_ratio
=
mlp_ratio
self
.
pe_mode
=
pe_mode
self
.
use_checkpoint
=
use_checkpoint
self
.
share_mod
=
share_mod
self
.
initialization
=
initialization
self
.
qk_rms_norm
=
qk_rms_norm
self
.
qk_rms_norm_cross
=
qk_rms_norm_cross
self
.
dtype
=
str_to_dtype
(
dtype
)
self
.
t_embedder
=
TimestepEmbedder
(
model_channels
)
if
share_mod
:
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
model_channels
,
6
*
model_channels
,
bias
=
True
)
)
if
pe_mode
==
"ape"
:
self
.
pos_embedder
=
AbsolutePositionEmbedder
(
model_channels
)
self
.
input_layer
=
sp
.
SparseLinear
(
in_channels
,
model_channels
)
self
.
blocks
=
nn
.
ModuleList
([
ModulatedSparseTransformerCrossBlock
(
model_channels
,
cond_channels
,
num_heads
=
self
.
num_heads
,
mlp_ratio
=
self
.
mlp_ratio
,
attn_mode
=
'full'
,
use_checkpoint
=
self
.
use_checkpoint
,
use_rope
=
(
pe_mode
==
"rope"
),
rope_freq
=
rope_freq
,
share_mod
=
self
.
share_mod
,
qk_rms_norm
=
self
.
qk_rms_norm
,
qk_rms_norm_cross
=
self
.
qk_rms_norm_cross
,
)
for
_
in
range
(
num_blocks
)
])
self
.
out_layer
=
sp
.
SparseLinear
(
model_channels
,
out_channels
)
self
.
initialize_weights
()
self
.
convert_to
(
self
.
dtype
)
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Return the device of the model.
"""
return
next
(
self
.
parameters
()).
device
def
convert_to
(
self
,
dtype
:
torch
.
dtype
)
->
None
:
"""
Convert the torso of the model to the specified dtype.
"""
self
.
dtype
=
dtype
self
.
blocks
.
apply
(
partial
(
convert_module_to
,
dtype
=
dtype
))
def
initialize_weights
(
self
)
->
None
:
if
self
.
initialization
==
'vanilla'
:
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in DiT blocks:
if
self
.
share_mod
:
nn
.
init
.
constant_
(
self
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
self
.
adaLN_modulation
[
-
1
].
bias
,
0
)
else
:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
out_layer
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
out_layer
.
bias
,
0
)
elif
self
.
initialization
==
'scaled'
:
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
normal_
(
module
.
weight
,
std
=
np
.
sqrt
(
2.0
/
(
5.0
*
self
.
model_channels
)))
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Scaled init for to_out and ffn2
def
_scaled_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
normal_
(
module
.
weight
,
std
=
1.0
/
np
.
sqrt
(
5
*
self
.
num_blocks
*
self
.
model_channels
))
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
for
block
in
self
.
blocks
:
block
.
self_attn
.
to_out
.
apply
(
_scaled_init
)
block
.
cross_attn
.
to_out
.
apply
(
_scaled_init
)
block
.
mlp
.
mlp
[
2
].
apply
(
_scaled_init
)
# Initialize input layer to make the initial representation have variance 1
nn
.
init
.
normal_
(
self
.
input_layer
.
weight
,
std
=
1.0
/
np
.
sqrt
(
self
.
in_channels
))
nn
.
init
.
zeros_
(
self
.
input_layer
.
bias
)
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in DiT blocks:
if
self
.
share_mod
:
nn
.
init
.
constant_
(
self
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
self
.
adaLN_modulation
[
-
1
].
bias
,
0
)
else
:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
out_layer
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
out_layer
.
bias
,
0
)
def
forward
(
self
,
x
:
sp
.
SparseTensor
,
t
:
torch
.
Tensor
,
cond
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
concat_cond
:
Optional
[
sp
.
SparseTensor
]
=
None
,
**
kwargs
)
->
sp
.
SparseTensor
:
if
concat_cond
is
not
None
:
x
=
sp
.
sparse_cat
([
x
,
concat_cond
],
dim
=-
1
)
if
isinstance
(
cond
,
list
):
cond
=
sp
.
VarLenTensor
.
from_tensor_list
(
cond
)
h
=
self
.
input_layer
(
x
)
h
=
manual_cast
(
h
,
self
.
dtype
)
t_emb
=
self
.
t_embedder
(
t
)
if
self
.
share_mod
:
t_emb
=
self
.
adaLN_modulation
(
t_emb
)
t_emb
=
manual_cast
(
t_emb
,
self
.
dtype
)
cond
=
manual_cast
(
cond
,
self
.
dtype
)
if
self
.
pe_mode
==
"ape"
:
pe
=
self
.
pos_embedder
(
h
.
coords
[:,
1
:])
h
=
h
+
manual_cast
(
pe
,
self
.
dtype
)
for
block
in
self
.
blocks
:
h
=
block
(
h
,
t_emb
,
cond
)
h
=
manual_cast
(
h
,
x
.
dtype
)
h
=
h
.
replace
(
F
.
layer_norm
(
h
.
feats
,
h
.
feats
.
shape
[
-
1
:]))
h
=
self
.
out_layer
(
h
)
return
h
class
ElasticSLatFlowModel
(
SparseTransformerElasticMixin
,
SLatFlowModel
):
"""
SLat Flow Model with elastic memory management.
Used for training with low VRAM.
"""
pass
TRELLIS.2_DCU/trellis2/modules/__pycache__/image_feature_extractor.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/__pycache__/norm.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/__pycache__/spatial.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/__pycache__/utils.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/attention/__init__.py
0 → 100644
View file @
f05e915f
from
.full_attn
import
*
from
.modules
import
*
from
.rope
import
*
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/config.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/full_attn.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/modules.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/attention/__pycache__/rope.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/attention/config.py
0 → 100644
View file @
f05e915f
from
typing
import
*
BACKEND
=
'flash_attn'
DEBUG
=
False
def
__from_env
():
import
os
global
BACKEND
global
DEBUG
env_attn_backend
=
os
.
environ
.
get
(
'ATTN_BACKEND'
)
env_attn_debug
=
os
.
environ
.
get
(
'ATTN_DEBUG'
)
if
env_attn_backend
is
not
None
and
env_attn_backend
in
[
'xformers'
,
'flash_attn'
,
'flash_attn_3'
,
'sdpa'
,
'naive'
]:
BACKEND
=
env_attn_backend
if
env_attn_debug
is
not
None
:
DEBUG
=
env_attn_debug
==
'1'
print
(
f
"[ATTENTION] Using backend:
{
BACKEND
}
"
)
__from_env
()
def
set_backend
(
backend
:
Literal
[
'xformers'
,
'flash_attn'
]):
global
BACKEND
BACKEND
=
backend
def
set_debug
(
debug
:
bool
):
global
DEBUG
DEBUG
=
debug
TRELLIS.2_DCU/trellis2/modules/attention/full_attn.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
math
from
.
import
config
__all__
=
[
'scaled_dot_product_attention'
,
]
def
_naive_sdpa
(
q
,
k
,
v
):
"""
Naive implementation of scaled dot product attention.
"""
q
=
q
.
permute
(
0
,
2
,
1
,
3
)
# [N, H, L, C]
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
# [N, H, L, C]
v
=
v
.
permute
(
0
,
2
,
1
,
3
)
# [N, H, L, C]
scale_factor
=
1
/
math
.
sqrt
(
q
.
size
(
-
1
))
attn_weight
=
q
@
k
.
transpose
(
-
2
,
-
1
)
*
scale_factor
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
)
out
=
attn_weight
@
v
out
=
out
.
permute
(
0
,
2
,
1
,
3
)
# [N, L, H, C]
return
out
@
overload
def
scaled_dot_product_attention
(
qkv
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply scaled dot product attention.
Args:
qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
"""
...
@
overload
def
scaled_dot_product_attention
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply scaled dot product attention.
Args:
q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
"""
...
@
overload
def
scaled_dot_product_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply scaled dot product attention.
Args:
q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
Note:
k and v are assumed to have the same coordinate map.
"""
...
def
scaled_dot_product_attention
(
*
args
,
**
kwargs
):
arg_names_dict
=
{
1
:
[
'qkv'
],
2
:
[
'q'
,
'kv'
],
3
:
[
'q'
,
'k'
,
'v'
]
}
num_all_args
=
len
(
args
)
+
len
(
kwargs
)
assert
num_all_args
in
arg_names_dict
,
f
"Invalid number of arguments, got
{
num_all_args
}
, expected 1, 2, or 3"
for
key
in
arg_names_dict
[
num_all_args
][
len
(
args
):]:
assert
key
in
kwargs
,
f
"Missing argument
{
key
}
"
if
num_all_args
==
1
:
qkv
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
[
'qkv'
]
assert
len
(
qkv
.
shape
)
==
5
and
qkv
.
shape
[
2
]
==
3
,
f
"Invalid shape for qkv, got
{
qkv
.
shape
}
, expected [N, L, 3, H, C]"
device
=
qkv
.
device
elif
num_all_args
==
2
:
q
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
[
'q'
]
kv
=
args
[
1
]
if
len
(
args
)
>
1
else
kwargs
[
'kv'
]
assert
q
.
shape
[
0
]
==
kv
.
shape
[
0
],
f
"Batch size mismatch, got
{
q
.
shape
[
0
]
}
and
{
kv
.
shape
[
0
]
}
"
assert
len
(
q
.
shape
)
==
4
,
f
"Invalid shape for q, got
{
q
.
shape
}
, expected [N, L, H, C]"
assert
len
(
kv
.
shape
)
==
5
,
f
"Invalid shape for kv, got
{
kv
.
shape
}
, expected [N, L, 2, H, C]"
device
=
q
.
device
elif
num_all_args
==
3
:
q
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
[
'q'
]
k
=
args
[
1
]
if
len
(
args
)
>
1
else
kwargs
[
'k'
]
v
=
args
[
2
]
if
len
(
args
)
>
2
else
kwargs
[
'v'
]
assert
q
.
shape
[
0
]
==
k
.
shape
[
0
]
==
v
.
shape
[
0
],
f
"Batch size mismatch, got
{
q
.
shape
[
0
]
}
,
{
k
.
shape
[
0
]
}
, and
{
v
.
shape
[
0
]
}
"
assert
len
(
q
.
shape
)
==
4
,
f
"Invalid shape for q, got
{
q
.
shape
}
, expected [N, L, H, Ci]"
assert
len
(
k
.
shape
)
==
4
,
f
"Invalid shape for k, got
{
k
.
shape
}
, expected [N, L, H, Ci]"
assert
len
(
v
.
shape
)
==
4
,
f
"Invalid shape for v, got
{
v
.
shape
}
, expected [N, L, H, Co]"
device
=
q
.
device
if
config
.
BACKEND
==
'xformers'
:
if
'xops'
not
in
globals
():
import
xformers.ops
as
xops
if
num_all_args
==
1
:
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
elif
num_all_args
==
2
:
k
,
v
=
kv
.
unbind
(
dim
=
2
)
out
=
xops
.
memory_efficient_attention
(
q
,
k
,
v
)
elif
config
.
BACKEND
==
'flash_attn'
:
if
'flash_attn'
not
in
globals
():
import
flash_attn
if
num_all_args
==
1
:
out
=
flash_attn
.
flash_attn_qkvpacked_func
(
qkv
)
elif
num_all_args
==
2
:
out
=
flash_attn
.
flash_attn_kvpacked_func
(
q
,
kv
)
elif
num_all_args
==
3
:
out
=
flash_attn
.
flash_attn_func
(
q
,
k
,
v
)
elif
config
.
BACKEND
==
'flash_attn_3'
:
if
'flash_attn_3'
not
in
globals
():
import
flash_attn_interface
as
flash_attn_3
if
num_all_args
==
1
:
out
=
flash_attn_3
.
flash_attn_qkvpacked_func
(
qkv
)
elif
num_all_args
==
2
:
k
,
v
=
kv
.
unbind
(
dim
=
2
)
out
=
flash_attn_3
.
flash_attn_func
(
q
,
k
,
v
)
elif
num_all_args
==
3
:
out
=
flash_attn_3
.
flash_attn_func
(
q
,
k
,
v
)
elif
config
.
BACKEND
==
'sdpa'
:
if
'sdpa'
not
in
globals
():
from
torch.nn.functional
import
scaled_dot_product_attention
as
sdpa
if
num_all_args
==
1
:
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
elif
num_all_args
==
2
:
k
,
v
=
kv
.
unbind
(
dim
=
2
)
q
=
q
.
permute
(
0
,
2
,
1
,
3
)
# [N, H, L, C]
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
# [N, H, L, C]
v
=
v
.
permute
(
0
,
2
,
1
,
3
)
# [N, H, L, C]
out
=
sdpa
(
q
,
k
,
v
)
# [N, H, L, C]
out
=
out
.
permute
(
0
,
2
,
1
,
3
)
# [N, L, H, C]
elif
config
.
BACKEND
==
'naive'
:
if
num_all_args
==
1
:
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
elif
num_all_args
==
2
:
k
,
v
=
kv
.
unbind
(
dim
=
2
)
out
=
_naive_sdpa
(
q
,
k
,
v
)
else
:
raise
ValueError
(
f
"Unknown attention module:
{
config
.
BACKEND
}
"
)
return
out
TRELLIS.2_DCU/trellis2/modules/attention/modules.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.full_attn
import
scaled_dot_product_attention
from
.rope
import
RotaryPositionEmbedder
class
MultiHeadRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
heads
:
int
):
super
().
__init__
()
self
.
scale
=
dim
**
0.5
self
.
gamma
=
nn
.
Parameter
(
torch
.
ones
(
heads
,
dim
))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
(
F
.
normalize
(
x
.
float
(),
dim
=
-
1
)
*
self
.
gamma
*
self
.
scale
).
to
(
x
.
dtype
)
class
MultiHeadAttention
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
num_heads
:
int
,
ctx_channels
:
Optional
[
int
]
=
None
,
type
:
Literal
[
"self"
,
"cross"
]
=
"self"
,
attn_mode
:
Literal
[
"full"
,
"windowed"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
Tuple
[
int
,
int
,
int
]]
=
None
,
qkv_bias
:
bool
=
True
,
use_rope
:
bool
=
False
,
rope_freq
:
Tuple
[
float
,
float
]
=
(
1.0
,
10000.0
),
qk_rms_norm
:
bool
=
False
,
):
super
().
__init__
()
assert
channels
%
num_heads
==
0
assert
type
in
[
"self"
,
"cross"
],
f
"Invalid attention type:
{
type
}
"
assert
attn_mode
in
[
"full"
,
"windowed"
],
f
"Invalid attention mode:
{
attn_mode
}
"
assert
type
==
"self"
or
attn_mode
==
"full"
,
"Cross-attention only supports full attention"
if
attn_mode
==
"windowed"
:
raise
NotImplementedError
(
"Windowed attention is not yet implemented"
)
self
.
channels
=
channels
self
.
head_dim
=
channels
//
num_heads
self
.
ctx_channels
=
ctx_channels
if
ctx_channels
is
not
None
else
channels
self
.
num_heads
=
num_heads
self
.
_type
=
type
self
.
attn_mode
=
attn_mode
self
.
window_size
=
window_size
self
.
shift_window
=
shift_window
self
.
use_rope
=
use_rope
self
.
qk_rms_norm
=
qk_rms_norm
if
self
.
_type
==
"self"
:
self
.
to_qkv
=
nn
.
Linear
(
channels
,
channels
*
3
,
bias
=
qkv_bias
)
else
:
self
.
to_q
=
nn
.
Linear
(
channels
,
channels
,
bias
=
qkv_bias
)
self
.
to_kv
=
nn
.
Linear
(
self
.
ctx_channels
,
channels
*
2
,
bias
=
qkv_bias
)
if
self
.
qk_rms_norm
:
self
.
q_rms_norm
=
MultiHeadRMSNorm
(
self
.
head_dim
,
num_heads
)
self
.
k_rms_norm
=
MultiHeadRMSNorm
(
self
.
head_dim
,
num_heads
)
self
.
to_out
=
nn
.
Linear
(
channels
,
channels
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
context
:
Optional
[
torch
.
Tensor
]
=
None
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
B
,
L
,
C
=
x
.
shape
if
self
.
_type
==
"self"
:
qkv
=
self
.
to_qkv
(
x
)
qkv
=
qkv
.
reshape
(
B
,
L
,
3
,
self
.
num_heads
,
-
1
)
if
self
.
attn_mode
==
"full"
:
if
self
.
qk_rms_norm
or
self
.
use_rope
:
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
if
self
.
qk_rms_norm
:
q
=
self
.
q_rms_norm
(
q
)
k
=
self
.
k_rms_norm
(
k
)
if
self
.
use_rope
:
assert
phases
is
not
None
,
"Phases must be provided for RoPE"
q
=
RotaryPositionEmbedder
.
apply_rotary_embedding
(
q
,
phases
)
k
=
RotaryPositionEmbedder
.
apply_rotary_embedding
(
k
,
phases
)
h
=
scaled_dot_product_attention
(
q
,
k
,
v
)
else
:
h
=
scaled_dot_product_attention
(
qkv
)
elif
self
.
attn_mode
==
"windowed"
:
raise
NotImplementedError
(
"Windowed attention is not yet implemented"
)
else
:
Lkv
=
context
.
shape
[
1
]
q
=
self
.
to_q
(
x
)
kv
=
self
.
to_kv
(
context
)
q
=
q
.
reshape
(
B
,
L
,
self
.
num_heads
,
-
1
)
kv
=
kv
.
reshape
(
B
,
Lkv
,
2
,
self
.
num_heads
,
-
1
)
if
self
.
qk_rms_norm
:
q
=
self
.
q_rms_norm
(
q
)
k
,
v
=
kv
.
unbind
(
dim
=
2
)
k
=
self
.
k_rms_norm
(
k
)
h
=
scaled_dot_product_attention
(
q
,
k
,
v
)
else
:
h
=
scaled_dot_product_attention
(
q
,
kv
)
h
=
h
.
reshape
(
B
,
L
,
-
1
)
h
=
self
.
to_out
(
h
)
return
h
TRELLIS.2_DCU/trellis2/modules/attention/rope.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
class
RotaryPositionEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
head_dim
:
int
,
dim
:
int
=
3
,
rope_freq
:
Tuple
[
float
,
float
]
=
(
1.0
,
10000.0
)
):
super
().
__init__
()
assert
head_dim
%
2
==
0
,
"Head dim must be divisible by 2"
self
.
head_dim
=
head_dim
self
.
dim
=
dim
self
.
rope_freq
=
rope_freq
self
.
freq_dim
=
head_dim
//
2
//
dim
self
.
freqs
=
torch
.
arange
(
self
.
freq_dim
,
dtype
=
torch
.
float32
)
/
self
.
freq_dim
self
.
freqs
=
rope_freq
[
0
]
/
(
rope_freq
[
1
]
**
(
self
.
freqs
))
def
_get_phases
(
self
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
.
freqs
=
self
.
freqs
.
to
(
indices
.
device
)
phases
=
torch
.
outer
(
indices
,
self
.
freqs
)
phases
=
torch
.
polar
(
torch
.
ones_like
(
phases
),
phases
)
return
phases
@
staticmethod
def
apply_rotary_embedding
(
x
:
torch
.
Tensor
,
phases
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_complex
=
torch
.
view_as_complex
(
x
.
float
().
reshape
(
*
x
.
shape
[:
-
1
],
-
1
,
2
))
x_rotated
=
x_complex
*
phases
.
unsqueeze
(
-
2
)
x_embed
=
torch
.
view_as_real
(
x_rotated
).
reshape
(
*
x_rotated
.
shape
[:
-
1
],
-
1
).
to
(
x
.
dtype
)
return
x_embed
def
forward
(
self
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
indices (torch.Tensor): [..., N, C] tensor of spatial positions
"""
assert
indices
.
shape
[
-
1
]
==
self
.
dim
,
f
"Last dim of indices must be
{
self
.
dim
}
"
phases
=
self
.
_get_phases
(
indices
.
reshape
(
-
1
)).
reshape
(
*
indices
.
shape
[:
-
1
],
-
1
)
if
phases
.
shape
[
-
1
]
<
self
.
head_dim
//
2
:
padn
=
self
.
head_dim
//
2
-
phases
.
shape
[
-
1
]
phases
=
torch
.
cat
([
phases
,
torch
.
polar
(
torch
.
ones
(
*
phases
.
shape
[:
-
1
],
padn
,
device
=
phases
.
device
),
torch
.
zeros
(
*
phases
.
shape
[:
-
1
],
padn
,
device
=
phases
.
device
)
)],
dim
=-
1
)
return
phases
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/image_feature_extractor.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn.functional
as
F
from
torchvision
import
transforms
from
transformers
import
DINOv3ViTModel
import
numpy
as
np
from
PIL
import
Image
class
DinoV2FeatureExtractor
:
"""
Feature extractor for DINOv2 models.
"""
def
__init__
(
self
,
model_name
:
str
):
self
.
model_name
=
model_name
self
.
model
=
torch
.
hub
.
load
(
'facebookresearch/dinov2'
,
model_name
,
pretrained
=
True
)
self
.
model
.
eval
()
self
.
transform
=
transforms
.
Compose
([
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
])
def
to
(
self
,
device
):
self
.
model
.
to
(
device
)
def
cuda
(
self
):
self
.
model
.
cuda
()
def
cpu
(
self
):
self
.
model
.
cpu
()
@
torch
.
no_grad
()
def
__call__
(
self
,
image
:
Union
[
torch
.
Tensor
,
List
[
Image
.
Image
]])
->
torch
.
Tensor
:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if
isinstance
(
image
,
torch
.
Tensor
):
assert
image
.
ndim
==
4
,
"Image tensor should be batched (B, C, H, W)"
elif
isinstance
(
image
,
list
):
assert
all
(
isinstance
(
i
,
Image
.
Image
)
for
i
in
image
),
"Image list should be list of PIL images"
image
=
[
i
.
resize
((
518
,
518
),
Image
.
LANCZOS
)
for
i
in
image
]
image
=
[
np
.
array
(
i
.
convert
(
'RGB'
)).
astype
(
np
.
float32
)
/
255
for
i
in
image
]
image
=
[
torch
.
from_numpy
(
i
).
permute
(
2
,
0
,
1
).
float
()
for
i
in
image
]
image
=
torch
.
stack
(
image
).
cuda
()
else
:
raise
ValueError
(
f
"Unsupported type of image:
{
type
(
image
)
}
"
)
image
=
self
.
transform
(
image
).
cuda
()
features
=
self
.
model
(
image
,
is_training
=
True
)[
'x_prenorm'
]
patchtokens
=
F
.
layer_norm
(
features
,
features
.
shape
[
-
1
:])
return
patchtokens
class
DinoV3FeatureExtractor
:
"""
Feature extractor for DINOv3 models.
"""
def
__init__
(
self
,
model_name
:
str
,
image_size
=
512
):
self
.
model_name
=
model_name
self
.
model
=
DINOv3ViTModel
.
from_pretrained
(
model_name
)
self
.
model
.
eval
()
self
.
image_size
=
image_size
self
.
transform
=
transforms
.
Compose
([
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
])
def
to
(
self
,
device
):
self
.
model
.
to
(
device
)
def
cuda
(
self
):
self
.
model
.
cuda
()
def
cpu
(
self
):
self
.
model
.
cpu
()
def
extract_features
(
self
,
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# transformers 5.x: DINOv3ViTModel is a backbone, use its forward() directly
if
hasattr
(
self
.
model
,
'model'
):
output
=
self
.
model
(
image
)
hidden_states
=
output
.
last_hidden_state
return
F
.
layer_norm
(
hidden_states
,
hidden_states
.
shape
[
-
1
:])
# older transformers: manual layer iteration
image
=
image
.
to
(
self
.
model
.
embeddings
.
patch_embeddings
.
weight
.
dtype
)
hidden_states
=
self
.
model
.
embeddings
(
image
,
bool_masked_pos
=
None
)
position_embeddings
=
self
.
model
.
rope_embeddings
(
image
)
for
i
,
layer_module
in
enumerate
(
self
.
model
.
layer
):
hidden_states
=
layer_module
(
hidden_states
,
position_embeddings
=
position_embeddings
,
)
return
F
.
layer_norm
(
hidden_states
,
hidden_states
.
shape
[
-
1
:])
@
torch
.
no_grad
()
def
__call__
(
self
,
image
:
Union
[
torch
.
Tensor
,
List
[
Image
.
Image
]])
->
torch
.
Tensor
:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if
isinstance
(
image
,
torch
.
Tensor
):
assert
image
.
ndim
==
4
,
"Image tensor should be batched (B, C, H, W)"
elif
isinstance
(
image
,
list
):
assert
all
(
isinstance
(
i
,
Image
.
Image
)
for
i
in
image
),
"Image list should be list of PIL images"
image
=
[
i
.
resize
((
self
.
image_size
,
self
.
image_size
),
Image
.
LANCZOS
)
for
i
in
image
]
image
=
[
np
.
array
(
i
.
convert
(
'RGB'
)).
astype
(
np
.
float32
)
/
255
for
i
in
image
]
image
=
[
torch
.
from_numpy
(
i
).
permute
(
2
,
0
,
1
).
float
()
for
i
in
image
]
image
=
torch
.
stack
(
image
).
cuda
()
else
:
raise
ValueError
(
f
"Unsupported type of image:
{
type
(
image
)
}
"
)
image
=
self
.
transform
(
image
).
cuda
()
features
=
self
.
extract_features
(
image
)
return
features
TRELLIS.2_DCU/trellis2/modules/norm.py
0 → 100644
View file @
f05e915f
import
torch
import
torch.nn
as
nn
from
.utils
import
manual_cast
class
LayerNorm32
(
nn
.
LayerNorm
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_dtype
=
x
.
dtype
x
=
manual_cast
(
x
,
torch
.
float32
)
o
=
super
().
forward
(
x
)
return
manual_cast
(
o
,
x_dtype
)
class
GroupNorm32
(
nn
.
GroupNorm
):
"""
A GroupNorm layer that converts to float32 before the forward pass.
"""
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_dtype
=
x
.
dtype
x
=
manual_cast
(
x
,
torch
.
float32
)
o
=
super
().
forward
(
x
)
return
manual_cast
(
o
,
x_dtype
)
class
ChannelLayerNorm32
(
LayerNorm32
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
DIM
=
x
.
dim
()
x
=
x
.
permute
(
0
,
*
range
(
2
,
DIM
),
1
).
contiguous
()
x
=
super
().
forward
(
x
)
x
=
x
.
permute
(
0
,
DIM
-
1
,
*
range
(
1
,
DIM
-
1
)).
contiguous
()
return
x
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/sparse/__init__.py
0 → 100644
View file @
f05e915f
from
.
import
config
import
importlib
__attributes
=
{
'VarLenTensor'
:
'basic'
,
'varlen_cat'
:
'basic'
,
'varlen_unbind'
:
'basic'
,
'SparseTensor'
:
'basic'
,
'sparse_cat'
:
'basic'
,
'sparse_unbind'
:
'basic'
,
'SparseGroupNorm'
:
'norm'
,
'SparseLayerNorm'
:
'norm'
,
'SparseGroupNorm32'
:
'norm'
,
'SparseLayerNorm32'
:
'norm'
,
'SparseReLU'
:
'nonlinearity'
,
'SparseSiLU'
:
'nonlinearity'
,
'SparseGELU'
:
'nonlinearity'
,
'SparseActivation'
:
'nonlinearity'
,
'SparseLinear'
:
'linear'
,
'sparse_scaled_dot_product_attention'
:
'attention'
,
'SerializeMode'
:
'attention'
,
'sparse_serialized_scaled_dot_product_self_attention'
:
'attention'
,
'sparse_windowed_scaled_dot_product_self_attention'
:
'attention'
,
'sparse_windowed_scaled_dot_product_cross_attention'
:
'attention'
,
'SparseRotaryPositionEmbedder'
:
'attention'
,
'SparseMultiHeadAttention'
:
'attention'
,
'SparseConv3d'
:
'conv'
,
'SparseInverseConv3d'
:
'conv'
,
'SparseDownsample'
:
'spatial'
,
'SparseUpsample'
:
'spatial'
,
'SparseSubdivide'
:
'spatial'
,
'SparseSpatial2Channel'
:
'spatial'
,
'SparseChannel2Spatial'
:
'spatial'
,
'sparse_nearest_interpolate'
:
'spatial'
,
'sparse_trilinear_interpolate'
:
'spatial'
,
'encode_seq'
:
'serialize'
,
'decode_seq'
:
'serialize'
,
}
__submodules
=
[
'transformer'
,
'conv'
]
__all__
=
list
(
__attributes
.
keys
())
+
__submodules
def
__getattr__
(
name
):
if
name
not
in
globals
():
if
name
in
__attributes
:
module_name
=
__attributes
[
name
]
module
=
importlib
.
import_module
(
f
".
{
module_name
}
"
,
__name__
)
globals
()[
name
]
=
getattr
(
module
,
name
)
elif
name
in
__submodules
:
module
=
importlib
.
import_module
(
f
".
{
name
}
"
,
__name__
)
globals
()[
name
]
=
module
else
:
raise
AttributeError
(
f
"module
{
__name__
}
has no attribute
{
name
}
"
)
return
globals
()[
name
]
# For Pylance
if
__name__
==
'__main__'
:
from
.basic
import
*
from
.norm
import
*
from
.nonlinearity
import
*
from
.linear
import
*
from
.attention
import
*
from
.conv
import
*
from
.spatial
import
*
from
.serialize
import
*
import
transformer
import
conv
TRELLIS.2_DCU/trellis2/modules/sparse/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment