Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
e2df774f
Commit
e2df774f
authored
Nov 06, 2022
by
yan.yan
Browse files
fix #532 overflow in huge dim
parent
1f5ce924
Changes
13
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
398 additions
and
246 deletions
+398
-246
.github/workflows/build.yaml
.github/workflows/build.yaml
+1
-1
CHANGELOG.md
CHANGELOG.md
+4
-0
README.md
README.md
+5
-3
pyproject.toml
pyproject.toml
+1
-1
setup.py
setup.py
+2
-2
spconv/algo.py
spconv/algo.py
+0
-1
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+2
-2
spconv/csrc/sparse/indices.py
spconv/csrc/sparse/indices.py
+308
-223
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+2
-2
spconv/test_utils.py
spconv/test_utils.py
+4
-3
test/dev.py
test/dev.py
+66
-5
test/test_all_algo.py
test/test_all_algo.py
+2
-2
version.txt
version.txt
+1
-1
No files found.
.github/workflows/build.yaml
View file @
e2df774f
...
@@ -116,7 +116,7 @@ jobs:
...
@@ -116,7 +116,7 @@ jobs:
strategy
:
strategy
:
matrix
:
matrix
:
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
,
'
3.11'
]
# this version is only used for upload.
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
,
'
3.11'
]
# this version is only used for upload.
cuda-version
:
[
'
102'
,
'
113'
,
'
114'
,
'
116'
,
'
117'
,
'
118'
]
cuda-version
:
[
'
102'
,
'
113'
,
'
114'
,
'
116'
,
'
117'
,
'
118'
,
'
'
]
steps
:
steps
:
-
uses
:
actions/checkout@master
-
uses
:
actions/checkout@master
...
...
CHANGELOG.md
View file @
e2df774f
# Changelog
# Changelog
## [2.2.5] - 2022-11-05
### Fixed
-
Fix overflow when shape is too large
## [2.2.4] - 2022-10-13
## [2.2.4] - 2022-10-13
### Added
### Added
-
Add prebuilt for CUDA 11.8 (RTX 4090 and H100) and CUDA 11.6.
-
Add prebuilt for CUDA 11.8 (RTX 4090 and H100) and CUDA 11.6.
...
...
README.md
View file @
e2df774f
...
@@ -41,8 +41,8 @@
...
@@ -41,8 +41,8 @@
[
pypi-url-118
]:
https://pypi.org/project/spconv-cu118/
[
pypi-url-118
]:
https://pypi.org/project/spconv-cu118/
[
pypi-download-118
]:
https://img.shields.io/pypi/dm/spconv-cu118
[
pypi-download-118
]:
https://img.shields.io/pypi/dm/spconv-cu118
[
pypi-url-116
]:
https://pypi.org/project/spconv-cu11
8
/
[
pypi-url-116
]:
https://pypi.org/project/spconv-cu11
6
/
[
pypi-download-116
]:
https://img.shields.io/pypi/dm/spconv-cu11
8
[
pypi-download-116
]:
https://img.shields.io/pypi/dm/spconv-cu11
6
# SpConv: Spatially Sparse Convolution Library
# SpConv: Spatially Sparse Convolution Library
[

](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild)
[

](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild)
...
@@ -57,7 +57,9 @@
...
@@ -57,7 +57,9 @@
| CUDA 11.4 |
[
![PyPI Version
][
pypi-ver-114
]
]
[
pypi-url-114] | ```pip install spconv-cu114```| [![pypi monthly download
][
pypi-download-114
]
][pypi-url-114]|
| CUDA 11.4 |
[
![PyPI Version
][
pypi-ver-114
]
]
[
pypi-url-114] | ```pip install spconv-cu114```| [![pypi monthly download
][
pypi-download-114
]
][pypi-url-114]|
| CUDA 11.6 |
[
![PyPI Version
][
pypi-ver-116
]
]
[
pypi-url-116] | ```pip install spconv-cu116```| [![pypi monthly download
][
pypi-download-116
]
][pypi-url-116]|
| CUDA 11.6 |
[
![PyPI Version
][
pypi-ver-116
]
]
[
pypi-url-116] | ```pip install spconv-cu116```| [![pypi monthly download
][
pypi-download-116
]
][pypi-url-116]|
| CUDA 11.7 |
[
![PyPI Version
][
pypi-ver-117
]
]
[
pypi-url-117] | ```pip install spconv-cu117```| [![pypi monthly download
][
pypi-download-117
]
][pypi-url-117]|
| CUDA 11.7 |
[
![PyPI Version
][
pypi-ver-117
]
]
[
pypi-url-117] | ```pip install spconv-cu117```| [![pypi monthly download
][
pypi-download-117
]
][pypi-url-117]|
| CUDA 11.8 |
[
![PyPI Version
][
pypi-ver-118
]
]
[
pypi-url-118] | ```pip install spconv-cu118```| [![pypi monthly download
][
pypi-download-118
]
][pypi-url-118]|
| CUDA 11.8
*
|
[
![PyPI Version
][
pypi-ver-118
]
]
[
pypi-url-118] | ```pip install spconv-cu118```| [![pypi monthly download
][
pypi-download-118
]
][pypi-url-118]|
*
: sm_89 and sm_90 is added in CUDA 11.8. If you use RTX 4090 or H100, you should use this version.
<!-- | CUDA 12.0 | [![PyPI Version][pypi-ver-120]][pypi-url-120] | ```pip install spconv-cu120```| [![pypi monthly download][pypi-download-120]][pypi-url-120]| -->
<!-- | CUDA 12.0 | [![PyPI Version][pypi-ver-120]][pypi-url-120] | ```pip install spconv-cu120```| [![pypi monthly download][pypi-download-120]][pypi-url-120]| -->
...
...
pyproject.toml
View file @
e2df774f
[build-system]
[build-system]
requires
=
[
"setuptools>=41.0"
,
"wheel"
,
"pccm>=0.4.0"
,
"cumm>=0.3.
5
"
]
requires
=
[
"setuptools>=41.0"
,
"wheel"
,
"pccm>=0.4.0"
,
"cumm>=0.3.
7
"
]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu118-0.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu118-0.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend
=
"setuptools.build_meta"
build-backend
=
"setuptools.build_meta"
setup.py
View file @
e2df774f
...
@@ -39,9 +39,9 @@ if cuda_ver:
...
@@ -39,9 +39,9 @@ if cuda_ver:
cuda_ver_str
=
cuda_ver
.
replace
(
"."
,
""
)
# 10.2 to 102
cuda_ver_str
=
cuda_ver
.
replace
(
"."
,
""
)
# 10.2 to 102
RELEASE_NAME
+=
"-cu{}"
.
format
(
cuda_ver_str
)
RELEASE_NAME
+=
"-cu{}"
.
format
(
cuda_ver_str
)
deps
=
[
"cumm-cu{}>=0.3.
4
"
.
format
(
cuda_ver_str
)]
deps
=
[
"cumm-cu{}>=0.3.
7
"
.
format
(
cuda_ver_str
)]
else
:
else
:
deps
=
[
"cumm>=0.3.
4
"
]
deps
=
[
"cumm>=0.3.
7
"
]
...
...
spconv/algo.py
View file @
e2df774f
...
@@ -618,7 +618,6 @@ class SimpleConv:
...
@@ -618,7 +618,6 @@ class SimpleConv:
]
]
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
self
.
prebuilt_desp_names
.
clear
()
self
.
lock
=
Lock
()
self
.
lock
=
Lock
()
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
all_desps
)
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
all_desps
)
...
...
spconv/csrc/sparse/all.py
View file @
e2df774f
...
@@ -1677,7 +1677,7 @@ class SpconvOps(pccm.Class):
...
@@ -1677,7 +1677,7 @@ class SpconvOps(pccm.Class):
}}
}}
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>())
* batch_size
;
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
...
@@ -2022,7 +2022,7 @@ Your Conv Params: )" << "\\n";
...
@@ -2022,7 +2022,7 @@ Your Conv Params: )" << "\\n";
}}
}}
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
std::vector<int64_t> output_dims_i64(out_shape.begin(), out_shape.end());
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
int64_t out_spatial_volume = std::accumulate(output_dims_i64.begin(),
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
output_dims_i64.end(), int64_t(1), std::multiplies<int64_t>())
* batch_size
;
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
bool use_int64_hash_k = out_spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
...
...
spconv/csrc/sparse/indices.py
View file @
e2df774f
This diff is collapsed.
Click to expand it.
spconv/pytorch/ops.py
View file @
e2df774f
...
@@ -185,7 +185,7 @@ def get_indice_pairs(indices: torch.Tensor,
...
@@ -185,7 +185,7 @@ def get_indice_pairs(indices: torch.Tensor,
)
)
assert
algo
==
ConvAlgo
.
Native
,
"TODO"
assert
algo
==
ConvAlgo
.
Native
,
"TODO"
# indices = indices.cpu()
# indices = indices.cpu()
spatial_volume
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
out_shape
,
1
)
spatial_volume
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
out_shape
,
1
)
*
batch_size
use_int64_hash_k
=
spatial_volume
>=
INT32_MAX
or
DEBUG_INT64_HASH_K
use_int64_hash_k
=
spatial_volume
>=
INT32_MAX
or
DEBUG_INT64_HASH_K
indice_dtype
=
torch
.
int64
if
use_int64_hash_k
else
indices
.
dtype
indice_dtype
=
torch
.
int64
if
use_int64_hash_k
else
indices
.
dtype
pair
=
torch
.
full
((
2
,
kv
,
indices
.
shape
[
0
]),
pair
=
torch
.
full
((
2
,
kv
,
indices
.
shape
[
0
]),
...
@@ -457,7 +457,7 @@ def get_indice_pairs_implicit_gemm(
...
@@ -457,7 +457,7 @@ def get_indice_pairs_implicit_gemm(
raise
ValueError
(
raise
ValueError
(
f
"your out spatial shape
{
out_shape
}
reach zero!!! input shape:
{
spatial_shape
}
"
f
"your out spatial shape
{
out_shape
}
reach zero!!! input shape:
{
spatial_shape
}
"
)
)
spatial_volume
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
spatial_shape
,
1
)
spatial_volume
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
spatial_shape
,
1
)
*
batch_size
use_int64_hash_k
=
spatial_volume
>=
INT32_MAX
or
DEBUG_INT64_HASH_K
use_int64_hash_k
=
spatial_volume
>=
INT32_MAX
or
DEBUG_INT64_HASH_K
indice_dtype
=
torch
.
int64
if
use_int64_hash_k
else
indices
.
dtype
indice_dtype
=
torch
.
int64
if
use_int64_hash_k
else
indices
.
dtype
assert
algo
==
ConvAlgo
.
MaskImplicitGemm
or
algo
==
ConvAlgo
.
MaskSplitImplicitGemm
,
"TODO"
assert
algo
==
ConvAlgo
.
MaskImplicitGemm
or
algo
==
ConvAlgo
.
MaskSplitImplicitGemm
,
"TODO"
...
...
spconv/test_utils.py
View file @
e2df774f
...
@@ -145,7 +145,8 @@ def generate_sparse_data(shape,
...
@@ -145,7 +145,8 @@ def generate_sparse_data(shape,
integer
=
False
,
integer
=
False
,
data_range
=
(
-
1
,
1
),
data_range
=
(
-
1
,
1
),
with_dense
=
True
,
with_dense
=
True
,
dtype
=
np
.
float32
):
dtype
=
np
.
float32
,
shape_scale
=
1
):
dense_shape
=
shape
dense_shape
=
shape
ndim
=
len
(
dense_shape
)
ndim
=
len
(
dense_shape
)
# num_points = np.random.randint(10, 100, size=[batch_size, ndim])
# num_points = np.random.randint(10, 100, size=[batch_size, ndim])
...
@@ -153,9 +154,9 @@ def generate_sparse_data(shape,
...
@@ -153,9 +154,9 @@ def generate_sparse_data(shape,
# num_points = np.array([3, 2])
# num_points = np.array([3, 2])
batch_size
=
len
(
num_points
)
batch_size
=
len
(
num_points
)
batch_indices
=
[]
batch_indices
=
[]
coors_total
=
np
.
stack
(
np
.
meshgrid
(
*
[
np
.
arange
(
0
,
s
)
for
s
in
shape
]),
coors_total
=
np
.
stack
(
np
.
meshgrid
(
*
[
np
.
arange
(
0
,
s
//
shape_scale
)
for
s
in
shape
]),
axis
=-
1
)
axis
=-
1
)
coors_total
=
coors_total
.
reshape
(
-
1
,
ndim
)
coors_total
=
coors_total
.
reshape
(
-
1
,
ndim
)
*
shape_scale
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
np
.
random
.
shuffle
(
coors_total
)
np
.
random
.
shuffle
(
coors_total
)
inds_total
=
coors_total
[:
num_points
[
i
]]
inds_total
=
coors_total
[:
num_points
[
i
]]
...
...
test/dev.py
View file @
e2df774f
import
spconv
import
spconv.pytorch
as
spconv
from
spconv.core
import
ConvAlgo
import
spconv.pytorch
as
spconv
from
spconv.test_utils
import
TestCase
,
generate_sparse_data
,
params_grid
from
spconv.core_cc.cumm.common
import
CompileInfo
import
torch
if
__name__
==
"__main__"
:
import
numpy
as
np
print
(
CompileInfo
.
arch_is_compatible_gemm
((
9
,
0
)),
CompileInfo
.
arch_is_compiled_gemm
((
9
,
0
)))
class
SparseMaxPool2dTestTorch
(
torch
.
nn
.
Module
):
print
(
CompileInfo
.
arch_is_compatible_gemm
((
8
,
6
)),
CompileInfo
.
arch_is_compiled_gemm
((
8
,
6
)))
def
__init__
(
self
,
num_layers
,
ndim
,
shape
,
kernel_size
,
stride
,
padding
,
\ No newline at end of file
dilation
,
algo
):
super
().
__init__
()
self
.
algo
=
algo
layers
=
[
spconv
.
SparseMaxPool2d
(
kernel_size
,
stride
,
padding
,
dilation
,
algo
=
algo
)
]
for
i
in
range
(
1
,
num_layers
):
layers
.
append
(
spconv
.
SparseMaxPool2d
(
kernel_size
,
stride
,
padding
,
dilation
,
algo
=
algo
))
self
.
net
=
spconv
.
SparseSequential
(
*
layers
,
)
self
.
shape
=
shape
def
forward
(
self
,
features
,
coors
,
batch_size
):
coors
=
coors
.
int
()
x
=
spconv
.
SparseConvTensor
(
features
,
coors
,
self
.
shape
,
batch_size
)
return
self
.
net
(
x
)
# .dense()
shapes
=
[[
65536
,
65536
]]
batchsizes
=
[
32
]
in_channels
=
[
32
]
out_channels
=
[
32
]
ksizes
=
[
2
]
strides
=
[
2
]
paddings
=
[
0
]
dilations
=
[
1
]
algos
=
[
# ConvAlgo.Native,
ConvAlgo
.
MaskImplicitGemm
,
# ConvAlgo.MaskSplitImplicitGemm
]
devices
=
[
"cuda:0"
]
for
dev
,
shape
,
bs
,
IC
,
OC
,
k
,
s
,
p
,
d
,
al
in
params_grid
(
devices
,
shapes
,
batchsizes
,
in_channels
,
out_channels
,
ksizes
,
strides
,
paddings
,
dilations
,
algos
):
device
=
torch
.
device
(
dev
)
num_points
=
[
1000
]
*
bs
print
(
1
)
sparse_dict
=
generate_sparse_data
(
shape
,
num_points
,
IC
,
with_dense
=
False
,
data_range
=
[
0.1
,
1
],
shape_scale
=
64
)
print
(
2
)
net
=
SparseMaxPool2dTestTorch
(
1
,
2
,
shape
,
k
,
s
,
p
,
d
,
al
).
to
(
device
)
features
=
np
.
ascontiguousarray
(
sparse_dict
[
"features"
]).
astype
(
np
.
float32
)
indices
=
np
.
ascontiguousarray
(
sparse_dict
[
"indices"
][:,
[
2
,
0
,
1
]]).
astype
(
np
.
int32
)
print
(
indices
.
max
(
0
))
indices_t
=
torch
.
from_numpy
(
indices
).
int
().
to
(
device
)
features_t
=
torch
.
from_numpy
(
features
).
to
(
device
)
features_t
.
requires_grad
=
True
out
=
net
(
features_t
,
indices_t
,
bs
)
print
(
out
.
indices
.
min
(
0
))
test/test_all_algo.py
View file @
e2df774f
...
@@ -916,8 +916,8 @@ def _test_native_conv_cuda(subm: bool):
...
@@ -916,8 +916,8 @@ def _test_native_conv_cuda(subm: bool):
def
test_all_algo_unit
():
def
test_all_algo_unit
():
# for i in range(5):
# for i in range(5):
#
_test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda
(
True
)
#
_test_impgemm_conv_cuda(False)
_test_impgemm_conv_cuda
(
False
)
_test_native_conv_cuda
(
True
)
_test_native_conv_cuda
(
True
)
_test_native_conv_cuda
(
False
)
_test_native_conv_cuda
(
False
)
...
...
version.txt
View file @
e2df774f
2.2.
4
2.2.
5
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment