Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
860527e2
Commit
860527e2
authored
Nov 11, 2021
by
yan.yan
Browse files
v2.1.7: fix a bug when run inference in eval mode
parent
9bf390da
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
96 additions
and
27 deletions
+96
-27
CHANGELOG.md
CHANGELOG.md
+4
-0
README.md
README.md
+3
-23
docs/SPCONV_DEVELOP_PLAN.md
docs/SPCONV_DEVELOP_PLAN.md
+83
-0
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+4
-2
spconv/pytorch/pool.py
spconv/pytorch/pool.py
+1
-1
version.txt
version.txt
+1
-1
No files found.
CHANGELOG.md
View file @
860527e2
# Changelog
## [2.1.7] - 2021-11-11
### Fixed
-
Fix a bug when net have inverse and run inference in eval mode.
## [2.1.6] - 2021-11-10
### Fixed
-
Fix missing -fopenmp in linker for CPU only
...
...
README.md
View file @
860527e2
...
...
@@ -62,19 +62,12 @@ Spconv 1.x users **NEED READ [THIS](docs/SPCONV_2_BREAKING_CHANGEs.md)** before
* doesn't depend on pytorch binary.
* since spconv 2.x doesn't depend on pytorch binary (never in future), it's impossible to support torch.jit/libtorch inference.
Spconv 2.
1 vs 1.x speed:
##
Spconv 2.
x Development and Roadmap
| | 1080Ti Spconv 1.x F32 | 1080Ti Spconv 2.0 F32 | 3080M* Spconv 2.1 F16 |
| -------------- |:---------------------:| ---------------------:| ----------:|
| 27x128x128 Fwd | 11ms | 5.4ms | 1.4ms |
See [dev plan](docs/SPCONV_DEVELOP_PLAN.md). A complete guide of spconv development will be released soon.
\* 3080M (Laptop) ~= 3070 Desktop
<!--
TODO Spconv vs [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) vs [torchsparse](https://github.com/mit-han-lab/torchsparse)
-->
## Usage
Firstly you need to use ```import spconv.pytorch as spconv``` in spconv 2.x.
...
...
@@ -160,20 +153,7 @@ You need to rebuild ```cumm``` first if you are build along a CUDA version that
5. run ```
pip install pccm cumm wheel
```
6. run ```
python setup.py bdist_wheel
```+```
pip install dists/xxx.whl
```
## Roadmap for Spconv 2.2-2.3:
*
TensorFormat32 support for faster fp32 training when you use NVIDIA Geforce RTX 30x0/Tesla A100/Quadro RTX Ax000 (2.2)
*
change implicit gemm weight layout from KRSC to RSKC to make sure we can use native algorithm with implicit gemm weight. (2.2)
*
documents (2.2)
*
Ampere feature support (2.3)
*
pytorch int8 inference, and QAT support (2.3)
## TODO in Spconv 2.x
-
[ ] Ampere (A100 / RTX 3000 series) feature support (work in progress)
-
[ ] torch QAT support (work in progress)
-
[ ] TensorRT (torch.fx based)
-
[ ] Build C++ only package
-
[ ] JIT compilation for CUDA kernels
-
[ ] Document (low priority)
## Note
...
...
docs/SPCONV_DEVELOP_PLAN.md
0 → 100644
View file @
860527e2
<!--
Copyright 2021 Yan Yan
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
## Spconv 2.x Develop Plan
If someone want to contribute to spconv 2.x, feel free to start new discussion in github, or just email to me.
### v2.2 Core Features
-
[ ] TF32 support
-
[ ] Make
```ConvAlgo.Native```
runable in KRSC layout and only use this layout in future
-
[ ] PyTorch Int8 Support
### v2.3 Core Features
-
[ ] Move most of function in spconv.pytorch.ops to C++
-
[ ] Ampere multi-stage gemm support
-
[ ] Optimize CUDA Kernels for small-channel-size layers.
### v2.4 Core Features
-
[ ] nvrtc support for gemm/conv kernels
-
[ ] C++ only spconv
-
[ ] TensorRT support
### Misc Features need contribution
-
[
] Test spconv 2.x in [torch-points3d
](
https://github.com/nicolas-chaulet/torch-points3d
)
and other frameworks
-
[ ] Documents in github Page
-
[ ] Better tests
### Details
1.
TF32 support
we only need to add tf32 tensor cores to cumm. not hard.
2.
Make
```ConvAlgo.Native```
runable in KRSC layout
Add stride arg to gemm kernels, use offset + stride to force gemm kernel use KRSC layout as a "KC" matrix.
3.
PyTorch Int8 Support
...
4.
Move most of function in spconv.pytorch.ops to C++
Pure engieering work.
5.
Ampere multi-stage gemm support
Not easy, we need to use new pattern to write gemm kernels.
6.
Optimize CUDA Kernels for small-channel-size layers
modify cumm and make it support small kernels. not hard, but need time.
7.
nvrtc support for gemm/conv kernels
need to rewrite kernel params in cumm. not easy.
8.
C++ only spconv
actually code generation is easy, we can finish this easily after move ops to c++.
9.
TensorRT support
The TensorRT support is the last feature in this plan. it needs lots of engieering work and prerequisites, may cost much time.
\ No newline at end of file
spconv/pytorch/conv.py
View file @
860527e2
...
...
@@ -346,7 +346,7 @@ class SparseConvolution(SparseModule):
mask_argsort_bwd_splits
=
datas
.
mask_argsort_fwd_splits
masks
=
datas
.
masks
out_spatial_shape
=
datas
.
spatial_shape
assert
pair_fwd
.
shape
[
0
]
==
np
.
prod
(
assert
datas
.
pair_fwd
.
shape
[
0
]
==
np
.
prod
(
self
.
kernel_size
),
"inverse conv must have same kernel size as its couple conv"
...
...
@@ -362,6 +362,8 @@ class SparseConvolution(SparseModule):
masks
=
datas
.
masks
else
:
with
input
.
_timer
.
namespace
(
"gen_pairs"
):
# we need to gen bwd indices for regular conv
# because it may be inversed.
res
=
ops
.
get_indice_pairs_implicit_gemm
(
indices
,
batch_size
,
...
...
@@ -374,7 +376,7 @@ class SparseConvolution(SparseModule):
out_padding
=
self
.
output_padding
,
subm
=
self
.
subm
,
transpose
=
self
.
transposed
,
is_train
=
self
.
training
,
is_train
=
(
not
self
.
subm
)
or
self
.
training
,
alloc
=
input
.
thrust_allocator
,
timer
=
input
.
_timer
)
outids
=
res
[
0
]
...
...
spconv/pytorch/pool.py
View file @
860527e2
...
...
@@ -178,7 +178,7 @@ class SparseMaxPool(SparseModule):
dilation
=
self
.
dilation
,
out_padding
=
out_padding
,
subm
=
self
.
subm
,
is_train
=
self
.
training
,
is_train
=
(
not
self
.
subm
)
or
self
.
training
,
alloc
=
input
.
thrust_allocator
,
timer
=
input
.
_timer
)
outids
=
res
[
0
]
...
...
version.txt
View file @
860527e2
2.1.
6
2.1.
7
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment