Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
TRELLIS.2
Commits
f05e915f
Commit
f05e915f
authored
May 27, 2026
by
weishb
Browse files
首次提交
parent
297bf637
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
994 additions
and
0 deletions
+994
-0
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/blocks.py
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/blocks.py
+145
-0
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/modulated.py
...IS.2_DCU/trellis2/modules/sparse/transformer/modulated.py
+166
-0
TRELLIS.2_DCU/trellis2/modules/spatial.py
TRELLIS.2_DCU/trellis2/modules/spatial.py
+48
-0
TRELLIS.2_DCU/trellis2/modules/transformer/__init__.py
TRELLIS.2_DCU/trellis2/modules/transformer/__init__.py
+2
-0
TRELLIS.2_DCU/trellis2/modules/transformer/__pycache__/__init__.cpython-310.pyc
.../modules/transformer/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/transformer/__pycache__/blocks.cpython-310.pyc
...s2/modules/transformer/__pycache__/blocks.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/transformer/__pycache__/modulated.cpython-310.pyc
...modules/transformer/__pycache__/modulated.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/transformer/blocks.py
TRELLIS.2_DCU/trellis2/modules/transformer/blocks.py
+186
-0
TRELLIS.2_DCU/trellis2/modules/transformer/modulated.py
TRELLIS.2_DCU/trellis2/modules/transformer/modulated.py
+165
-0
TRELLIS.2_DCU/trellis2/modules/utils.py
TRELLIS.2_DCU/trellis2/modules/utils.py
+96
-0
TRELLIS.2_DCU/trellis2/pipelines/__init__.py
TRELLIS.2_DCU/trellis2/pipelines/__init__.py
+52
-0
TRELLIS.2_DCU/trellis2/pipelines/__pycache__/__init__.cpython-310.pyc
...U/trellis2/pipelines/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/__pycache__/base.cpython-310.pyc
...2_DCU/trellis2/pipelines/__pycache__/base.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/__pycache__/trellis2_image_to_3d.cpython-310.pyc
...ipelines/__pycache__/trellis2_image_to_3d.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/base.py
TRELLIS.2_DCU/trellis2/pipelines/base.py
+72
-0
TRELLIS.2_DCU/trellis2/pipelines/rembg/BiRefNet.py
TRELLIS.2_DCU/trellis2/pipelines/rembg/BiRefNet.py
+55
-0
TRELLIS.2_DCU/trellis2/pipelines/rembg/__init__.py
TRELLIS.2_DCU/trellis2/pipelines/rembg/__init__.py
+1
-0
TRELLIS.2_DCU/trellis2/pipelines/rembg/__pycache__/BiRefNet.cpython-310.pyc
...lis2/pipelines/rembg/__pycache__/BiRefNet.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/rembg/__pycache__/__init__.cpython-310.pyc
...lis2/pipelines/rembg/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/pipelines/samplers/__init__.py
TRELLIS.2_DCU/trellis2/pipelines/samplers/__init__.py
+6
-0
No files found.
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/blocks.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
from
..basic
import
VarLenTensor
,
SparseTensor
from
..linear
import
SparseLinear
from
..nonlinearity
import
SparseGELU
from
..attention
import
SparseMultiHeadAttention
from
...norm
import
LayerNorm32
class
SparseFeedForwardNet
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
mlp_ratio
:
float
=
4.0
):
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
SparseLinear
(
channels
,
int
(
channels
*
mlp_ratio
)),
SparseGELU
(
approximate
=
"tanh"
),
SparseLinear
(
int
(
channels
*
mlp_ratio
),
channels
),
)
def
forward
(
self
,
x
:
VarLenTensor
)
->
VarLenTensor
:
return
self
.
mlp
(
x
)
class
SparseTransformerBlock
(
nn
.
Module
):
"""
Sparse Transformer block (MSA + FFN).
"""
def
__init__
(
self
,
channels
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
Literal
[
"full"
,
"swin"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
Tuple
[
int
,
int
,
int
]]
=
None
,
use_checkpoint
:
bool
=
False
,
use_rope
:
bool
=
False
,
rope_freq
:
Tuple
[
int
,
int
]
=
(
1.0
,
10000.0
),
qk_rms_norm
:
bool
=
False
,
qkv_bias
:
bool
=
True
,
ln_affine
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_checkpoint
=
use_checkpoint
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
attn
=
SparseMultiHeadAttention
(
channels
,
num_heads
=
num_heads
,
attn_mode
=
attn_mode
,
window_size
=
window_size
,
shift_window
=
shift_window
,
qkv_bias
=
qkv_bias
,
use_rope
=
use_rope
,
rope_freq
=
rope_freq
,
qk_rms_norm
=
qk_rms_norm
,
)
self
.
mlp
=
SparseFeedForwardNet
(
channels
,
mlp_ratio
=
mlp_ratio
,
)
def
_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
self
.
attn
(
h
)
x
=
x
+
h
h
=
x
.
replace
(
self
.
norm2
(
x
.
feats
))
h
=
self
.
mlp
(
h
)
x
=
x
+
h
return
x
def
forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
)
class
SparseTransformerCrossBlock
(
nn
.
Module
):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN).
"""
def
__init__
(
self
,
channels
:
int
,
ctx_channels
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
Literal
[
"full"
,
"swin"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
Tuple
[
int
,
int
,
int
]]
=
None
,
use_checkpoint
:
bool
=
False
,
use_rope
:
bool
=
False
,
qk_rms_norm
:
bool
=
False
,
qk_rms_norm_cross
:
bool
=
False
,
qkv_bias
:
bool
=
True
,
ln_affine
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_checkpoint
=
use_checkpoint
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
norm3
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
self_attn
=
SparseMultiHeadAttention
(
channels
,
num_heads
=
num_heads
,
type
=
"self"
,
attn_mode
=
attn_mode
,
window_size
=
window_size
,
shift_window
=
shift_window
,
qkv_bias
=
qkv_bias
,
use_rope
=
use_rope
,
qk_rms_norm
=
qk_rms_norm
,
)
self
.
cross_attn
=
SparseMultiHeadAttention
(
channels
,
ctx_channels
=
ctx_channels
,
num_heads
=
num_heads
,
type
=
"cross"
,
attn_mode
=
"full"
,
qkv_bias
=
qkv_bias
,
qk_rms_norm
=
qk_rms_norm_cross
,
)
self
.
mlp
=
SparseFeedForwardNet
(
channels
,
mlp_ratio
=
mlp_ratio
,
)
def
_forward
(
self
,
x
:
SparseTensor
,
context
:
Union
[
torch
.
Tensor
,
VarLenTensor
])
->
SparseTensor
:
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
self
.
self_attn
(
h
)
x
=
x
+
h
h
=
x
.
replace
(
self
.
norm2
(
x
.
feats
))
h
=
self
.
cross_attn
(
h
,
context
)
x
=
x
+
h
h
=
x
.
replace
(
self
.
norm3
(
x
.
feats
))
h
=
self
.
mlp
(
h
)
x
=
x
+
h
return
x
def
forward
(
self
,
x
:
SparseTensor
,
context
:
Union
[
torch
.
Tensor
,
VarLenTensor
])
->
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
context
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
,
context
)
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/modulated.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
from
..basic
import
VarLenTensor
,
SparseTensor
from
..attention
import
SparseMultiHeadAttention
from
...norm
import
LayerNorm32
from
.blocks
import
SparseFeedForwardNet
class
ModulatedSparseTransformerBlock
(
nn
.
Module
):
"""
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
"""
def
__init__
(
self
,
channels
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
Literal
[
"full"
,
"swin"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
Tuple
[
int
,
int
,
int
]]
=
None
,
use_checkpoint
:
bool
=
False
,
use_rope
:
bool
=
False
,
rope_freq
:
Tuple
[
float
,
float
]
=
(
1.0
,
10000.0
),
qk_rms_norm
:
bool
=
False
,
qkv_bias
:
bool
=
True
,
share_mod
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_checkpoint
=
use_checkpoint
self
.
share_mod
=
share_mod
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
attn
=
SparseMultiHeadAttention
(
channels
,
num_heads
=
num_heads
,
attn_mode
=
attn_mode
,
window_size
=
window_size
,
shift_window
=
shift_window
,
qkv_bias
=
qkv_bias
,
use_rope
=
use_rope
,
rope_freq
=
rope_freq
,
qk_rms_norm
=
qk_rms_norm
,
)
self
.
mlp
=
SparseFeedForwardNet
(
channels
,
mlp_ratio
=
mlp_ratio
,
)
if
not
share_mod
:
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
channels
,
6
*
channels
,
bias
=
True
)
)
else
:
self
.
modulation
=
nn
.
Parameter
(
torch
.
randn
(
6
*
channels
)
/
channels
**
0.5
)
def
_forward
(
self
,
x
:
SparseTensor
,
mod
:
torch
.
Tensor
)
->
SparseTensor
:
if
self
.
share_mod
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
modulation
+
mod
).
type
(
mod
.
dtype
).
chunk
(
6
,
dim
=
1
)
else
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
adaLN_modulation
(
mod
).
chunk
(
6
,
dim
=
1
)
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
h
*
(
1
+
scale_msa
)
+
shift_msa
h
=
self
.
attn
(
h
)
h
=
h
*
gate_msa
x
=
x
+
h
h
=
x
.
replace
(
self
.
norm2
(
x
.
feats
))
h
=
h
*
(
1
+
scale_mlp
)
+
shift_mlp
h
=
self
.
mlp
(
h
)
h
=
h
*
gate_mlp
x
=
x
+
h
return
x
def
forward
(
self
,
x
:
SparseTensor
,
mod
:
torch
.
Tensor
)
->
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
mod
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
,
mod
)
class
ModulatedSparseTransformerCrossBlock
(
nn
.
Module
):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
"""
def
__init__
(
self
,
channels
:
int
,
ctx_channels
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
Literal
[
"full"
,
"swin"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
Tuple
[
int
,
int
,
int
]]
=
None
,
use_checkpoint
:
bool
=
False
,
use_rope
:
bool
=
False
,
rope_freq
:
Tuple
[
float
,
float
]
=
(
1.0
,
10000.0
),
qk_rms_norm
:
bool
=
False
,
qk_rms_norm_cross
:
bool
=
False
,
qkv_bias
:
bool
=
True
,
share_mod
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_checkpoint
=
use_checkpoint
self
.
share_mod
=
share_mod
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
channels
,
elementwise_affine
=
True
,
eps
=
1e-6
)
self
.
norm3
=
LayerNorm32
(
channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
self_attn
=
SparseMultiHeadAttention
(
channels
,
num_heads
=
num_heads
,
type
=
"self"
,
attn_mode
=
attn_mode
,
window_size
=
window_size
,
shift_window
=
shift_window
,
qkv_bias
=
qkv_bias
,
use_rope
=
use_rope
,
rope_freq
=
rope_freq
,
qk_rms_norm
=
qk_rms_norm
,
)
self
.
cross_attn
=
SparseMultiHeadAttention
(
channels
,
ctx_channels
=
ctx_channels
,
num_heads
=
num_heads
,
type
=
"cross"
,
attn_mode
=
"full"
,
qkv_bias
=
qkv_bias
,
qk_rms_norm
=
qk_rms_norm_cross
,
)
self
.
mlp
=
SparseFeedForwardNet
(
channels
,
mlp_ratio
=
mlp_ratio
,
)
if
not
share_mod
:
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
channels
,
6
*
channels
,
bias
=
True
)
)
else
:
self
.
modulation
=
nn
.
Parameter
(
torch
.
randn
(
6
*
channels
)
/
channels
**
0.5
)
def
_forward
(
self
,
x
:
SparseTensor
,
mod
:
torch
.
Tensor
,
context
:
Union
[
torch
.
Tensor
,
VarLenTensor
])
->
SparseTensor
:
if
self
.
share_mod
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
modulation
+
mod
).
type
(
mod
.
dtype
).
chunk
(
6
,
dim
=
1
)
else
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
adaLN_modulation
(
mod
).
chunk
(
6
,
dim
=
1
)
h
=
x
.
replace
(
self
.
norm1
(
x
.
feats
))
h
=
h
*
(
1
+
scale_msa
)
+
shift_msa
h
=
self
.
self_attn
(
h
)
h
=
h
*
gate_msa
x
=
x
+
h
h
=
x
.
replace
(
self
.
norm2
(
x
.
feats
))
h
=
self
.
cross_attn
(
h
,
context
)
x
=
x
+
h
h
=
x
.
replace
(
self
.
norm3
(
x
.
feats
))
h
=
h
*
(
1
+
scale_mlp
)
+
shift_mlp
h
=
self
.
mlp
(
h
)
h
=
h
*
gate_mlp
x
=
x
+
h
return
x
def
forward
(
self
,
x
:
SparseTensor
,
mod
:
torch
.
Tensor
,
context
:
Union
[
torch
.
Tensor
,
VarLenTensor
])
->
SparseTensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
mod
,
context
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
,
mod
,
context
)
TRELLIS.2_DCU/trellis2/modules/spatial.py
0 → 100644
View file @
f05e915f
import
torch
def
pixel_shuffle_3d
(
x
:
torch
.
Tensor
,
scale_factor
:
int
)
->
torch
.
Tensor
:
"""
3D pixel shuffle.
"""
B
,
C
,
H
,
W
,
D
=
x
.
shape
C_
=
C
//
scale_factor
**
3
x
=
x
.
reshape
(
B
,
C_
,
scale_factor
,
scale_factor
,
scale_factor
,
H
,
W
,
D
)
x
=
x
.
permute
(
0
,
1
,
5
,
2
,
6
,
3
,
7
,
4
)
x
=
x
.
reshape
(
B
,
C_
,
H
*
scale_factor
,
W
*
scale_factor
,
D
*
scale_factor
)
return
x
def
patchify
(
x
:
torch
.
Tensor
,
patch_size
:
int
):
"""
Patchify a tensor.
Args:
x (torch.Tensor): (N, C, *spatial) tensor
patch_size (int): Patch size
"""
DIM
=
x
.
dim
()
-
2
for
d
in
range
(
2
,
DIM
+
2
):
assert
x
.
shape
[
d
]
%
patch_size
==
0
,
f
"Dimension
{
d
}
of input tensor must be divisible by patch size, got
{
x
.
shape
[
d
]
}
and
{
patch_size
}
"
x
=
x
.
reshape
(
*
x
.
shape
[:
2
],
*
sum
([[
x
.
shape
[
d
]
//
patch_size
,
patch_size
]
for
d
in
range
(
2
,
DIM
+
2
)],
[]))
x
=
x
.
permute
(
0
,
1
,
*
([
2
*
i
+
3
for
i
in
range
(
DIM
)]
+
[
2
*
i
+
2
for
i
in
range
(
DIM
)]))
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
]
*
(
patch_size
**
DIM
),
*
(
x
.
shape
[
-
DIM
:]))
return
x
def
unpatchify
(
x
:
torch
.
Tensor
,
patch_size
:
int
):
"""
Unpatchify a tensor.
Args:
x (torch.Tensor): (N, C, *spatial) tensor
patch_size (int): Patch size
"""
DIM
=
x
.
dim
()
-
2
assert
x
.
shape
[
1
]
%
(
patch_size
**
DIM
)
==
0
,
f
"Second dimension of input tensor must be divisible by patch size to unpatchify, got
{
x
.
shape
[
1
]
}
and
{
patch_size
**
DIM
}
"
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
]
//
(
patch_size
**
DIM
),
*
([
patch_size
]
*
DIM
),
*
(
x
.
shape
[
-
DIM
:]))
x
=
x
.
permute
(
0
,
1
,
*
(
sum
([[
2
+
DIM
+
i
,
2
+
i
]
for
i
in
range
(
DIM
)],
[])))
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
*
[
x
.
shape
[
2
+
2
*
i
]
*
patch_size
for
i
in
range
(
DIM
)])
return
x
TRELLIS.2_DCU/trellis2/modules/transformer/__init__.py
0 → 100644
View file @
f05e915f
from
.blocks
import
*
from
.modulated
import
*
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/transformer/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/transformer/__pycache__/blocks.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/transformer/__pycache__/modulated.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/transformer/blocks.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
from
..attention
import
MultiHeadAttention
from
..norm
import
LayerNorm32
class
AbsolutePositionEmbedder
(
nn
.
Module
):
"""
Embeds spatial positions into vector representations.
"""
def
__init__
(
self
,
channels
:
int
,
in_channels
:
int
=
3
):
super
().
__init__
()
self
.
channels
=
channels
self
.
in_channels
=
in_channels
self
.
freq_dim
=
channels
//
in_channels
//
2
self
.
freqs
=
torch
.
arange
(
self
.
freq_dim
,
dtype
=
torch
.
float32
)
/
self
.
freq_dim
self
.
freqs
=
1.0
/
(
10000
**
self
.
freqs
)
def
_sin_cos_embedding
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Create sinusoidal position embeddings.
Args:
x: a 1-D Tensor of N indices
Returns:
an (N, D) Tensor of positional embeddings.
"""
self
.
freqs
=
self
.
freqs
.
to
(
x
.
device
)
out
=
torch
.
outer
(
x
,
self
.
freqs
)
out
=
torch
.
cat
([
torch
.
sin
(
out
),
torch
.
cos
(
out
)],
dim
=-
1
)
return
out
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x (torch.Tensor): (N, D) tensor of spatial positions
"""
N
,
D
=
x
.
shape
assert
D
==
self
.
in_channels
,
"Input dimension must match number of input channels"
embed
=
self
.
_sin_cos_embedding
(
x
.
reshape
(
-
1
))
embed
=
embed
.
reshape
(
N
,
-
1
)
if
embed
.
shape
[
1
]
<
self
.
channels
:
embed
=
torch
.
cat
([
embed
,
torch
.
zeros
(
N
,
self
.
channels
-
embed
.
shape
[
1
],
device
=
embed
.
device
)],
dim
=-
1
)
return
embed
class
FeedForwardNet
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
mlp_ratio
:
float
=
4.0
):
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
channels
,
int
(
channels
*
mlp_ratio
)),
nn
.
GELU
(
approximate
=
"tanh"
),
nn
.
Linear
(
int
(
channels
*
mlp_ratio
),
channels
),
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
mlp
(
x
)
class
TransformerBlock
(
nn
.
Module
):
"""
Transformer block (MSA + FFN).
"""
def
__init__
(
self
,
channels
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
Literal
[
"full"
,
"windowed"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
int
]
=
None
,
use_checkpoint
:
bool
=
False
,
use_rope
:
bool
=
False
,
rope_freq
:
Tuple
[
int
,
int
]
=
(
1.0
,
10000.0
),
qk_rms_norm
:
bool
=
False
,
qkv_bias
:
bool
=
True
,
ln_affine
:
bool
=
True
,
):
super
().
__init__
()
self
.
use_checkpoint
=
use_checkpoint
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
attn
=
MultiHeadAttention
(
channels
,
num_heads
=
num_heads
,
attn_mode
=
attn_mode
,
window_size
=
window_size
,
shift_window
=
shift_window
,
qkv_bias
=
qkv_bias
,
use_rope
=
use_rope
,
rope_freq
=
rope_freq
,
qk_rms_norm
=
qk_rms_norm
,
)
self
.
mlp
=
FeedForwardNet
(
channels
,
mlp_ratio
=
mlp_ratio
,
)
def
_forward
(
self
,
x
:
torch
.
Tensor
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
h
=
self
.
norm1
(
x
)
h
=
self
.
attn
(
h
,
phases
=
phases
)
x
=
x
+
h
h
=
self
.
norm2
(
x
)
h
=
self
.
mlp
(
h
)
x
=
x
+
h
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
phases
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
,
phases
)
class
TransformerCrossBlock
(
nn
.
Module
):
"""
Transformer cross-attention block (MSA + MCA + FFN).
"""
def
__init__
(
self
,
channels
:
int
,
ctx_channels
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
Literal
[
"full"
,
"windowed"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
Tuple
[
int
,
int
,
int
]]
=
None
,
use_checkpoint
:
bool
=
False
,
use_rope
:
bool
=
False
,
rope_freq
:
Tuple
[
int
,
int
]
=
(
1.0
,
10000.0
),
qk_rms_norm
:
bool
=
False
,
qk_rms_norm_cross
:
bool
=
False
,
qkv_bias
:
bool
=
True
,
ln_affine
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_checkpoint
=
use_checkpoint
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
norm3
=
LayerNorm32
(
channels
,
elementwise_affine
=
ln_affine
,
eps
=
1e-6
)
self
.
self_attn
=
MultiHeadAttention
(
channels
,
num_heads
=
num_heads
,
type
=
"self"
,
attn_mode
=
attn_mode
,
window_size
=
window_size
,
shift_window
=
shift_window
,
qkv_bias
=
qkv_bias
,
use_rope
=
use_rope
,
rope_freq
=
rope_freq
,
qk_rms_norm
=
qk_rms_norm
,
)
self
.
cross_attn
=
MultiHeadAttention
(
channels
,
ctx_channels
=
ctx_channels
,
num_heads
=
num_heads
,
type
=
"cross"
,
attn_mode
=
"full"
,
qkv_bias
=
qkv_bias
,
qk_rms_norm
=
qk_rms_norm_cross
,
)
self
.
mlp
=
FeedForwardNet
(
channels
,
mlp_ratio
=
mlp_ratio
,
)
def
_forward
(
self
,
x
:
torch
.
Tensor
,
context
:
torch
.
Tensor
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
h
=
self
.
norm1
(
x
)
h
=
self
.
self_attn
(
h
,
phases
=
phases
)
x
=
x
+
h
h
=
self
.
norm2
(
x
)
h
=
self
.
cross_attn
(
h
,
context
)
x
=
x
+
h
h
=
self
.
norm3
(
x
)
h
=
self
.
mlp
(
h
)
x
=
x
+
h
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
context
:
torch
.
Tensor
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
context
,
phases
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
,
context
,
phases
)
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/transformer/modulated.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
from
..attention
import
MultiHeadAttention
from
..norm
import
LayerNorm32
from
.blocks
import
FeedForwardNet
class
ModulatedTransformerBlock
(
nn
.
Module
):
"""
Transformer block (MSA + FFN) with adaptive layer norm conditioning.
"""
def
__init__
(
self
,
channels
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
Literal
[
"full"
,
"windowed"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
Tuple
[
int
,
int
,
int
]]
=
None
,
use_checkpoint
:
bool
=
False
,
use_rope
:
bool
=
False
,
rope_freq
:
Tuple
[
int
,
int
]
=
(
1.0
,
10000.0
),
qk_rms_norm
:
bool
=
False
,
qkv_bias
:
bool
=
True
,
share_mod
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_checkpoint
=
use_checkpoint
self
.
share_mod
=
share_mod
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
attn
=
MultiHeadAttention
(
channels
,
num_heads
=
num_heads
,
attn_mode
=
attn_mode
,
window_size
=
window_size
,
shift_window
=
shift_window
,
qkv_bias
=
qkv_bias
,
use_rope
=
use_rope
,
rope_freq
=
rope_freq
,
qk_rms_norm
=
qk_rms_norm
,
)
self
.
mlp
=
FeedForwardNet
(
channels
,
mlp_ratio
=
mlp_ratio
,
)
if
not
share_mod
:
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
channels
,
6
*
channels
,
bias
=
True
)
)
else
:
self
.
modulation
=
nn
.
Parameter
(
torch
.
randn
(
6
*
channels
)
/
channels
**
0.5
)
def
_forward
(
self
,
x
:
torch
.
Tensor
,
mod
:
torch
.
Tensor
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
share_mod
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
modulation
+
mod
).
type
(
mod
.
dtype
).
chunk
(
6
,
dim
=
1
)
else
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
adaLN_modulation
(
mod
).
chunk
(
6
,
dim
=
1
)
h
=
self
.
norm1
(
x
)
h
=
h
*
(
1
+
scale_msa
.
unsqueeze
(
1
))
+
shift_msa
.
unsqueeze
(
1
)
h
=
self
.
attn
(
h
,
phases
=
phases
)
h
=
h
*
gate_msa
.
unsqueeze
(
1
)
x
=
x
+
h
h
=
self
.
norm2
(
x
)
h
=
h
*
(
1
+
scale_mlp
.
unsqueeze
(
1
))
+
shift_mlp
.
unsqueeze
(
1
)
h
=
self
.
mlp
(
h
)
h
=
h
*
gate_mlp
.
unsqueeze
(
1
)
x
=
x
+
h
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
mod
:
torch
.
Tensor
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
mod
,
phases
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
,
mod
,
phases
)
class
ModulatedTransformerCrossBlock
(
nn
.
Module
):
"""
Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
"""
def
__init__
(
self
,
channels
:
int
,
ctx_channels
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
attn_mode
:
Literal
[
"full"
,
"windowed"
]
=
"full"
,
window_size
:
Optional
[
int
]
=
None
,
shift_window
:
Optional
[
Tuple
[
int
,
int
,
int
]]
=
None
,
use_checkpoint
:
bool
=
False
,
use_rope
:
bool
=
False
,
rope_freq
:
Tuple
[
int
,
int
]
=
(
1.0
,
10000.0
),
qk_rms_norm
:
bool
=
False
,
qk_rms_norm_cross
:
bool
=
False
,
qkv_bias
:
bool
=
True
,
share_mod
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_checkpoint
=
use_checkpoint
self
.
share_mod
=
share_mod
self
.
norm1
=
LayerNorm32
(
channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
norm2
=
LayerNorm32
(
channels
,
elementwise_affine
=
True
,
eps
=
1e-6
)
self
.
norm3
=
LayerNorm32
(
channels
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
self_attn
=
MultiHeadAttention
(
channels
,
num_heads
=
num_heads
,
type
=
"self"
,
attn_mode
=
attn_mode
,
window_size
=
window_size
,
shift_window
=
shift_window
,
qkv_bias
=
qkv_bias
,
use_rope
=
use_rope
,
rope_freq
=
rope_freq
,
qk_rms_norm
=
qk_rms_norm
,
)
self
.
cross_attn
=
MultiHeadAttention
(
channels
,
ctx_channels
=
ctx_channels
,
num_heads
=
num_heads
,
type
=
"cross"
,
attn_mode
=
"full"
,
qkv_bias
=
qkv_bias
,
qk_rms_norm
=
qk_rms_norm_cross
,
)
self
.
mlp
=
FeedForwardNet
(
channels
,
mlp_ratio
=
mlp_ratio
,
)
if
not
share_mod
:
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
channels
,
6
*
channels
,
bias
=
True
)
)
else
:
self
.
modulation
=
nn
.
Parameter
(
torch
.
randn
(
6
*
channels
)
/
channels
**
0.5
)
def
_forward
(
self
,
x
:
torch
.
Tensor
,
mod
:
torch
.
Tensor
,
context
:
torch
.
Tensor
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
share_mod
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
modulation
+
mod
).
type
(
mod
.
dtype
).
chunk
(
6
,
dim
=
1
)
else
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
adaLN_modulation
(
mod
).
chunk
(
6
,
dim
=
1
)
h
=
self
.
norm1
(
x
)
h
=
h
*
(
1
+
scale_msa
.
unsqueeze
(
1
))
+
shift_msa
.
unsqueeze
(
1
)
h
=
self
.
self_attn
(
h
,
phases
=
phases
)
h
=
h
*
gate_msa
.
unsqueeze
(
1
)
x
=
x
+
h
h
=
self
.
norm2
(
x
)
h
=
self
.
cross_attn
(
h
,
context
)
x
=
x
+
h
h
=
self
.
norm3
(
x
)
h
=
h
*
(
1
+
scale_mlp
.
unsqueeze
(
1
))
+
shift_mlp
.
unsqueeze
(
1
)
h
=
self
.
mlp
(
h
)
h
=
h
*
gate_mlp
.
unsqueeze
(
1
)
x
=
x
+
h
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
mod
:
torch
.
Tensor
,
context
:
torch
.
Tensor
,
phases
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
use_checkpoint
:
return
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
_forward
,
x
,
mod
,
context
,
phases
,
use_reentrant
=
False
)
else
:
return
self
.
_forward
(
x
,
mod
,
context
,
phases
)
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/utils.py
0 → 100644
View file @
f05e915f
import
torch
import
torch.nn
as
nn
from
..modules
import
sparse
as
sp
MIX_PRECISION_MODULES
=
(
nn
.
Conv1d
,
nn
.
Conv2d
,
nn
.
Conv3d
,
nn
.
ConvTranspose1d
,
nn
.
ConvTranspose2d
,
nn
.
ConvTranspose3d
,
nn
.
Linear
,
sp
.
SparseConv3d
,
sp
.
SparseInverseConv3d
,
sp
.
SparseLinear
,
)
def
convert_module_to_f16
(
l
):
"""
Convert primitive modules to float16.
"""
if
isinstance
(
l
,
MIX_PRECISION_MODULES
):
for
p
in
l
.
parameters
():
p
.
data
=
p
.
data
.
half
()
def
convert_module_to_bf16
(
l
):
"""
Convert primitive modules to bfloat16.
"""
if
isinstance
(
l
,
MIX_PRECISION_MODULES
):
for
p
in
l
.
parameters
():
p
.
data
=
p
.
data
.
bfloat16
()
def
convert_module_to_f32
(
l
):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if
isinstance
(
l
,
MIX_PRECISION_MODULES
):
for
p
in
l
.
parameters
():
p
.
data
=
p
.
data
.
float
()
def
convert_module_to
(
l
,
dtype
):
"""
Convert primitive modules to the given dtype.
"""
if
isinstance
(
l
,
MIX_PRECISION_MODULES
):
for
p
in
l
.
parameters
():
p
.
data
=
p
.
data
.
to
(
dtype
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
def
scale_module
(
module
,
scale
):
"""
Scale the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
mul_
(
scale
)
return
module
def
modulate
(
x
,
shift
,
scale
):
return
x
*
(
1
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
def
manual_cast
(
tensor
,
dtype
):
"""
Cast if autocast is not enabled.
"""
if
not
torch
.
is_autocast_enabled
():
return
tensor
.
type
(
dtype
)
return
tensor
def
str_to_dtype
(
dtype_str
:
str
):
return
{
'f16'
:
torch
.
float16
,
'fp16'
:
torch
.
float16
,
'float16'
:
torch
.
float16
,
'bf16'
:
torch
.
bfloat16
,
'bfloat16'
:
torch
.
bfloat16
,
'f32'
:
torch
.
float32
,
'fp32'
:
torch
.
float32
,
'float32'
:
torch
.
float32
,
}[
dtype_str
]
TRELLIS.2_DCU/trellis2/pipelines/__init__.py
0 → 100644
View file @
f05e915f
import
importlib
__attributes
=
{
"Trellis2ImageTo3DPipeline"
:
"trellis2_image_to_3d"
,
"Trellis2TexturingPipeline"
:
"trellis2_texturing"
,
}
__submodules
=
[
'samplers'
,
'rembg'
]
__all__
=
list
(
__attributes
.
keys
())
+
__submodules
def
__getattr__
(
name
):
if
name
not
in
globals
():
if
name
in
__attributes
:
module_name
=
__attributes
[
name
]
module
=
importlib
.
import_module
(
f
".
{
module_name
}
"
,
__name__
)
globals
()[
name
]
=
getattr
(
module
,
name
)
elif
name
in
__submodules
:
module
=
importlib
.
import_module
(
f
".
{
name
}
"
,
__name__
)
globals
()[
name
]
=
module
else
:
raise
AttributeError
(
f
"module
{
__name__
}
has no attribute
{
name
}
"
)
return
globals
()[
name
]
def
from_pretrained
(
path
:
str
):
"""
Load a pipeline from a model folder or a Hugging Face model hub.
Args:
path: The path to the model. Can be either local path or a Hugging Face model name.
"""
import
os
import
json
is_local
=
os
.
path
.
exists
(
f
"
{
path
}
/pipeline.json"
)
if
is_local
:
config_file
=
f
"
{
path
}
/pipeline.json"
else
:
from
huggingface_hub
import
hf_hub_download
config_file
=
hf_hub_download
(
path
,
"pipeline.json"
)
with
open
(
config_file
,
'r'
)
as
f
:
config
=
json
.
load
(
f
)
return
globals
()[
config
[
'name'
]].
from_pretrained
(
path
)
# For PyLance
if
__name__
==
'__main__'
:
from
.
import
samplers
,
rembg
from
.trellis2_image_to_3d
import
Trellis2ImageTo3DPipeline
from
.trellis2_texturing
import
Trellis2TexturingPipeline
TRELLIS.2_DCU/trellis2/pipelines/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/__pycache__/base.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/__pycache__/trellis2_image_to_3d.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/base.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
from
..
import
models
class
Pipeline
:
"""
A base class for pipelines.
"""
def
__init__
(
self
,
models
:
dict
[
str
,
nn
.
Module
]
=
None
,
):
if
models
is
None
:
return
self
.
models
=
models
for
model
in
self
.
models
.
values
():
model
.
eval
()
@
classmethod
def
from_pretrained
(
cls
,
path
:
str
,
config_file
:
str
=
"pipeline.json"
)
->
"Pipeline"
:
"""
Load a pretrained model.
"""
import
os
import
json
is_local
=
os
.
path
.
exists
(
f
"
{
path
}
/
{
config_file
}
"
)
if
is_local
:
config_file
=
f
"
{
path
}
/
{
config_file
}
"
else
:
from
huggingface_hub
import
hf_hub_download
config_file
=
hf_hub_download
(
path
,
config_file
)
with
open
(
config_file
,
'r'
)
as
f
:
args
=
json
.
load
(
f
)[
'args'
]
_models
=
{}
for
k
,
v
in
args
[
'models'
].
items
():
if
hasattr
(
cls
,
'model_names_to_load'
)
and
k
not
in
cls
.
model_names_to_load
:
continue
try
:
_models
[
k
]
=
models
.
from_pretrained
(
f
"
{
path
}
/
{
v
}
"
)
except
Exception
as
e
:
_models
[
k
]
=
models
.
from_pretrained
(
v
)
new_pipeline
=
cls
(
_models
)
new_pipeline
.
_pretrained_args
=
args
return
new_pipeline
@
property
def
device
(
self
)
->
torch
.
device
:
if
hasattr
(
self
,
'_device'
):
return
self
.
_device
for
model
in
self
.
models
.
values
():
if
hasattr
(
model
,
'device'
):
return
model
.
device
for
model
in
self
.
models
.
values
():
if
hasattr
(
model
,
'parameters'
):
return
next
(
model
.
parameters
()).
device
raise
RuntimeError
(
"No device found."
)
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
for
model
in
self
.
models
.
values
():
model
.
to
(
device
)
def
cuda
(
self
)
->
None
:
self
.
to
(
torch
.
device
(
"cuda"
))
def
cpu
(
self
)
->
None
:
self
.
to
(
torch
.
device
(
"cpu"
))
\ No newline at end of file
TRELLIS.2_DCU/trellis2/pipelines/rembg/BiRefNet.py
0 → 100644
View file @
f05e915f
from
typing
import
*
from
transformers
import
AutoModelForImageSegmentation
import
torch
from
torchvision
import
transforms
from
PIL
import
Image
class
BiRefNet
:
def
__init__
(
self
,
model_name
:
str
=
"ZhengPeng7/BiRefNet"
):
# transformers 5.x calls all_tied_weights_keys.keys() during model loading,
# but BiRefNet (trust_remote_code) was written for older transformers and doesn't
# define this attribute. Patch the base class before loading.
from
transformers
import
PreTrainedModel
if
not
hasattr
(
PreTrainedModel
,
'_trellis2_patched'
):
_method_name
=
'_move_missing_keys_from_meta_to_device'
if
hasattr
(
PreTrainedModel
,
'_move_missing_keys_from_meta_to_device'
)
else
'_move_missing_keys_from_meta_to_cpu'
_orig
=
getattr
(
PreTrainedModel
,
_method_name
)
def
_patched
(
self_model
,
*
args
,
**
kwargs
):
if
not
hasattr
(
self_model
,
'all_tied_weights_keys'
):
self_model
.
all_tied_weights_keys
=
{}
return
_orig
(
self_model
,
*
args
,
**
kwargs
)
setattr
(
PreTrainedModel
,
_method_name
,
_patched
)
PreTrainedModel
.
_trellis2_patched
=
True
self
.
model
=
AutoModelForImageSegmentation
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
self
.
model
.
eval
()
self
.
transform_image
=
transforms
.
Compose
(
[
transforms
.
Resize
((
1024
,
1024
)),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
]),
]
)
def
to
(
self
,
device
:
str
):
self
.
model
.
to
(
device
)
def
cuda
(
self
):
self
.
model
.
cuda
()
def
cpu
(
self
):
self
.
model
.
cpu
()
def
__call__
(
self
,
image
:
Image
.
Image
)
->
Image
.
Image
:
image_size
=
image
.
size
input_images
=
self
.
transform_image
(
image
).
unsqueeze
(
0
).
to
(
"cuda"
)
# Prediction
with
torch
.
no_grad
():
preds
=
self
.
model
(
input_images
)[
-
1
].
sigmoid
().
cpu
()
pred
=
preds
[
0
].
squeeze
()
pred_pil
=
transforms
.
ToPILImage
()(
pred
)
mask
=
pred_pil
.
resize
(
image_size
)
image
.
putalpha
(
mask
)
return
image
TRELLIS.2_DCU/trellis2/pipelines/rembg/__init__.py
0 → 100644
View file @
f05e915f
from
.BiRefNet
import
*
TRELLIS.2_DCU/trellis2/pipelines/rembg/__pycache__/BiRefNet.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/rembg/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/pipelines/samplers/__init__.py
0 → 100644
View file @
f05e915f
from
.base
import
Sampler
from
.flow_euler
import
(
FlowEulerSampler
,
FlowEulerCfgSampler
,
FlowEulerGuidanceIntervalSampler
,
)
\ No newline at end of file
Prev
1
…
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment