Commit 1a333c99 authored by yan.yan's avatar yan.yan
Browse files

add support for pytorch 1.5

parent b32079a2
# Changelog # Changelog
## [2.1.8] - 2021-11-15
### Added
- Add support for pytorch 1.5
## [2.1.7] - 2021-11-11 ## [2.1.7] - 2021-11-11
### Fixed ### Fixed
- Fix a bug when net have inverse and run inference in eval mode. - Fix a bug when net have inverse and run inference in eval mode.
......
...@@ -61,7 +61,7 @@ Spconv 1.x users **NEED READ [THIS](docs/SPCONV_2_BREAKING_CHANGEs.md)** before ...@@ -61,7 +61,7 @@ Spconv 1.x users **NEED READ [THIS](docs/SPCONV_2_BREAKING_CHANGEs.md)** before
* fp32 (not tf32) training/inference speed is increased (+50~80%) * fp32 (not tf32) training/inference speed is increased (+50~80%)
* fp16 training/inference speed is greatly increased when your layer support tensor core (channel size must be multiple of 8). * fp16 training/inference speed is greatly increased when your layer support tensor core (channel size must be multiple of 8).
* int8 op is ready, but we still need some time to figure out how to run int8 in pytorch. * int8 op is ready, but we still need some time to figure out how to run int8 in pytorch.
* doesn't depend on pytorch binary. * [doesn't depend on pytorch binary](docs/FAQ.md#What-does-no-dependency-on-pytorch-mean), but you may need at least pytorch >= 1.6.0 to run spconv 2.x.
* since spconv 2.x doesn't depend on pytorch binary (never in future), it's impossible to support torch.jit/libtorch inference. * since spconv 2.x doesn't depend on pytorch binary (never in future), it's impossible to support torch.jit/libtorch inference.
## Spconv 2.x Development and Roadmap ## Spconv 2.x Development and Roadmap
......
# Frequently Asked Questions
- [What does no dependency on pytorch mean?](#What-does-no-dependency-on-pytorch-mean)
## What does no dependency on pytorch mean?
This means spconv 2.x doesn't have pytorch shared library dependency when you use ```ldd``` to inspect required shared objects of our shared library.
This **doesn't** mean spconv 2.x can run in pytorch with **any** version.
Most of pytorch extension repos use official pytorch extension build system, libraries built from these extension depend on pytorch c++ library and impossible to match requirements of [manylinux](https://github.com/pypa/manylinux). The official python package server, [PyPI](https://pypi.org/), and its mirrors, only accept manylinux package for linux platforms. So we must remove all pytorch stuffs from our c++ code to create manylinux packages.
Spconv 2.x use two core feature of pytorch to match manylinux requirements: ```torch.Tensor.data_ptr``` and ```torch.cuda.current_stream().cuda_stream```. the first one is used to get pointer of ```torch.Tensor```, the second part is used to get cuda stream pointer. We don't need pytorch anymore in c++ code when these features are available in pytorch.
\ No newline at end of file
...@@ -206,15 +206,14 @@ setup( ...@@ -206,15 +206,14 @@ setup(
install_requires=REQUIRED, install_requires=REQUIRED,
extras_require=EXTRAS, extras_require=EXTRAS,
include_package_data=True, include_package_data=True,
license='MIT', license='Apache License 2.0',
classifiers=[ classifiers=[
# Trove classifiers # Trove classifiers
# Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
'License :: OSI Approved :: MIT License', 'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python', 'Programming Language :: Python',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy'
], ],
# $ setup.py publish support. # $ setup.py publish support.
cmdclass=cmdclass, cmdclass=cmdclass,
......
...@@ -15,19 +15,33 @@ ...@@ -15,19 +15,33 @@
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
from typing import Optional from typing import Optional, TypeVar
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
from spconv.pytorch import ops from spconv.pytorch import ops
import torch.cuda.amp as amp from spconv.pytorch.constants import PYTORCH_VERSION
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
import numpy as np import numpy as np
from typing import List from typing import List
_T = TypeVar("_T")
def identity_decorator(func: _T) -> _T:
return func
if PYTORCH_VERSION >= [1, 6, 0]:
import torch.cuda.amp as amp
_TORCH_CUSTOM_FWD = amp.custom_fwd(cast_inputs=torch.float16)
_TORCH_CUSTOM_BWD = amp.custom_bwd
else:
_TORCH_CUSTOM_FWD = identity_decorator
_TORCH_CUSTOM_BWD = identity_decorator
class SparseConvFunction(Function): class SparseConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @_TORCH_CUSTOM_FWD
def forward(ctx, def forward(ctx,
features, features,
filters, filters,
...@@ -50,7 +64,7 @@ class SparseConvFunction(Function): ...@@ -50,7 +64,7 @@ class SparseConvFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @_TORCH_CUSTOM_BWD
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer timer = ctx.timer
...@@ -69,7 +83,7 @@ class SparseConvFunction(Function): ...@@ -69,7 +83,7 @@ class SparseConvFunction(Function):
class SparseInverseConvFunction(Function): class SparseInverseConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @_TORCH_CUSTOM_FWD
def forward(ctx, def forward(ctx,
features, features,
filters, filters,
...@@ -94,7 +108,7 @@ class SparseInverseConvFunction(Function): ...@@ -94,7 +108,7 @@ class SparseInverseConvFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @_TORCH_CUSTOM_BWD
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer timer = ctx.timer
...@@ -114,7 +128,7 @@ class SparseInverseConvFunction(Function): ...@@ -114,7 +128,7 @@ class SparseInverseConvFunction(Function):
class SparseImplicitGemmFunction(Function): class SparseImplicitGemmFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @_TORCH_CUSTOM_FWD
def forward(ctx, def forward(ctx,
features: torch.Tensor, features: torch.Tensor,
filters: torch.Tensor, filters: torch.Tensor,
...@@ -151,7 +165,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -151,7 +165,7 @@ class SparseImplicitGemmFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @_TORCH_CUSTOM_BWD
def backward(ctx, grad_output): def backward(ctx, grad_output):
features, filters, pair_fwd, pair_bwd = ctx.saved_tensors features, filters, pair_fwd, pair_bwd = ctx.saved_tensors
mask_width = ctx.mask_width mask_width = ctx.mask_width
...@@ -185,7 +199,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -185,7 +199,7 @@ class SparseImplicitGemmFunction(Function):
class SubMConvFunction(Function): class SubMConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @_TORCH_CUSTOM_FWD
def forward(ctx, def forward(ctx,
features, features,
filters, filters,
...@@ -209,7 +223,7 @@ class SubMConvFunction(Function): ...@@ -209,7 +223,7 @@ class SubMConvFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @_TORCH_CUSTOM_BWD
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer timer = ctx.timer
...@@ -229,7 +243,7 @@ class SubMConvFunction(Function): ...@@ -229,7 +243,7 @@ class SubMConvFunction(Function):
class SparseMaxPoolFunction(Function): class SparseMaxPoolFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @_TORCH_CUSTOM_FWD
def forward(ctx, features, indice_pairs, indice_pair_num, def forward(ctx, features, indice_pairs, indice_pair_num,
num_activate_out): num_activate_out):
out = ops.indice_maxpool(features, indice_pairs, indice_pair_num, out = ops.indice_maxpool(features, indice_pairs, indice_pair_num,
...@@ -239,7 +253,7 @@ class SparseMaxPoolFunction(Function): ...@@ -239,7 +253,7 @@ class SparseMaxPoolFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @_TORCH_CUSTOM_BWD
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, out = ctx.saved_tensors indice_pairs, indice_pair_num, features, out = ctx.saved_tensors
input_bp = ops.indice_maxpool_backward(features, out, grad_output, input_bp = ops.indice_maxpool_backward(features, out, grad_output,
...@@ -249,7 +263,7 @@ class SparseMaxPoolFunction(Function): ...@@ -249,7 +263,7 @@ class SparseMaxPoolFunction(Function):
class SparseMaxPoolImplicitGemmFunction(Function): class SparseMaxPoolImplicitGemmFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @_TORCH_CUSTOM_FWD
def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor, def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
indice_pairs_bwd: torch.Tensor, num_activate_out: int): indice_pairs_bwd: torch.Tensor, num_activate_out: int):
out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd, out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd,
...@@ -259,7 +273,7 @@ class SparseMaxPoolImplicitGemmFunction(Function): ...@@ -259,7 +273,7 @@ class SparseMaxPoolImplicitGemmFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @_TORCH_CUSTOM_BWD
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs_bwd, features, out = ctx.saved_tensors indice_pairs_bwd, features, out = ctx.saved_tensors
input_bp = ops.indice_maxpool_implicit_gemm_backward( input_bp = ops.indice_maxpool_implicit_gemm_backward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment