Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
21bb00ae
Commit
21bb00ae
authored
Jul 27, 2022
by
Yan Yan
Browse files
still working on c++ only
parent
899008fa
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
482 additions
and
170 deletions
+482
-170
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+207
-94
spconv/pytorch/pool.py
spconv/pytorch/pool.py
+254
-32
test/CMakeLists.txt
test/CMakeLists.txt
+0
-27
test/benchmark.py
test/benchmark.py
+20
-16
test/test_conv.py
test/test_conv.py
+1
-1
No files found.
spconv/pytorch/ops.py
View file @
21bb00ae
...
...
@@ -75,26 +75,31 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
output_size
.
append
(
size
)
return
output_size
class
_HashData
:
def
__init__
(
self
,
num
:
int
,
use_i64
:
bool
,
device
:
torch
.
device
)
->
None
:
if
use_i64
:
self
.
hashdata_k
=
torch
.
empty
((
num
*
2
,
),
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hashdata_v
=
torch
.
empty
((
num
*
2
,
),
self
.
hashdata_v
=
torch
.
empty
((
num
*
2
,
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
hashdata_k_tv
=
torch_tensor_to_tv
(
self
.
hashdata_k
)
self
.
hashdata_v_tv
=
torch_tensor_to_tv
(
self
.
hashdata_v
)
else
:
self
.
hashdata
=
torch
.
empty
((
2
,
num
*
2
,
),
self
.
hashdata
=
torch
.
empty
((
2
,
num
*
2
,
),
dtype
=
torch
.
int32
,
device
=
device
)
hashdata_tv
=
torch_tensor_to_tv
(
self
.
hashdata
)
self
.
hashdata_k_tv
=
hashdata_tv
[
0
]
self
.
hashdata_v_tv
=
hashdata_tv
[
1
]
def
get_indice_pairs
(
indices
:
torch
.
Tensor
,
batch_size
:
int
,
spatial_shape
:
List
[
int
],
...
...
@@ -119,13 +124,18 @@ def get_indice_pairs(indices: torch.Tensor,
if
indices
.
is_cuda
:
stream
=
get_current_stream
()
num_act_out
=
SpconvOps
.
get_indice_pairs
(
alloc
,
torch_tensor_to_tv
(
indices
),
batch_size
,
spatial_shape
,
algo
.
value
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
subm
,
transpose
,
stream
)
num_act_out
=
SpconvOps
.
get_indice_pairs
(
alloc
,
torch_tensor_to_tv
(
indices
),
batch_size
,
spatial_shape
,
algo
.
value
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
subm
,
transpose
,
stream
)
if
subm
:
out_inds
=
indices
else
:
out_inds
=
alloc
.
allocated
[
AllocKeys
.
OutIndices
]
pair
=
alloc
.
allocated
[
AllocKeys
.
Pair
]
pair
=
alloc
.
allocated
[
AllocKeys
.
Pair
Fwd
]
indice_num_per_loc
=
alloc
.
allocated
[
AllocKeys
.
IndiceNumPerLoc
]
# print(subm, out_inds.shape, pair.shape, indice_num_per_loc.shape, num_act_out)
return
out_inds
[:
num_act_out
],
pair
,
indice_num_per_loc
...
...
@@ -146,7 +156,7 @@ def get_indice_pairs(indices: torch.Tensor,
)
assert
algo
==
ConvAlgo
.
Native
,
"TODO"
# indices = indices.cpu()
spatial_volume
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
spatial
_shape
,
1
)
spatial_volume
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
out
_shape
,
1
)
use_int64_hash_k
=
spatial_volume
>=
INT32_MAX
or
DEBUG_INT64_HASH_K
indice_dtype
=
torch
.
int64
if
use_int64_hash_k
else
indices
.
dtype
pair
=
torch
.
full
((
2
,
kv
,
indices
.
shape
[
0
]),
...
...
@@ -164,7 +174,8 @@ def get_indice_pairs(indices: torch.Tensor,
out_inds
=
indices
if
indices
.
is_cuda
:
stream
=
get_current_stream
()
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
# hashdata = torch.empty((out_inds.shape[0] * 2, ),
# dtype=torch.int64,
# device=indices.device)
...
...
@@ -234,7 +245,8 @@ def get_indice_pairs(indices: torch.Tensor,
# hashdata = torch.empty((out_inds.shape[0] * 2, ),
# dtype=torch.int64,
# device=indices.device)
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
out_inds_tv
=
torch_tensor_to_tv
(
out_inds
)
# hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
...
...
@@ -281,6 +293,7 @@ def get_indice_pairs(indices: torch.Tensor,
# print("REGU", time.time() - t)
return
out_inds
,
pair
,
indice_num_per_loc
def
get_indice_pairs_implicit_gemm
(
indices
:
torch
.
Tensor
,
batch_size
:
int
,
...
...
@@ -303,11 +316,11 @@ def get_indice_pairs_implicit_gemm(
out_inds,
num_inds_per_loc,
pair_fwd,
pair_bwd, #
None
if subm or inference mode
pair_bwd, #
torch.Tensor()
if subm or inference mode
pair_mask_fwd_splits,
pair_mask_bwd_splits, #
None
if subm or inference mode
pair_mask_bwd_splits, #
torch.Tensor()
if subm or inference mode
mask_argsort_fwd_splits,
mask_argsort_bwd_splits, #
None
if subm or inference mode
mask_argsort_bwd_splits, #
torch.Tensor()
if subm or inference mode
masks,
)
"""
...
...
@@ -316,39 +329,47 @@ def get_indice_pairs_implicit_gemm(
thalloc
=
TorchAllocator
(
indices
.
device
)
mask_tensor
,
num_act_out
=
SpconvOps
.
get_indice_pairs_implicit_gemm
(
thalloc
,
torch_tensor_to_tv
(
indices
),
batch_size
,
spatial_shape
,
algo
.
value
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
subm
,
transpose
,
is_train
,
stream
,
num_out_act_bound
)
algo
.
value
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
subm
,
transpose
,
is_train
,
stream
,
num_out_act_bound
)
mask_split_count
=
mask_tensor
.
dim
(
0
)
masks
=
[
mask_tensor
[
i
:
i
+
1
].
numpy
()
for
i
in
range
(
mask_split_count
)]
masks
=
[
mask_tensor
[
i
:
i
+
1
].
numpy
()
for
i
in
range
(
mask_split_count
)]
if
subm
:
out_inds
=
indices
else
:
out_inds
=
thalloc
.
allocated
[
AllocKeys
.
OutIndices
]
pair
=
thalloc
.
allocated
[
AllocKeys
.
Pair
]
indice_num_per_loc
=
thalloc
.
allocated
[
AllocKeys
.
IndiceNumPerLoc
]
if
subm
:
# for subm, if training, pair shape is [2, kv, ...]
# if not training, pair is [1, kv, ...]
pair
=
thalloc
.
allocated
[
AllocKeys
.
PairFwd
]
pair_mask
=
thalloc
.
allocated
[
AllocKeys
.
PairMask
]
mask_argsort
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSort
]
pair_mask_in_splits
=
[
pair_mask
[
i
]
for
i
in
range
(
mask_split_count
)]
pair_mask_in_splits
=
[
pair_mask
[
i
]
for
i
in
range
(
mask_split_count
)
]
mask_argsort_in_splits
=
[
mask_argsort
[
i
]
for
i
in
range
(
mask_split_count
)
]
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair
[
1
],
pair_bwd
=
torch
.
Tensor
()
pair_fwd
=
pair
[
0
]
if
is_train
:
assert
pair
.
shape
[
0
]
==
2
pair_bwd
=
pair
[
1
]
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair_bwd
,
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
pair_bwd
=
pair
pair_bwd
=
thalloc
.
allocated
.
get
(
AllocKeys
.
PairBwd
,
torch
.
Tensor
())
pair_fwd
=
thalloc
.
allocated
[
AllocKeys
.
PairFwd
]
pair_mask_fwd
=
thalloc
.
allocated
[
AllocKeys
.
PairMask
]
pair_mask_bwd
=
torch
.
Tensor
()
mask_argsort_bwd
=
torch
.
Tensor
()
if
is_train
:
pair_mask_bwd
=
thalloc
.
allocated
[
AllocKeys
.
PairMaskBwd
]
mask_argsort_bwd
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSortBwd
]
mask_argsort_fwd
=
thalloc
.
allocated
[
AllocKeys
.
MaskArgSort
]
if
not
is_train
:
pair_bwd
=
torch
.
Tensor
()
pair_mask_bwd_splits
:
List
[
torch
.
Tensor
]
=
[]
mask_argsort_bwd_splits
:
List
[
torch
.
Tensor
]
=
[]
else
:
...
...
@@ -377,9 +398,6 @@ def get_indice_pairs_implicit_gemm(
kv
:
int
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ksize
,
1
)
# TODO in future we will support up to 128 kernel volume.
assert
kv
<=
32
,
"currently only support kernel volume <= 32 to use implicit gemm"
spatial_volume
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
spatial_shape
,
1
)
use_int64_hash_k
=
spatial_volume
>=
INT32_MAX
or
DEBUG_INT64_HASH_K
indice_dtype
=
torch
.
int64
if
use_int64_hash_k
else
indices
.
dtype
if
not
subm
:
if
transpose
:
...
...
@@ -394,6 +412,9 @@ def get_indice_pairs_implicit_gemm(
raise
ValueError
(
f
"your out spatial shape
{
out_shape
}
reach zero!!! input shape:
{
spatial_shape
}
"
)
spatial_volume
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
spatial_shape
,
1
)
use_int64_hash_k
=
spatial_volume
>=
INT32_MAX
or
DEBUG_INT64_HASH_K
indice_dtype
=
torch
.
int64
if
use_int64_hash_k
else
indices
.
dtype
assert
algo
==
ConvAlgo
.
MaskImplicitGemm
or
algo
==
ConvAlgo
.
MaskSplitImplicitGemm
,
"TODO"
is_mask_split
=
algo
==
ConvAlgo
.
MaskSplitImplicitGemm
mask_split_count
=
2
if
is_mask_split
else
1
...
...
@@ -433,7 +454,8 @@ def get_indice_pairs_implicit_gemm(
# hashdata = torch.empty((out_inds.shape[0] * 2, ),
# dtype=torch.int64,
# device=indices.device)
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
pair_mask
=
torch
.
empty
((
mask_split_count
,
indices
.
shape
[
0
]),
dtype
=
torch
.
int32
,
...
...
@@ -552,7 +574,8 @@ def get_indice_pairs_implicit_gemm(
device
=
indices
.
device
)
pair_mask_bwd_tv
=
torch_tensor_to_tv
(
pair_mask_bwd
,
dtype
=
tv
.
uint32
)
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
# hashdata = torch.empty((out_inds.shape[0] * 2, ),
# dtype=torch.int64,
...
...
@@ -714,13 +737,14 @@ def indice_conv(features: torch.Tensor,
if
SPCONV_CPP_GEMM
and
GEMM_CPP
is
not
None
:
# print("CPPPPPP!!!", features.device)
alloc
=
TorchAllocator
(
features
.
device
)
from
spconv.core_cc.csrc.sparse.convops
import
SimpleExternalSpconvMatmul
# ext_mm = TorchSpconvMatmul(alloc)
if
features
.
is_cuda
:
ext_mm
=
SimpleExternalSpconvMatmul
(
alloc
)
else
:
ext_mm
=
TorchSpconvMatmul
(
alloc
)
# from spconv.core_cc.csrc.sparse.convops import SimpleExternalSpconvMatmul
# if features.is_cuda:
# ext_mm = SimpleExternalSpconvMatmul(alloc)
# else:
# ext_mm = TorchSpconvMatmul(alloc)
alloc
.
allocated
[
AllocKeys
.
Features
]
=
features
alloc
.
allocated
[
AllocKeys
.
Filters
]
=
filters
...
...
@@ -731,13 +755,14 @@ def indice_conv(features: torch.Tensor,
stream
=
0
if
features
.
is_cuda
:
stream
=
get_current_stream
()
ConvGemmOps
.
indice_conv
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
stream
)
ConvGemmOps
.
indice_conv
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
stream
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
return
out_features
if
not
ALL_WEIGHT_IS_KRSC
:
kv_dim
=
0
is_KC_not_CK
=
not
FILTER_HWIO
...
...
@@ -779,7 +804,9 @@ def indice_conv(features: torch.Tensor,
features_np
=
torch_tensor_to_tv
(
features
).
numpy_view
()
filters_np
=
torch_tensor_to_tv
(
filters
).
numpy_view
()
out_features_np
=
torch_tensor_to_tv
(
out_features
).
numpy_view
()
np
.
matmul
(
features_np
,
filters_np
[:,
kv_center
].
T
,
out
=
out_features_np
)
np
.
matmul
(
features_np
,
filters_np
[:,
kv_center
].
T
,
out
=
out_features_np
)
# out_features = torch.mm(features, filters[:, kv_center].T)
else
:
out_features
=
torch
.
zeros
((
num_activate_out
,
out_channel
),
...
...
@@ -826,10 +853,13 @@ def indice_conv(features: torch.Tensor,
if
features
.
dtype
==
torch
.
float16
:
inp_buffer_np
=
torch_tensor_to_tv
(
inp_buffer
).
numpy_view
()
filters_np
=
torch_tensor_to_tv
(
filters
).
numpy_view
()
filters_i_np
=
filters_np
[
i
]
if
not
ALL_WEIGHT_IS_KRSC
else
filters_np
[:,
i
]
filters_i_np
=
filters_np
[
i
]
if
not
ALL_WEIGHT_IS_KRSC
else
filters_np
[:,
i
]
filters_cur_np
=
filters_i_np
if
not
is_KC_not_CK
else
filters_i_np
.
T
out_buffer_np
=
torch_tensor_to_tv
(
out_buffer
).
numpy_view
()
np
.
matmul
(
inp_buffer_np
[:
nhot
],
filters_cur_np
,
out
=
out_buffer_np
[:
nhot
])
np
.
matmul
(
inp_buffer_np
[:
nhot
],
filters_cur_np
,
out
=
out_buffer_np
[:
nhot
])
else
:
torch
.
mm
(
inp_buffer
[:
nhot
],
filters_cur
,
out
=
out_buffer
[:
nhot
])
SpconvOps
.
scatter_add_cpu
(
c
,
out_buffer_tv
,
out_indices
)
...
...
@@ -968,8 +998,10 @@ def indice_conv_backward(features: torch.Tensor,
stream
=
0
if
features
.
is_cuda
:
stream
=
get_current_stream
()
ConvGemmOps
.
indice_conv_backward
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
out_bp_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
ConvGemmOps
.
indice_conv_backward
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
out_bp_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
inverse
,
subm
,
algo
.
value
,
stream
)
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
df
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
...
...
@@ -1076,10 +1108,14 @@ def indice_conv_backward(features: torch.Tensor,
filters_KC
=
filters_i
if
is_KC_not_CK
else
filters_i
.
T
if
is_KC_not_CK
:
# KN @ NC
torch
.
mm
(
out_buffer
[:
nhot
].
T
,
inp_buffer
[:
nhot
],
out
=
dfilters_i
)
torch
.
mm
(
out_buffer
[:
nhot
].
T
,
inp_buffer
[:
nhot
],
out
=
dfilters_i
)
else
:
# CN @ NK
torch
.
mm
(
inp_buffer
[:
nhot
].
T
,
out_buffer
[:
nhot
],
out
=
dfilters_i
)
torch
.
mm
(
inp_buffer
[:
nhot
].
T
,
out_buffer
[:
nhot
],
out
=
dfilters_i
)
# NK @ KC
torch
.
mm
(
out_buffer
[:
nhot
],
filters_KC
,
out
=
inp_buffer
[:
nhot
])
SpconvOps
.
scatter_add_cpu
(
din_tv
,
inp_buffer_tv
,
inp_indices
)
...
...
@@ -1295,8 +1331,12 @@ def implicit_gemm(features: torch.Tensor,
alloc
=
TorchAllocator
(
features
.
device
)
features_tv
=
torch_tensor_to_tv
(
features
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
pair_mask_fwd_splits_tv
=
[
torch_tensor_to_tv
(
t
,
tv
.
uint32
)
for
t
in
pair_mask_fwd_splits
]
mask_argsort_fwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
mask_argsort_fwd_splits
]
pair_mask_fwd_splits_tv
=
[
torch_tensor_to_tv
(
t
,
tv
.
uint32
)
for
t
in
pair_mask_fwd_splits
]
mask_argsort_fwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
mask_argsort_fwd_splits
]
filters_tv
=
torch_tensor_to_tv
(
filters
)
mask
=
np
.
concatenate
(
masks
)
...
...
@@ -1307,9 +1347,11 @@ def implicit_gemm(features: torch.Tensor,
auto_fp32_accum
=
fp32_accum
is
None
if
fp32_accum
is
None
:
fp32_accum
=
False
mask_width
=
ConvGemmOps
.
implicit_gemm
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
num_activate_out
,
mask_tv
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
mask_width
=
ConvGemmOps
.
implicit_gemm
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
num_activate_out
,
mask_tv
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
if
is_train
:
...
...
@@ -1543,12 +1585,19 @@ def implicit_gemm_backward(features: torch.Tensor,
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
pair_bwd_tv
=
torch_tensor_to_tv
(
pair_bwd
)
pair_mask_fwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
pair_mask_fwd_splits
]
pair_mask_bwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
pair_mask_bwd_splits
]
mask_argsort_fwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
mask_argsort_fwd_splits
]
mask_argsort_bwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
mask_argsort_bwd_splits
]
pair_mask_fwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
pair_mask_fwd_splits
]
pair_mask_bwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
pair_mask_bwd_splits
]
mask_argsort_fwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
mask_argsort_fwd_splits
]
mask_argsort_bwd_splits_tv
=
[
torch_tensor_to_tv
(
t
)
for
t
in
mask_argsort_bwd_splits
]
filters_tv
=
torch_tensor_to_tv
(
filters
)
out_bp_tv
=
torch_tensor_to_tv
(
out_bp
)
...
...
@@ -1564,10 +1613,12 @@ def implicit_gemm_backward(features: torch.Tensor,
auto_fp32_accum
=
fp32_accum
is
None
if
fp32_accum
is
None
:
fp32_accum
=
False
ConvGemmOps
.
implicit_gemm_backward
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
out_bp_tv
,
pair_fwd_tv
,
pair_bwd_tv
,
pair_mask_fwd_splits_tv
,
pair_mask_bwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
mask_argsort_bwd_splits_tv
,
mask_output_fwd_tv
,
mask_tv
,
mask_width
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
ConvGemmOps
.
implicit_gemm_backward
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
out_bp_tv
,
pair_fwd_tv
,
pair_bwd_tv
,
pair_mask_fwd_splits_tv
,
pair_mask_bwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
mask_argsort_bwd_splits_tv
,
mask_output_fwd_tv
,
mask_tv
,
mask_width
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
dfilters
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
return
din
,
dfilters
...
...
@@ -1849,3 +1900,65 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
indice_pairs_tv
,
stream
)
return
din
def
indice_avgpool_implicit_gemm
(
features
:
torch
.
Tensor
,
indice_pairs
:
torch
.
Tensor
,
num_activate_out
,
calc_count
:
bool
):
# torch.cuda.synchronize()
# t = time.time()
stream
=
get_current_stream
()
# CONV.stream_synchronize(stream)
# t = time.time()
if
not
features
.
is_contiguous
():
features
=
features
.
contiguous
()
out_channel
=
features
.
shape
[
-
1
]
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
dtype
=
features
.
dtype
,
device
=
features
.
device
)
assert
features
.
is_cuda
stream
=
get_current_stream
()
out_features_tv
=
torch_tensor_to_tv
(
out_features
)
features_tv
=
torch_tensor_to_tv
(
features
)
indice_pairs_tv
=
torch_tensor_to_tv
(
indice_pairs
)
count_out
=
torch
.
Tensor
()
count_out_tv
=
tv
.
Tensor
()
if
calc_count
:
count_out
=
torch
.
zeros
((
num_activate_out
,),
dtype
=
torch
.
int32
,
device
=
features
.
device
)
count_out_tv
=
torch_tensor_to_tv
(
count_out
)
SpconvOps
.
avgpool_implicit_gemm_forward
(
out_features_tv
,
features_tv
,
indice_pairs_tv
,
count_out_tv
,
stream
)
# CONV.stream_synchronize(stream)
# print("M", time.time() - t)
return
out_features
,
count_out
def
indice_avgpool_implicit_gemm_backward
(
out_bp
,
indice_pairs
,
count_out
):
# torch.cuda.synchronize()
# t = time.time()
out_channel
=
out_bp
.
shape
[
-
1
]
din
=
torch
.
zeros
((
indice_pairs
.
shape
[
1
],
out_bp
.
shape
[
1
]),
dtype
=
out_bp
.
dtype
,
device
=
out_bp
.
device
)
assert
out_bp
.
is_cuda
if
not
out_bp
.
is_contiguous
():
out_bp
=
out_bp
.
contiguous
()
stream
=
get_current_stream
()
count_out_tv
=
torch_tensor_to_tv
(
count_out
)
out_bp_tv
=
torch_tensor_to_tv
(
out_bp
)
din_tv
=
torch_tensor_to_tv
(
din
)
indice_pairs_tv
=
torch_tensor_to_tv
(
indice_pairs
)
SpconvOps
.
avgpool_implicit_gemm_backward
(
out_bp_tv
,
din_tv
,
indice_pairs_tv
,
count_out_tv
,
stream
)
return
din
def
maximum_value_int_
(
ten
:
torch
.
Tensor
,
value
:
int
):
stream
=
0
if
not
CPU_ONLY_BUILD
:
stream
=
get_current_stream
()
else
:
assert
not
ten
.
is_cuda
SpconvOps
.
maximum_value_int
(
torch_tensor_to_tv
(
ten
),
value
,
stream
)
spconv/pytorch/pool.py
View file @
21bb00ae
...
...
@@ -30,6 +30,7 @@ from spconv.pytorch.core import IndiceData, ImplicitGemmIndiceData, expand_nd
from
spconv.pytorch.modules
import
SparseModule
from
spconv.cppconstants
import
CPU_ONLY_BUILD
from
spconv.utils
import
nullcontext
from
.conv
import
_MAX_NUM_VOXELS_DURING_TRAINING
class
SparseMaxPool
(
SparseModule
):
...
...
@@ -42,6 +43,7 @@ class SparseMaxPool(SparseModule):
indice_key
:
Optional
[
str
]
=
None
,
subm
:
bool
=
False
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseMaxPool
,
self
).
__init__
(
name
=
name
)
self
.
ndim
=
ndim
...
...
@@ -52,6 +54,12 @@ class SparseMaxPool(SparseModule):
self
.
stride
=
expand_nd
(
ndim
,
stride
)
self
.
padding
=
expand_nd
(
ndim
,
padding
)
self
.
subm
=
subm
if
record_voxel_count
and
not
self
.
subm
:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self
.
register_buffer
(
_MAX_NUM_VOXELS_DURING_TRAINING
,
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
))
self
.
record_voxel_count
=
record_voxel_count
self
.
dilation
=
expand_nd
(
ndim
,
dilation
)
self
.
indice_key
=
indice_key
kv
=
int
(
np
.
prod
(
kernel_size
))
...
...
@@ -220,6 +228,136 @@ class SparseMaxPool(SparseModule):
features
.
shape
[
0
])
out_tensor
.
benchmark_record
[
self
.
name
][
"num_out_points"
].
append
(
out_features
.
shape
[
0
])
if
not
self
.
subm
and
self
.
record_voxel_count
:
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
ops
.
maximum_value_int_
(
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
),
outids
.
shape
[
0
])
out_tensor
=
out_tensor
.
replace_feature
(
out_features
)
out_tensor
.
indices
=
outids
out_tensor
.
indice_dict
=
indice_dict
out_tensor
.
spatial_shape
=
out_spatial_shape
return
out_tensor
class
SparseAvgPool
(
SparseModule
):
def
__init__
(
self
,
ndim
,
kernel_size
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
3
,
stride
:
Optional
[
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
indice_key
:
Optional
[
str
]
=
None
,
subm
:
bool
=
False
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseAvgPool
,
self
).
__init__
(
name
=
name
)
self
.
ndim
=
ndim
self
.
kernel_size
=
expand_nd
(
ndim
,
kernel_size
)
if
stride
is
None
:
self
.
stride
=
self
.
kernel_size
.
copy
()
else
:
self
.
stride
=
expand_nd
(
ndim
,
stride
)
self
.
padding
=
expand_nd
(
ndim
,
padding
)
self
.
subm
=
subm
if
record_voxel_count
and
not
self
.
subm
:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self
.
register_buffer
(
_MAX_NUM_VOXELS_DURING_TRAINING
,
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
))
self
.
record_voxel_count
=
record_voxel_count
self
.
dilation
=
expand_nd
(
ndim
,
dilation
)
self
.
indice_key
=
indice_key
kv
=
int
(
np
.
prod
(
kernel_size
))
assert
kv
<=
32
,
"avg pool only support implicit-gemm style indice gen with kv <= 32 limit"
self
.
algo
=
ConvAlgo
.
MaskImplicitGemm
def
extra_repr
(
self
):
s
=
(
'kernel_size={kernel_size}'
', stride={stride}'
)
if
self
.
padding
!=
(
0
,
)
*
len
(
self
.
padding
):
s
+=
', padding={padding}'
if
self
.
dilation
!=
(
1
,
)
*
len
(
self
.
dilation
):
s
+=
', dilation={dilation}'
if
self
.
algo
is
not
None
:
s
+=
f
', algo=
{
self
.
algo
}
'
return
s
.
format
(
**
self
.
__dict__
)
def
forward
(
self
,
input
):
assert
isinstance
(
input
,
spconv
.
SparseConvTensor
)
features
=
input
.
features
device
=
features
.
device
indices
=
input
.
indices
spatial_shape
=
input
.
spatial_shape
batch_size
=
input
.
batch_size
if
not
self
.
subm
:
out_spatial_shape
=
ops
.
get_conv_output_size
(
spatial_shape
,
self
.
kernel_size
,
self
.
stride
,
self
.
padding
,
self
.
dilation
)
else
:
out_spatial_shape
=
spatial_shape
out_tensor
=
input
.
shadow_copy
()
out_padding
=
[
0
]
*
self
.
ndim
indice_dict
=
input
.
indice_dict
.
copy
()
profile_ctx
=
nullcontext
()
if
input
.
_timer
is
not
None
and
self
.
_sparse_unique_name
:
profile_ctx
=
input
.
_timer
.
namespace
(
self
.
_sparse_unique_name
)
with
profile_ctx
:
with
input
.
_timer
.
namespace
(
"gen_pairs"
):
res
=
ops
.
get_indice_pairs_implicit_gemm
(
indices
,
batch_size
,
spatial_shape
,
self
.
algo
,
ksize
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
out_padding
=
out_padding
,
subm
=
self
.
subm
,
is_train
=
(
not
self
.
subm
)
or
self
.
training
,
alloc
=
input
.
thrust_allocator
,
timer
=
input
.
_timer
)
outids
=
res
[
0
]
num_inds_per_loc
=
res
[
1
]
pair_fwd
=
res
[
2
]
pair_bwd
=
res
[
3
]
pair_mask_fwd_splits
=
res
[
4
]
pair_mask_bwd_splits
=
res
[
5
]
mask_argsort_fwd_splits
=
res
[
6
]
mask_argsort_bwd_splits
=
res
[
7
]
masks
=
res
[
8
]
if
self
.
indice_key
is
not
None
:
indice_data
=
ImplicitGemmIndiceData
(
outids
,
indices
,
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
=
pair_mask_fwd_splits
,
pair_mask_bwd_splits
=
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
=
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
=
mask_argsort_bwd_splits
,
masks
=
masks
,
is_subm
=
self
.
subm
,
spatial_shape
=
spatial_shape
,
out_spatial_shape
=
out_spatial_shape
,
algo
=
self
.
algo
,
ksize
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
)
msg
=
f
"your indice key
{
self
.
indice_key
}
already exists in this sparse tensor."
assert
self
.
indice_key
not
in
indice_dict
,
msg
indice_dict
[
self
.
indice_key
]
=
indice_data
out_features
=
Fsp
.
indice_avgpool_implicit_gemm
(
features
,
pair_fwd
,
pair_bwd
,
outids
.
shape
[
0
],
self
.
training
)
if
not
self
.
subm
and
self
.
record_voxel_count
:
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
ops
.
maximum_value_int_
(
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
),
outids
.
shape
[
0
])
out_tensor
=
out_tensor
.
replace_feature
(
out_features
)
out_tensor
.
indices
=
outids
out_tensor
.
indice_dict
=
indice_dict
...
...
@@ -235,14 +373,17 @@ class SparseMaxPool1d(SparseMaxPool):
dilation
=
1
,
indice_key
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseMaxPool1d
,
self
).
__init__
(
1
,
super
(
SparseMaxPool1d
,
self
).
__init__
(
1
,
kernel_size
,
stride
,
padding
,
dilation
,
indice_key
=
indice_key
,
algo
=
algo
,
record_voxel_count
=
record_voxel_count
,
name
=
name
)
...
...
@@ -254,14 +395,17 @@ class SparseMaxPool2d(SparseMaxPool):
dilation
=
1
,
indice_key
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseMaxPool2d
,
self
).
__init__
(
2
,
super
(
SparseMaxPool2d
,
self
).
__init__
(
2
,
kernel_size
,
stride
,
padding
,
dilation
,
indice_key
=
indice_key
,
algo
=
algo
,
record_voxel_count
=
record_voxel_count
,
name
=
name
)
...
...
@@ -273,14 +417,17 @@ class SparseMaxPool3d(SparseMaxPool):
dilation
=
1
,
indice_key
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseMaxPool3d
,
self
).
__init__
(
3
,
super
(
SparseMaxPool3d
,
self
).
__init__
(
3
,
kernel_size
,
stride
,
padding
,
dilation
,
indice_key
=
indice_key
,
algo
=
algo
,
record_voxel_count
=
record_voxel_count
,
name
=
name
)
...
...
@@ -292,12 +439,87 @@ class SparseMaxPool4d(SparseMaxPool):
dilation
=
1
,
indice_key
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseMaxPool4d
,
self
).
__init__
(
4
,
kernel_size
,
stride
,
padding
,
dilation
,
indice_key
=
indice_key
,
algo
=
algo
,
record_voxel_count
=
record_voxel_count
,
name
=
name
)
class
SparseAvgPool1d
(
SparseAvgPool
):
"""avg pool that use real point count instead of kernel size.
"""
def
__init__
(
self
,
kernel_size
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
indice_key
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseAvgPool1d
,
self
).
__init__
(
1
,
kernel_size
,
stride
,
padding
,
dilation
,
indice_key
=
indice_key
,
algo
=
algo
,
record_voxel_count
=
record_voxel_count
,
name
=
name
)
class
SparseAvgPool2d
(
SparseAvgPool
):
"""avg pool that use real point count instead of kernel size.
"""
def
__init__
(
self
,
kernel_size
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
indice_key
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseAvgPool2d
,
self
).
__init__
(
2
,
kernel_size
,
stride
,
padding
,
dilation
,
indice_key
=
indice_key
,
algo
=
algo
,
record_voxel_count
=
record_voxel_count
,
name
=
name
)
class
SparseAvgPool3d
(
SparseAvgPool
):
"""avg pool that use real point count instead of kernel size.
"""
def
__init__
(
self
,
kernel_size
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
indice_key
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
record_voxel_count
:
bool
=
False
,
name
=
None
):
super
(
SparseMaxPool4d
,
self
).
__init__
(
4
,
super
(
SparseAvgPool3d
,
self
).
__init__
(
3
,
kernel_size
,
stride
,
padding
,
dilation
,
indice_key
=
indice_key
,
algo
=
algo
,
record_voxel_count
=
record_voxel_count
,
name
=
name
)
test/CMakeLists.txt
deleted
100644 → 0
View file @
899008fa
set
(
CATCH_HEADER
${
PROJECT_SOURCE_DIR
}
/third_party/catch2
)
add_library
(
catch_main OBJECT src/catch_main.cpp
)
# target_compile_features(catch_main PUBLIC cxx_std_2a)
set_property
(
TARGET catch_main PROPERTY CXX_STANDARD 14
)
target_include_directories
(
catch_main PRIVATE
${
CATCH_HEADER
}
)
file
(
GLOB files
"src/test_*.cpp"
)
foreach
(
file
${
files
}
)
get_filename_component
(
file_basename
${
file
}
NAME_WE
)
string
(
REGEX REPLACE
"test_([^$]+)"
"test-
\\
1"
testcase
${
file_basename
}
)
add_executable
(
${
testcase
}
${
file
}
$<TARGET_OBJECTS:catch_main>
)
set_property
(
TARGET
${
testcase
}
PROPERTY CXX_STANDARD 14
)
# set_target_properties(${testcase} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
# set_property(TARGET ${testcase} PROPERTY CUDA_STANDARD 14)
target_compile_definitions
(
${
testcase
}
PRIVATE
CATCH_CONFIG_FAST_COMPILE
)
target_include_directories
(
${
testcase
}
PRIVATE
${
CATCH_HEADER
}
${
ALL_INCLUDE
}
)
target_link_libraries
(
${
testcase
}
${
ALL_LIBS
}
pybind11::embed -Wl,--no-as-needed spconv
)
add_test
(
NAME
"
${
testcase
}
"
COMMAND
${
testcase
}
WORKING_DIRECTORY
${
CMAKE_SOURCE_DIR
}
)
endforeach
()
\ No newline at end of file
test/benchmark.py
View file @
21bb00ae
...
...
@@ -113,7 +113,7 @@ class Net(nn.Module):
# nn.BatchNorm1d(32),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv
.
SparseMaxPool3d
(
2
,
2
,
algo
=
pool_algo
),
spconv
.
SparseMaxPool3d
(
2
,
2
,
algo
=
pool_algo
,
record_voxel_count
=
True
),
spconv
.
SubMConv3d
(
64
,
96
,
3
,
...
...
@@ -332,7 +332,7 @@ def main():
voxels_th
=
torch
.
from_numpy
(
voxels
).
to
(
device
).
to
(
dtype
)
coors_th
=
torch
.
from_numpy
(
coors
).
to
(
device
).
int
()
voxels_th
.
requires_grad
=
True
algo
=
spconv
.
ConvAlgo
.
MaskImplicitGemm
algo
=
spconv
.
ConvAlgo
.
Native
# 3080 Laptop
# MaskImpGemm: 11.2ms
# MaskSplitImpGemm: 12.2ms
...
...
@@ -385,21 +385,25 @@ def main():
torch
.
cuda
.
synchronize
()
# sort_bench()
times
.
append
(
time
.
time
()
-
t
)
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
# net.load_state_dict(state)
# breakpoint()
print
(
"spconv time"
,
np
.
mean
(
times
[
10
:]))
times
=
[]
for
i
in
range
(
10
):
out
=
net
(
voxels_th
,
coors_th
,
1
)
print
(
"------------"
)
torch
.
cuda
.
synchronize
()
t
=
time
.
time
()
out
.
features
.
backward
(
dout_t
)
torch
.
cuda
.
synchronize
()
times
.
append
(
time
.
time
()
-
t
)
# # print((net.grid == -1).float().sum(), net.grid.numel())
# # print("spconv time", time.time() - t)
print
(
"spconv bw time"
,
np
.
mean
(
times
[
5
:]))
#
times = []
#
for i in range(10):
#
out = net(voxels_th, coors_th, 1)
#
print("------------")
#
torch.cuda.synchronize()
#
t = time.time()
#
out.features.backward(dout_t)
#
torch.cuda.synchronize()
#
times.append(time.time() - t)
# #
#
print((net.grid == -1).float().sum(), net.grid.numel())
# #
#
print("spconv time", time.time() - t)
#
print("spconv bw time", np.mean(times[5:]))
if
__name__
==
"__main__"
:
...
...
test/test_conv.py
View file @
21bb00ae
...
...
@@ -248,7 +248,7 @@ def test_spconv3d():
ConvAlgo
.
Native
,
ConvAlgo
.
MaskImplicitGemm
,
ConvAlgo
.
MaskSplitImplicitGemm
]
algos
=
[
ConvAlgo
.
Native
]
algos
=
[
ConvAlgo
.
Native
,
ConvAlgo
.
MaskImplicitGemm
,
ConvAlgo
.
MaskSplitImplicitGemm
]
for
dev
,
shape
,
bs
,
IC
,
OC
,
k
,
s
,
p
,
d
,
al
in
params_grid
(
devices
,
shapes
,
batchsizes
,
in_channels
,
out_channels
,
ksizes
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment