Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
TRELLIS.2
Commits
f05e915f
Commit
f05e915f
authored
May 27, 2026
by
weishb
Browse files
首次提交
parent
297bf637
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
913 additions
and
0 deletions
+913
-0
TRELLIS.2_DCU/trellis2/modules/sparse/conv/__pycache__/conv_flex_gemm.cpython-310.pyc
...es/sparse/conv/__pycache__/conv_flex_gemm.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/sparse/conv/config.py
TRELLIS.2_DCU/trellis2/modules/sparse/conv/config.py
+3
-0
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv.py
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv.py
+30
-0
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_flex_gemm.py
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_flex_gemm.py
+136
-0
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_none.py
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_none.py
+293
-0
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_spconv.py
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_spconv.py
+73
-0
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_torchsparse.py
...IS.2_DCU/trellis2/modules/sparse/conv/conv_torchsparse.py
+30
-0
TRELLIS.2_DCU/trellis2/modules/sparse/linear.py
TRELLIS.2_DCU/trellis2/modules/sparse/linear.py
+43
-0
TRELLIS.2_DCU/trellis2/modules/sparse/nonlinearity.py
TRELLIS.2_DCU/trellis2/modules/sparse/nonlinearity.py
+35
-0
TRELLIS.2_DCU/trellis2/modules/sparse/norm.py
TRELLIS.2_DCU/trellis2/modules/sparse/norm.py
+64
-0
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__init__.py
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__init__.py
+2
-0
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__pycache__/__init__.cpython-310.pyc
...dules/sparse/spatial/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__pycache__/basic.cpython-310.pyc
.../modules/sparse/spatial/__pycache__/basic.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__pycache__/spatial2channel.cpython-310.pyc
...parse/spatial/__pycache__/spatial2channel.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/basic.py
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/basic.py
+109
-0
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/spatial2channel.py
....2_DCU/trellis2/modules/sparse/spatial/spatial2channel.py
+93
-0
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/__init__.py
...LIS.2_DCU/trellis2/modules/sparse/transformer/__init__.py
+2
-0
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc
...s/sparse/transformer/__pycache__/__init__.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc
...les/sparse/transformer/__pycache__/blocks.cpython-310.pyc
+0
-0
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc
.../sparse/transformer/__pycache__/modulated.cpython-310.pyc
+0
-0
No files found.
TRELLIS.2_DCU/trellis2/modules/sparse/conv/__pycache__/conv_flex_gemm.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/sparse/conv/config.py
0 → 100644
View file @
f05e915f
SPCONV_ALGO
=
'auto'
# 'auto', 'implicit_gemm', 'native'
FLEX_GEMM_ALGO
=
'masked_implicit_gemm'
# 'explicit_gemm', 'implicit_gemm', 'implicit_gemm_splitk', 'masked_implicit_gemm', 'masked_implicit_gemm_splitk'
FLEX_GEMM_HASHMAP_RATIO
=
2.0
# Ratio of hashmap size to input size
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv.py
0 → 100644
View file @
f05e915f
from
..
import
config
import
importlib
import
torch
import
torch.nn
as
nn
from
..
import
SparseTensor
_backends
=
{}
class
SparseConv3d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
padding
=
None
,
bias
=
True
,
indice_key
=
None
):
super
(
SparseConv3d
,
self
).
__init__
()
if
config
.
CONV
not
in
_backends
:
_backends
[
config
.
CONV
]
=
importlib
.
import_module
(
f
'..conv_
{
config
.
CONV
}
'
,
__name__
)
_backends
[
config
.
CONV
].
sparse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
dilation
,
padding
,
bias
,
indice_key
)
def
forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
return
_backends
[
config
.
CONV
].
sparse_conv3d_forward
(
self
,
x
)
class
SparseInverseConv3d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
bias
=
True
,
indice_key
=
None
):
super
(
SparseInverseConv3d
,
self
).
__init__
()
if
config
.
CONV
not
in
_backends
:
_backends
[
config
.
CONV
]
=
importlib
.
import_module
(
f
'..conv_
{
config
.
CONV
}
'
,
__name__
)
_backends
[
config
.
CONV
].
sparse_inverse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
dilation
,
bias
,
indice_key
)
def
forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
return
_backends
[
config
.
CONV
].
sparse_inverse_conv3d_forward
(
self
,
x
)
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_flex_gemm.py
0 → 100644
View file @
f05e915f
import
math
import
torch
import
torch.nn
as
nn
from
..
import
SparseTensor
from
.
import
config
from
..
import
config
as
sparse_config
from
..linear
import
ROCM_SAFE_CHUNK
import
flex_gemm
from
flex_gemm.ops.spconv
import
sparse_submanifold_conv3d
from
flex_gemm.ops.spconv.submanifold_conv3d
import
SubMConv3dFunction
,
SubMConv3dNeighborCache
from
flex_gemm.ops
import
utils
as
flex_utils
import
flex_gemm.kernels
as
flex_kernels
def
sparse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
padding
=
None
,
bias
=
True
,
indice_key
=
None
):
assert
stride
==
1
and
(
padding
is
None
),
'Currently flex_gemm implementation only support submanifold sparse convolution (stride=1, padding=None)'
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
tuple
(
kernel_size
)
if
isinstance
(
kernel_size
,
(
list
,
tuple
))
else
(
kernel_size
,
)
*
3
self
.
stride
=
tuple
(
stride
)
if
isinstance
(
stride
,
(
list
,
tuple
))
else
(
stride
,
)
*
3
self
.
dilation
=
tuple
(
dilation
)
if
isinstance
(
dilation
,
(
list
,
tuple
))
else
(
dilation
,
)
*
3
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
out_channels
,
in_channels
,
*
self
.
kernel_size
)))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
out_channels
))
else
:
self
.
register_parameter
(
"bias"
,
None
)
# initialize parameters
torch
.
nn
.
init
.
kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
5
))
if
self
.
bias
is
not
None
:
fan_in
,
_
=
torch
.
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
)
if
fan_in
!=
0
:
bound
=
1
/
math
.
sqrt
(
fan_in
)
torch
.
nn
.
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
# Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci)
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
())
def
_sparse_conv3d_explicit_gemm_chunked
(
feats
,
neighbor_map
,
weight
,
bias
,
N
,
V
,
Co
,
Ci
):
"""
Chunked explicit-GEMM sparse conv: im2col + torch.mm in ROCM_SAFE_CHUNK-sized pieces.
Avoids the flex_gemm Triton kernel for large N on ROCm GFX1201.
"""
# weight: [Co, V, Ci] (reshaped from [Co, Kd, Kh, Kw, Ci])
# neighbor_map: [N, V] uint32 - 0xffffffff means no neighbor
weight_2d
=
weight
.
view
(
Co
,
V
*
Ci
).
t
().
contiguous
()
# [V*Ci, Co]
output
=
torch
.
zeros
(
N
,
Co
,
device
=
feats
.
device
,
dtype
=
feats
.
dtype
)
for
s
in
range
(
0
,
N
,
ROCM_SAFE_CHUNK
):
e
=
min
(
s
+
ROCM_SAFE_CHUNK
,
N
)
chunk_size
=
e
-
s
nm
=
neighbor_map
[
s
:
e
].
long
()
# [chunk, V]
# im2col: [chunk, V*Ci]
im2col
=
torch
.
zeros
(
chunk_size
*
V
,
Ci
,
device
=
feats
.
device
,
dtype
=
feats
.
dtype
)
flat_nm
=
nm
.
view
(
-
1
)
# [chunk*V]
valid
=
flat_nm
!=
0xffffffff
# clamp invalid indices to 0 to avoid index-out-of-bounds, then mask
safe_nm
=
flat_nm
.
clone
()
safe_nm
[
~
valid
]
=
0
im2col
[
valid
]
=
feats
[
safe_nm
[
valid
]]
im2col
=
im2col
.
view
(
chunk_size
,
V
*
Ci
)
# GEMM: [chunk, V*Ci] @ [V*Ci, Co] -> [chunk, Co]
output
[
s
:
e
]
=
torch
.
mm
(
im2col
,
weight_2d
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
def
sparse_conv3d_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
flex_gemm
.
ops
.
spconv
.
set_algorithm
(
config
.
FLEX_GEMM_ALGO
)
flex_gemm
.
ops
.
spconv
.
set_hashmap_ratio
(
config
.
FLEX_GEMM_HASHMAP_RATIO
)
Co
,
Kd
,
Kh
,
Kw
,
Ci
=
self
.
weight
.
shape
N
=
x
.
feats
.
shape
[
0
]
V
=
Kd
*
Kh
*
Kw
neighbor_cache_key
=
f
'SubMConv3d_neighbor_cache_
{
Kw
}
x
{
Kh
}
x
{
Kd
}
_dilation
{
self
.
dilation
}
'
neighbor_cache
=
x
.
get_spatial_cache
(
neighbor_cache_key
)
# ROCm safe spconv: build neighbor map normally, then use chunked torch.mm instead of Triton
if
sparse_config
.
ROCM_SAFE_SPCONV
and
N
>
ROCM_SAFE_CHUNK
:
from
flex_gemm.ops.spconv.submanifold_conv3d
import
SubMConv3dFunction
from
flex_gemm.ops
import
utils
as
flex_utils
import
flex_gemm.kernels
as
flex_kernels
from
flex_gemm.ops.spconv
import
Algorithm
if
neighbor_cache
is
None
:
# Build neighbor map using the HIP hash kernel (small/fast operation)
hashmap_keys
,
hashmap_vals
=
flex_utils
.
init_hashmap
(
torch
.
Size
([
*
x
.
shape
,
*
x
.
spatial_shape
]),
int
(
config
.
FLEX_GEMM_HASHMAP_RATIO
*
N
),
x
.
feats
.
device
,
)
neighbor_map
=
flex_kernels
.
cuda
.
hashmap_build_submanifold_conv_neighbour_map_cuda
(
hashmap_keys
,
hashmap_vals
,
x
.
coords
,
x
.
spatial_shape
[
0
],
x
.
spatial_shape
[
1
],
x
.
spatial_shape
[
2
],
Kw
,
Kh
,
Kd
,
self
.
dilation
[
0
],
self
.
dilation
[
1
],
self
.
dilation
[
2
],
)
# Store minimal cache so we skip rebuild next call
from
flex_gemm.ops.spconv.submanifold_conv3d
import
SubMConv3dNeighborCache
neighbor_cache_
=
SubMConv3dNeighborCache
(
neighbor_map
=
neighbor_map
)
x
.
register_spatial_cache
(
neighbor_cache_key
,
neighbor_cache_
)
else
:
neighbor_map
=
neighbor_cache
[
'neighbor_map'
]
weight_flat
=
self
.
weight
.
reshape
(
Co
,
V
,
Ci
)
out
=
_sparse_conv3d_explicit_gemm_chunked
(
x
.
feats
,
neighbor_map
,
weight_flat
,
self
.
bias
,
N
,
V
,
Co
,
Ci
)
print
(
f
"[ROCM_SAFE_SPCONV] N=
{
N
}
used chunked explicit GEMM (V=
{
V
}
)"
)
return
x
.
replace
(
out
)
# Normal path: flex_gemm Triton kernel
out
,
neighbor_cache_
=
sparse_submanifold_conv3d
(
x
.
feats
,
x
.
coords
,
torch
.
Size
([
*
x
.
shape
,
*
x
.
spatial_shape
]),
self
.
weight
,
self
.
bias
,
neighbor_cache
,
self
.
dilation
)
if
neighbor_cache
is
None
:
x
.
register_spatial_cache
(
neighbor_cache_key
,
neighbor_cache_
)
return
x
.
replace
(
out
)
def
sparse_inverse_conv3d_init
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'SparseInverseConv3d with flex_gemm is not implemented yet'
)
def
sparse_inverse_conv3d_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
raise
NotImplementedError
(
'SparseInverseConv3d with flex_gemm is not implemented yet'
)
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_none.py
0 → 100644
View file @
f05e915f
"""
Native PyTorch sparse convolution implementation.
No CUDA kernels, no Triton - works on any PyTorch backend (CUDA, ROCm, CPU).
MUCH SLOWER than optimized backends but numerically stable.
Use for debugging or when other backends fail.
"""
import
math
import
torch
import
torch.nn
as
nn
from
..
import
SparseTensor
def
sparse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
padding
=
None
,
bias
=
True
,
indice_key
=
None
):
"""
Initialize sparse 3D convolution layer.
"""
assert
stride
==
1
and
padding
is
None
,
\
'Native implementation only supports submanifold sparse convolution (stride=1, padding=None)'
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
tuple
(
kernel_size
)
if
isinstance
(
kernel_size
,
(
list
,
tuple
))
else
(
kernel_size
,)
*
3
self
.
stride
=
tuple
(
stride
)
if
isinstance
(
stride
,
(
list
,
tuple
))
else
(
stride
,)
*
3
self
.
dilation
=
tuple
(
dilation
)
if
isinstance
(
dilation
,
(
list
,
tuple
))
else
(
dilation
,)
*
3
# Weight shape: (out_channels, kernel_d, kernel_h, kernel_w, in_channels)
# Matches FlexGEMM's layout for compatibility
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
out_channels
,
*
self
.
kernel_size
,
in_channels
)))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
out_channels
))
else
:
self
.
register_parameter
(
"bias"
,
None
)
# Initialize parameters
torch
.
nn
.
init
.
kaiming_uniform_
(
self
.
weight
.
view
(
out_channels
,
-
1
,
in_channels
),
a
=
math
.
sqrt
(
5
))
if
self
.
bias
is
not
None
:
fan_in
=
in_channels
*
math
.
prod
(
self
.
kernel_size
)
if
fan_in
!=
0
:
bound
=
1
/
math
.
sqrt
(
fan_in
)
torch
.
nn
.
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
_build_coord_map
(
coords
):
"""Build hash map from coordinates to indices."""
coord_map
=
{}
for
i
in
range
(
coords
.
shape
[
0
]):
key
=
(
int
(
coords
[
i
,
0
].
item
()),
int
(
coords
[
i
,
1
].
item
()),
int
(
coords
[
i
,
2
].
item
()),
int
(
coords
[
i
,
3
].
item
()))
coord_map
[
key
]
=
i
return
coord_map
def
sparse_conv3d_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
"""
Forward pass for native sparse 3D convolution.
Uses precomputed neighbor cache for efficiency.
"""
coords
=
x
.
coords
# [N, 4] - (batch_idx, x, y, z)
feats
=
x
.
feats
# [N, C_in]
N
=
coords
.
shape
[
0
]
C_out
=
self
.
weight
.
shape
[
0
]
Kd
,
Kh
,
Kw
=
self
.
kernel_size
dk
,
dh
,
dw
=
self
.
dilation
device
=
feats
.
device
dtype
=
feats
.
dtype
# Center offsets
kd_c
,
kh_c
,
kw_c
=
Kd
//
2
,
Kh
//
2
,
Kw
//
2
# Check for cached neighbor list
cache_key
=
f
'NativeConv3d_neighbors_
{
Kw
}
x
{
Kh
}
x
{
Kd
}
_d
{
dk
}{
dh
}{
dw
}
'
neighbor_cache
=
x
.
get_spatial_cache
(
cache_key
)
if
neighbor_cache
is
None
:
# Build coordinate map
coord_map
=
_build_coord_map
(
coords
)
# Build neighbor lists for each voxel
# neighbor_list[i] = [(kernel_idx, neighbor_feat_idx), ...]
neighbor_lists
=
[]
for
i
in
range
(
N
):
b
=
int
(
coords
[
i
,
0
].
item
())
cx
,
cy
,
cz
=
int
(
coords
[
i
,
1
].
item
()),
int
(
coords
[
i
,
2
].
item
()),
int
(
coords
[
i
,
3
].
item
())
neighbors
=
[]
for
kd
in
range
(
Kd
):
for
kh
in
range
(
Kh
):
for
kw
in
range
(
Kw
):
nx
=
cx
+
(
kd
-
kd_c
)
*
dk
ny
=
cy
+
(
kh
-
kh_c
)
*
dh
nz
=
cz
+
(
kw
-
kw_c
)
*
dw
key
=
(
b
,
nx
,
ny
,
nz
)
if
key
in
coord_map
:
# Store as flat kernel index and neighbor index
k_idx
=
kd
*
Kh
*
Kw
+
kh
*
Kw
+
kw
neighbors
.
append
((
k_idx
,
coord_map
[
key
]))
neighbor_lists
.
append
(
neighbors
)
neighbor_cache
=
neighbor_lists
x
.
register_spatial_cache
(
cache_key
,
neighbor_cache
)
# Compute output
output
=
torch
.
zeros
(
N
,
C_out
,
device
=
device
,
dtype
=
dtype
)
# Flatten weight for faster indexing: [Kd*Kh*Kw, C_out, C_in]
weight_flat
=
self
.
weight
.
view
(
-
1
,
C_out
,
self
.
in_channels
)
for
i
in
range
(
N
):
for
k_idx
,
n_idx
in
neighbor_cache
[
i
]:
# weight_flat[k_idx] has shape [C_out, C_in]
output
[
i
]
+=
feats
[
n_idx
]
@
weight_flat
[
k_idx
].
T
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
return
x
.
replace
(
output
)
def
sparse_inverse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
bias
=
True
,
indice_key
=
None
):
"""
Initialize sparse inverse 3D convolution (transposed/deconvolution).
This is used in the decoder for upsampling.
"""
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
tuple
(
kernel_size
)
if
isinstance
(
kernel_size
,
(
list
,
tuple
))
else
(
kernel_size
,)
*
3
self
.
stride
=
tuple
(
stride
)
if
isinstance
(
stride
,
(
list
,
tuple
))
else
(
stride
,)
*
3
self
.
dilation
=
tuple
(
dilation
)
if
isinstance
(
dilation
,
(
list
,
tuple
))
else
(
dilation
,)
*
3
# Weight shape: (in_channels, kernel_d, kernel_h, kernel_w, out_channels)
# Note: For transposed conv, we swap in/out channels
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
in_channels
,
*
self
.
kernel_size
,
out_channels
)))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
out_channels
))
else
:
self
.
register_parameter
(
"bias"
,
None
)
# Initialize
torch
.
nn
.
init
.
kaiming_uniform_
(
self
.
weight
.
view
(
in_channels
,
-
1
,
out_channels
),
a
=
math
.
sqrt
(
5
))
if
self
.
bias
is
not
None
:
fan_in
=
in_channels
*
math
.
prod
(
self
.
kernel_size
)
if
fan_in
!=
0
:
bound
=
1
/
math
.
sqrt
(
fan_in
)
torch
.
nn
.
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
sparse_inverse_conv3d_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
"""
Forward pass for sparse inverse 3D convolution.
For inverse convolution, each input voxel scatters to multiple output positions.
This is essentially the transpose of the forward convolution.
NOTE: This implementation assumes stride=1 (no actual upsampling).
For stride>1 upsampling, the output coordinates would be different from input.
"""
coords
=
x
.
coords
# [N, 4]
feats
=
x
.
feats
# [N, C_in]
N
=
coords
.
shape
[
0
]
C_out
=
self
.
weight
.
shape
[
-
1
]
# out_channels is last dim for inverse conv
Kd
,
Kh
,
Kw
=
self
.
kernel_size
dk
,
dh
,
dw
=
self
.
dilation
device
=
feats
.
device
dtype
=
feats
.
dtype
kd_c
,
kh_c
,
kw_c
=
Kd
//
2
,
Kh
//
2
,
Kw
//
2
# Build coordinate map
coord_map
=
_build_coord_map
(
coords
)
# For stride=1 inverse conv, output has same coordinates as input
# Each output position accumulates from neighbors
output
=
torch
.
zeros
(
N
,
C_out
,
device
=
device
,
dtype
=
dtype
)
# Flatten weight: [Kd*Kh*Kw, C_in, C_out]
weight_flat
=
self
.
weight
.
view
(
-
1
,
self
.
in_channels
,
C_out
)
# For each input voxel, scatter to its neighbors
for
i
in
range
(
N
):
b
=
int
(
coords
[
i
,
0
].
item
())
cx
,
cy
,
cz
=
int
(
coords
[
i
,
1
].
item
()),
int
(
coords
[
i
,
2
].
item
()),
int
(
coords
[
i
,
3
].
item
())
for
kd
in
range
(
Kd
):
for
kh
in
range
(
Kh
):
for
kw
in
range
(
Kw
):
# Neighbor coordinate
nx
=
cx
+
(
kd
-
kd_c
)
*
dk
ny
=
cy
+
(
kh
-
kh_c
)
*
dh
nz
=
cz
+
(
kw
-
kw_c
)
*
dw
key
=
(
b
,
nx
,
ny
,
nz
)
if
key
in
coord_map
:
n_idx
=
coord_map
[
key
]
k_idx
=
kd
*
Kh
*
Kw
+
kh
*
Kw
+
kw
# Scatter: output[neighbor] += feats[i] @ weight[k_idx]
output
[
n_idx
]
+=
feats
[
i
]
@
weight_flat
[
k_idx
]
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
return
x
.
replace
(
output
)
# ============================================================================
# Vectorized implementation (faster but more memory)
# ============================================================================
def
sparse_conv3d_forward_vectorized
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
"""
Vectorized implementation using batch operations.
Faster than loop version but uses more memory.
"""
coords
=
x
.
coords
feats
=
x
.
feats
N
=
coords
.
shape
[
0
]
C_in
=
feats
.
shape
[
1
]
C_out
=
self
.
weight
.
shape
[
0
]
Kd
,
Kh
,
Kw
=
self
.
kernel_size
dk
,
dh
,
dw
=
self
.
dilation
device
=
feats
.
device
dtype
=
feats
.
dtype
kd_c
,
kh_c
,
kw_c
=
Kd
//
2
,
Kh
//
2
,
Kw
//
2
# Build coordinate map
coord_map
=
_build_coord_map
(
coords
)
# Build all neighbor pairs
src_indices
=
[]
# Input voxel index
dst_indices
=
[]
# Output voxel index (same as src for stride=1)
kernel_indices
=
[]
# Which kernel weight to use
for
i
in
range
(
N
):
b
=
int
(
coords
[
i
,
0
].
item
())
cx
,
cy
,
cz
=
int
(
coords
[
i
,
1
].
item
()),
int
(
coords
[
i
,
2
].
item
()),
int
(
coords
[
i
,
3
].
item
())
for
kd
in
range
(
Kd
):
for
kh
in
range
(
Kh
):
for
kw
in
range
(
Kw
):
nx
=
cx
+
(
kd
-
kd_c
)
*
dk
ny
=
cy
+
(
kh
-
kh_c
)
*
dh
nz
=
cz
+
(
kw
-
kw_c
)
*
dw
key
=
(
b
,
nx
,
ny
,
nz
)
if
key
in
coord_map
:
n_idx
=
coord_map
[
key
]
k_idx
=
kd
*
Kh
*
Kw
+
kh
*
Kw
+
kw
src_indices
.
append
(
n_idx
)
dst_indices
.
append
(
i
)
kernel_indices
.
append
(
k_idx
)
if
len
(
src_indices
)
==
0
:
# No neighbors found
output
=
torch
.
zeros
(
N
,
C_out
,
device
=
device
,
dtype
=
dtype
)
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
return
x
.
replace
(
output
)
# Convert to tensors
src_indices
=
torch
.
tensor
(
src_indices
,
device
=
device
,
dtype
=
torch
.
long
)
dst_indices
=
torch
.
tensor
(
dst_indices
,
device
=
device
,
dtype
=
torch
.
long
)
kernel_indices
=
torch
.
tensor
(
kernel_indices
,
device
=
device
,
dtype
=
torch
.
long
)
# Gather features: [num_pairs, C_in]
pair_feats
=
feats
[
src_indices
]
# Gather weights: [num_pairs, C_out, C_in]
weight_flat
=
self
.
weight
.
view
(
-
1
,
C_out
,
C_in
)
pair_weights
=
weight_flat
[
kernel_indices
]
# Compute contributions: [num_pairs, C_out]
# pair_feats @ pair_weights.T -> but we need batch matmul
contributions
=
torch
.
bmm
(
pair_weights
,
pair_feats
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# Scatter to output
output
=
torch
.
zeros
(
N
,
C_out
,
device
=
device
,
dtype
=
dtype
)
output
.
scatter_add_
(
0
,
dst_indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
C_out
),
contributions
)
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
return
x
.
replace
(
output
)
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_spconv.py
0 → 100644
View file @
f05e915f
import
torch
import
torch.nn
as
nn
from
..
import
SparseTensor
from
.
import
config
import
spconv.pytorch
as
spconv
def
sparse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
padding
=
None
,
bias
=
True
,
indice_key
=
None
):
algo
=
None
if
config
.
SPCONV_ALGO
==
'native'
:
algo
=
spconv
.
ConvAlgo
.
Native
elif
config
.
SPCONV_ALGO
==
'implicit_gemm'
:
algo
=
spconv
.
ConvAlgo
.
MaskImplicitGemm
if
stride
==
1
and
(
padding
is
None
):
self
.
conv
=
spconv
.
SubMConv3d
(
in_channels
,
out_channels
,
kernel_size
,
dilation
=
dilation
,
bias
=
bias
,
indice_key
=
indice_key
,
algo
=
algo
)
else
:
self
.
conv
=
spconv
.
SparseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
padding
=
padding
,
bias
=
bias
,
indice_key
=
indice_key
,
algo
=
algo
)
self
.
stride
=
tuple
(
stride
)
if
isinstance
(
stride
,
(
list
,
tuple
))
else
(
stride
,
stride
,
stride
)
self
.
padding
=
padding
def
sparse_conv3d_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
spatial_changed
=
any
(
s
!=
1
for
s
in
self
.
stride
)
or
(
self
.
padding
is
not
None
)
new_data
=
self
.
conv
(
x
.
data
)
new_shape
=
[
x
.
shape
[
0
],
self
.
conv
.
out_channels
]
new_layout
=
None
if
spatial_changed
else
x
.
layout
if
spatial_changed
and
(
x
.
shape
[
0
]
!=
1
):
# spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
fwd
=
new_data
.
indices
[:,
0
].
argsort
()
bwd
=
torch
.
zeros_like
(
fwd
).
scatter_
(
0
,
fwd
,
torch
.
arange
(
fwd
.
shape
[
0
],
device
=
fwd
.
device
))
sorted_feats
=
new_data
.
features
[
fwd
]
sorted_coords
=
new_data
.
indices
[
fwd
]
unsorted_data
=
new_data
new_data
=
spconv
.
SparseConvTensor
(
sorted_feats
,
sorted_coords
,
unsorted_data
.
spatial_shape
,
unsorted_data
.
batch_size
)
# type: ignore
out
=
SparseTensor
(
new_data
,
shape
=
torch
.
Size
(
new_shape
),
layout
=
new_layout
,
scale
=
tuple
([
s
*
stride
for
s
,
stride
in
zip
(
x
.
_scale
,
self
.
stride
)]),
spatial_cache
=
x
.
_spatial_cache
,
)
if
spatial_changed
and
(
x
.
shape
[
0
]
!=
1
):
out
.
register_spatial_cache
(
f
'conv_
{
self
.
stride
}
_unsorted_data'
,
unsorted_data
)
out
.
register_spatial_cache
(
f
'conv_
{
self
.
stride
}
_sort_bwd'
,
bwd
)
return
out
def
sparse_inverse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
bias
=
True
,
indice_key
=
None
):
self
.
conv
=
spconv
.
SparseInverseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
bias
,
indice_key
=
indice_key
)
self
.
stride
=
tuple
(
stride
)
if
isinstance
(
stride
,
(
list
,
tuple
))
else
(
stride
,
stride
,
stride
)
def
sparse_inverse_conv3d_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
spatial_changed
=
any
(
s
!=
1
for
s
in
self
.
stride
)
if
spatial_changed
:
# recover the original spconv order
data
=
x
.
get_spatial_cache
(
f
'conv_
{
self
.
stride
}
_unsorted_data'
)
bwd
=
x
.
get_spatial_cache
(
f
'conv_
{
self
.
stride
}
_sort_bwd'
)
data
=
data
.
replace_feature
(
x
.
feats
[
bwd
])
else
:
data
=
x
.
data
new_data
=
self
.
conv
(
data
)
new_shape
=
[
x
.
shape
[
0
],
self
.
conv
.
out_channels
]
new_layout
=
None
if
spatial_changed
else
x
.
layout
out
=
SparseTensor
(
new_data
,
shape
=
torch
.
Size
(
new_shape
),
layout
=
new_layout
,
scale
=
tuple
([
s
//
stride
for
s
,
stride
in
zip
(
x
.
_scale
,
self
.
stride
)]),
spatial_cache
=
x
.
_spatial_cache
,
)
return
out
TRELLIS.2_DCU/trellis2/modules/sparse/conv/conv_torchsparse.py
0 → 100644
View file @
f05e915f
import
torch
import
torch.nn
as
nn
from
..
import
SparseTensor
import
torchsparse
def
sparse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
padding
=
None
,
bias
=
True
,
indice_key
=
None
):
self
.
conv
=
torchsparse
.
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
0
,
dilation
,
bias
)
def
sparse_conv3d_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
out
=
self
.
conv
(
x
.
data
)
new_shape
=
[
x
.
shape
[
0
],
self
.
conv
.
out_channels
]
out
=
SparseTensor
(
out
,
shape
=
torch
.
Size
(
new_shape
),
layout
=
x
.
layout
if
all
(
s
==
1
for
s
in
self
.
conv
.
stride
)
else
None
)
out
.
_spatial_cache
=
x
.
_spatial_cache
out
.
_scale
=
tuple
([
s
*
stride
for
s
,
stride
in
zip
(
x
.
_scale
,
self
.
conv
.
stride
)])
return
out
def
sparse_inverse_conv3d_init
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
bias
=
True
,
indice_key
=
None
):
self
.
conv
=
torchsparse
.
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
0
,
dilation
,
bias
,
transposed
=
True
)
def
sparse_inverse_conv3d_forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
out
=
self
.
conv
(
x
.
data
)
new_shape
=
[
x
.
shape
[
0
],
self
.
conv
.
out_channels
]
out
=
SparseTensor
(
out
,
shape
=
torch
.
Size
(
new_shape
),
layout
=
x
.
layout
if
all
(
s
==
1
for
s
in
self
.
conv
.
stride
)
else
None
)
out
.
_spatial_cache
=
x
.
_spatial_cache
out
.
_scale
=
tuple
([
s
/
stride
for
s
,
stride
in
zip
(
x
.
_scale
,
self
.
conv
.
stride
)])
return
out
TRELLIS.2_DCU/trellis2/modules/sparse/linear.py
0 → 100644
View file @
f05e915f
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.
import
VarLenTensor
__all__
=
[
'SparseLinear'
,
'ROCM_SAFE_CHUNK'
,
'rocm_safe_linear'
,
]
# ROCm GFX1201 (RX 9070 XT) bug workaround:
# hipBLASLt and rocBLAS GEMM kernels corrupt memory (→ NaN) when N > ~800k
# for shapes like [N, K] @ [K, M] with small K/M. Chunking keeps each
# dispatch below the confirmed-safe threshold of 524288 rows.
ROCM_SAFE_CHUNK
=
524_288
def
rocm_safe_linear
(
feats
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
=
None
)
->
torch
.
Tensor
:
"""F.linear with ROCm large-N chunking workaround."""
N
=
feats
.
shape
[
0
]
if
N
<=
ROCM_SAFE_CHUNK
:
return
F
.
linear
(
feats
,
weight
,
bias
)
out
=
torch
.
empty
(
N
,
weight
.
shape
[
0
],
device
=
feats
.
device
,
dtype
=
feats
.
dtype
)
for
s
in
range
(
0
,
N
,
ROCM_SAFE_CHUNK
):
e
=
min
(
s
+
ROCM_SAFE_CHUNK
,
N
)
out
[
s
:
e
]
=
F
.
linear
(
feats
[
s
:
e
],
weight
,
bias
)
return
out
class
SparseLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
):
super
(
SparseLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
)
#def forward(self, input: VarLenTensor) -> VarLenTensor:
# return input.replace(super().forward(input.feats))
def
forward
(
self
,
input
):
feats
=
input
.
feats
if
hasattr
(
input
,
'feats'
)
else
input
out
=
rocm_safe_linear
(
feats
,
self
.
weight
,
self
.
bias
)
if
hasattr
(
input
,
'replace'
):
return
input
.
replace
(
out
)
return
out
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/sparse/nonlinearity.py
0 → 100644
View file @
f05e915f
import
torch
import
torch.nn
as
nn
from
.
import
VarLenTensor
__all__
=
[
'SparseReLU'
,
'SparseSiLU'
,
'SparseGELU'
,
'SparseActivation'
]
class
SparseReLU
(
nn
.
ReLU
):
def
forward
(
self
,
input
:
VarLenTensor
)
->
VarLenTensor
:
return
input
.
replace
(
super
().
forward
(
input
.
feats
))
class
SparseSiLU
(
nn
.
SiLU
):
def
forward
(
self
,
input
:
VarLenTensor
)
->
VarLenTensor
:
return
input
.
replace
(
super
().
forward
(
input
.
feats
))
class
SparseGELU
(
nn
.
GELU
):
def
forward
(
self
,
input
:
VarLenTensor
)
->
VarLenTensor
:
return
input
.
replace
(
super
().
forward
(
input
.
feats
))
class
SparseActivation
(
nn
.
Module
):
def
__init__
(
self
,
activation
:
nn
.
Module
):
super
().
__init__
()
self
.
activation
=
activation
def
forward
(
self
,
input
:
VarLenTensor
)
->
VarLenTensor
:
return
input
.
replace
(
self
.
activation
(
input
.
feats
))
TRELLIS.2_DCU/trellis2/modules/sparse/norm.py
0 → 100644
View file @
f05e915f
import
torch
import
torch.nn
as
nn
from
..utils
import
manual_cast
from
.
import
VarLenTensor
from
.
import
config
__all__
=
[
'SparseGroupNorm'
,
'SparseLayerNorm'
,
'SparseGroupNorm32'
,
'SparseLayerNorm32'
,
]
class
SparseGroupNorm
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
eps
=
1e-5
,
affine
=
True
):
super
(
SparseGroupNorm
,
self
).
__init__
(
num_groups
,
num_channels
,
eps
,
affine
)
def
forward
(
self
,
input
:
VarLenTensor
)
->
VarLenTensor
:
nfeats
=
torch
.
zeros_like
(
input
.
feats
)
for
k
in
range
(
input
.
shape
[
0
]):
bfeats
=
input
.
feats
[
input
.
layout
[
k
]]
bfeats
=
bfeats
.
permute
(
1
,
0
).
reshape
(
1
,
input
.
shape
[
1
],
-
1
)
bfeats
=
super
().
forward
(
bfeats
)
bfeats
=
bfeats
.
reshape
(
input
.
shape
[
1
],
-
1
).
permute
(
1
,
0
)
nfeats
[
input
.
layout
[
k
]]
=
bfeats
return
input
.
replace
(
nfeats
)
class
SparseLayerNorm
(
nn
.
LayerNorm
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
super
(
SparseLayerNorm
,
self
).
__init__
(
normalized_shape
,
eps
,
elementwise_affine
)
def
forward
(
self
,
input
:
VarLenTensor
)
->
VarLenTensor
:
nfeats
=
torch
.
zeros_like
(
input
.
feats
)
for
k
in
range
(
input
.
shape
[
0
]):
bfeats
=
input
.
feats
[
input
.
layout
[
k
]]
bfeats
=
bfeats
.
permute
(
1
,
0
).
reshape
(
1
,
input
.
shape
[
1
],
-
1
)
bfeats
=
super
().
forward
(
bfeats
)
bfeats
=
bfeats
.
reshape
(
input
.
shape
[
1
],
-
1
).
permute
(
1
,
0
)
nfeats
[
input
.
layout
[
k
]]
=
bfeats
return
input
.
replace
(
nfeats
)
class
SparseGroupNorm32
(
SparseGroupNorm
):
"""
A GroupNorm layer that converts to float32 before the forward pass.
"""
def
forward
(
self
,
x
:
VarLenTensor
)
->
VarLenTensor
:
x_dtype
=
x
.
dtype
x
=
manual_cast
(
x
,
torch
.
float32
)
o
=
super
().
forward
(
x
)
return
manual_cast
(
o
,
x_dtype
)
class
SparseLayerNorm32
(
SparseLayerNorm
):
"""
A LayerNorm layer that converts to float32 before the forward pass.
"""
def
forward
(
self
,
x
:
VarLenTensor
)
->
VarLenTensor
:
x_dtype
=
x
.
dtype
x
=
manual_cast
(
x
,
torch
.
float32
)
o
=
super
().
forward
(
x
)
return
manual_cast
(
o
,
x_dtype
)
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__init__.py
0 → 100644
View file @
f05e915f
from
.basic
import
*
from
.spatial2channel
import
*
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__pycache__/basic.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/__pycache__/spatial2channel.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/basic.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
from
..
import
SparseTensor
__all__
=
[
'SparseDownsample'
,
'SparseUpsample'
,
]
class
SparseDownsample
(
nn
.
Module
):
"""
Downsample a sparse tensor by a factor of `factor`.
Implemented as average pooling.
"""
def
__init__
(
self
,
factor
:
int
,
mode
:
Literal
[
'mean'
,
'max'
]
=
'mean'
):
super
(
SparseDownsample
,
self
).
__init__
()
self
.
factor
=
factor
self
.
mode
=
mode
assert
self
.
mode
in
[
'mean'
,
'max'
],
f
'Invalid mode:
{
self
.
mode
}
'
def
forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
cache
=
x
.
get_spatial_cache
(
f
'downsample_
{
self
.
factor
}
'
)
if
cache
is
None
:
DIM
=
x
.
coords
.
shape
[
-
1
]
-
1
coord
=
list
(
x
.
coords
.
unbind
(
dim
=-
1
))
for
i
in
range
(
DIM
):
coord
[
i
+
1
]
=
coord
[
i
+
1
]
//
self
.
factor
MAX
=
[(
s
+
self
.
factor
-
1
)
//
self
.
factor
for
s
in
x
.
spatial_shape
]
OFFSET
=
torch
.
cumprod
(
torch
.
tensor
(
MAX
[::
-
1
]),
0
).
tolist
()[::
-
1
]
+
[
1
]
code
=
sum
([
c
*
o
for
c
,
o
in
zip
(
coord
,
OFFSET
)])
code
,
idx
=
code
.
unique
(
return_inverse
=
True
)
new_coords
=
torch
.
stack
(
[
code
//
OFFSET
[
0
]]
+
[(
code
//
OFFSET
[
i
+
1
])
%
MAX
[
i
]
for
i
in
range
(
DIM
)],
dim
=-
1
)
else
:
new_coords
,
idx
=
cache
new_feats
=
torch
.
scatter_reduce
(
torch
.
zeros
(
new_coords
.
shape
[
0
],
x
.
feats
.
shape
[
1
],
device
=
x
.
feats
.
device
,
dtype
=
x
.
feats
.
dtype
),
dim
=
0
,
index
=
idx
.
unsqueeze
(
1
).
expand
(
-
1
,
x
.
feats
.
shape
[
1
]),
src
=
x
.
feats
,
reduce
=
self
.
mode
,
include_self
=
False
,
)
out
=
SparseTensor
(
new_feats
,
new_coords
,
x
.
_shape
)
out
.
_scale
=
tuple
([
s
*
self
.
factor
for
s
in
x
.
_scale
])
out
.
_spatial_cache
=
x
.
_spatial_cache
if
cache
is
None
:
x
.
register_spatial_cache
(
f
'downsample_
{
self
.
factor
}
'
,
(
new_coords
,
idx
))
out
.
register_spatial_cache
(
f
'upsample_
{
self
.
factor
}
'
,
(
x
.
coords
,
idx
))
out
.
register_spatial_cache
(
f
'shape'
,
torch
.
Size
(
MAX
))
if
self
.
training
:
subidx
=
x
.
coords
[:,
1
:]
%
self
.
factor
subidx
=
sum
([
subidx
[...,
i
]
*
self
.
factor
**
i
for
i
in
range
(
DIM
)])
subdivision
=
torch
.
zeros
((
new_coords
.
shape
[
0
],
self
.
factor
**
DIM
),
device
=
x
.
device
,
dtype
=
torch
.
bool
)
subdivision
[
idx
,
subidx
]
=
True
out
.
register_spatial_cache
(
f
'subdivision'
,
subdivision
)
return
out
class
SparseUpsample
(
nn
.
Module
):
"""
Upsample a sparse tensor by a factor of `factor`.
Implemented as nearest neighbor interpolation.
"""
def
__init__
(
self
,
factor
:
int
):
super
(
SparseUpsample
,
self
).
__init__
()
self
.
factor
=
factor
def
forward
(
self
,
x
:
SparseTensor
,
subdivision
:
Optional
[
SparseTensor
]
=
None
)
->
SparseTensor
:
DIM
=
x
.
coords
.
shape
[
-
1
]
-
1
cache
=
x
.
get_spatial_cache
(
f
'upsample_
{
self
.
factor
}
'
)
if
cache
is
None
:
if
subdivision
is
None
:
raise
ValueError
(
'Cache not found. Provide subdivision tensor or pair SparseUpsample with SparseDownsample.'
)
else
:
sub
=
subdivision
.
feats
N_leaf
=
sub
.
sum
(
dim
=-
1
)
subidx
=
sub
.
nonzero
()[:,
-
1
]
new_coords
=
x
.
coords
.
clone
().
detach
()
new_coords
[:,
1
:]
*=
self
.
factor
new_coords
=
torch
.
repeat_interleave
(
new_coords
,
N_leaf
,
dim
=
0
,
output_size
=
subidx
.
shape
[
0
])
for
i
in
range
(
DIM
):
new_coords
[:,
i
+
1
]
+=
subidx
//
self
.
factor
**
i
%
self
.
factor
idx
=
torch
.
repeat_interleave
(
torch
.
arange
(
x
.
coords
.
shape
[
0
],
device
=
x
.
device
),
N_leaf
,
dim
=
0
,
output_size
=
subidx
.
shape
[
0
])
else
:
new_coords
,
idx
=
cache
new_feats
=
x
.
feats
[
idx
]
out
=
SparseTensor
(
new_feats
,
new_coords
,
x
.
_shape
)
out
.
_scale
=
tuple
([
s
/
self
.
factor
for
s
in
x
.
_scale
])
if
cache
is
not
None
:
# only keep cache when subdiv following it
out
.
_spatial_cache
=
x
.
_spatial_cache
return
out
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/sparse/spatial/spatial2channel.py
0 → 100644
View file @
f05e915f
from
typing
import
*
import
torch
import
torch.nn
as
nn
from
..
import
SparseTensor
class
SparseSpatial2Channel
(
nn
.
Module
):
"""
Downsample a sparse tensor by a factor of `factor`.
Implemented as rearranging its features from spatial to channel.
"""
def
__init__
(
self
,
factor
:
int
=
2
):
super
(
SparseSpatial2Channel
,
self
).
__init__
()
self
.
factor
=
factor
def
forward
(
self
,
x
:
SparseTensor
)
->
SparseTensor
:
DIM
=
x
.
coords
.
shape
[
-
1
]
-
1
cache
=
x
.
get_spatial_cache
(
f
'spatial2channel_
{
self
.
factor
}
'
)
if
cache
is
None
:
coord
=
list
(
x
.
coords
.
unbind
(
dim
=-
1
))
for
i
in
range
(
DIM
):
coord
[
i
+
1
]
=
coord
[
i
+
1
]
//
self
.
factor
subidx
=
x
.
coords
[:,
1
:]
%
self
.
factor
subidx
=
sum
([
subidx
[...,
i
]
*
self
.
factor
**
i
for
i
in
range
(
DIM
)])
MAX
=
[(
s
+
self
.
factor
-
1
)
//
self
.
factor
for
s
in
x
.
spatial_shape
]
OFFSET
=
torch
.
cumprod
(
torch
.
tensor
(
MAX
[::
-
1
]),
0
).
tolist
()[::
-
1
]
+
[
1
]
code
=
sum
([
c
*
o
for
c
,
o
in
zip
(
coord
,
OFFSET
)])
code
,
idx
=
code
.
unique
(
return_inverse
=
True
)
new_coords
=
torch
.
stack
(
[
code
//
OFFSET
[
0
]]
+
[(
code
//
OFFSET
[
i
+
1
])
%
MAX
[
i
]
for
i
in
range
(
DIM
)],
dim
=-
1
)
else
:
new_coords
,
idx
,
subidx
=
cache
new_feats
=
torch
.
zeros
(
new_coords
.
shape
[
0
]
*
self
.
factor
**
DIM
,
x
.
feats
.
shape
[
1
],
device
=
x
.
feats
.
device
,
dtype
=
x
.
feats
.
dtype
)
new_feats
[
idx
*
self
.
factor
**
DIM
+
subidx
]
=
x
.
feats
out
=
SparseTensor
(
new_feats
.
reshape
(
new_coords
.
shape
[
0
],
-
1
),
new_coords
,
None
if
x
.
_shape
is
None
else
torch
.
Size
([
x
.
_shape
[
0
],
x
.
_shape
[
1
]
*
self
.
factor
**
DIM
]))
out
.
_scale
=
tuple
([
s
*
self
.
factor
for
s
in
x
.
_scale
])
out
.
_spatial_cache
=
x
.
_spatial_cache
if
cache
is
None
:
x
.
register_spatial_cache
(
f
'spatial2channel_
{
self
.
factor
}
'
,
(
new_coords
,
idx
,
subidx
))
out
.
register_spatial_cache
(
f
'channel2spatial_
{
self
.
factor
}
'
,
(
x
.
coords
,
idx
,
subidx
))
out
.
register_spatial_cache
(
f
'shape'
,
torch
.
Size
(
MAX
))
if
self
.
training
:
subdivision
=
torch
.
zeros
((
new_coords
.
shape
[
0
],
self
.
factor
**
DIM
),
device
=
x
.
device
,
dtype
=
torch
.
bool
)
subdivision
[
idx
,
subidx
]
=
True
out
.
register_spatial_cache
(
f
'subdivision'
,
subdivision
)
return
out
class
SparseChannel2Spatial
(
nn
.
Module
):
"""
Upsample a sparse tensor by a factor of `factor`.
Implemented as rearranging its features from channel to spatial.
"""
def
__init__
(
self
,
factor
:
int
=
2
):
super
(
SparseChannel2Spatial
,
self
).
__init__
()
self
.
factor
=
factor
def
forward
(
self
,
x
:
SparseTensor
,
subdivision
:
Optional
[
SparseTensor
]
=
None
)
->
SparseTensor
:
DIM
=
x
.
coords
.
shape
[
-
1
]
-
1
cache
=
x
.
get_spatial_cache
(
f
'channel2spatial_
{
self
.
factor
}
'
)
if
cache
is
None
:
if
subdivision
is
None
:
raise
ValueError
(
'Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.'
)
else
:
sub
=
subdivision
.
feats
# [N, self.factor ** DIM]
N_leaf
=
sub
.
sum
(
dim
=-
1
)
# [N]
subidx
=
sub
.
nonzero
()[:,
-
1
]
new_coords
=
x
.
coords
.
clone
().
detach
()
new_coords
[:,
1
:]
*=
self
.
factor
new_coords
=
torch
.
repeat_interleave
(
new_coords
,
N_leaf
,
dim
=
0
,
output_size
=
subidx
.
shape
[
0
])
for
i
in
range
(
DIM
):
new_coords
[:,
i
+
1
]
+=
subidx
//
self
.
factor
**
i
%
self
.
factor
idx
=
torch
.
repeat_interleave
(
torch
.
arange
(
x
.
coords
.
shape
[
0
],
device
=
x
.
device
),
N_leaf
,
dim
=
0
,
output_size
=
subidx
.
shape
[
0
])
else
:
new_coords
,
idx
,
subidx
=
cache
x_feats
=
x
.
feats
.
reshape
(
x
.
feats
.
shape
[
0
]
*
self
.
factor
**
DIM
,
-
1
)
new_feats
=
x_feats
[
idx
*
self
.
factor
**
DIM
+
subidx
]
out
=
SparseTensor
(
new_feats
,
new_coords
,
None
if
x
.
_shape
is
None
else
torch
.
Size
([
x
.
_shape
[
0
],
x
.
_shape
[
1
]
//
self
.
factor
**
DIM
]))
out
.
_scale
=
tuple
([
s
/
self
.
factor
for
s
in
x
.
_scale
])
if
cache
is
not
None
:
# only keep cache when subdiv following it
out
.
_spatial_cache
=
x
.
_spatial_cache
return
out
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/__init__.py
0 → 100644
View file @
f05e915f
from
.blocks
import
*
from
.modulated
import
*
\ No newline at end of file
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
TRELLIS.2_DCU/trellis2/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc
0 → 100644
View file @
f05e915f
File added
Prev
1
…
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment