Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
aa26c99e
Commit
aa26c99e
authored
Dec 29, 2022
by
yan.yan
Browse files
working on quantization
parent
ee8c9465
Changes
29
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
745 additions
and
465 deletions
+745
-465
example/fuse_bn_act.py
example/fuse_bn_act.py
+11
-1
example/libspconv/main.cu
example/libspconv/main.cu
+2
-2
example/mnist/mnist_qat.py
example/mnist/mnist_qat.py
+235
-0
example/mnist/mnist_sparse.py
example/mnist/mnist_sparse.py
+0
-0
pyproject.toml
pyproject.toml
+1
-1
spconv/algo.py
spconv/algo.py
+53
-21
spconv/algocore.py
spconv/algocore.py
+7
-2
spconv/build.py
spconv/build.py
+2
-2
spconv/core.py
spconv/core.py
+60
-43
spconv/core_cc/csrc/sparse/all/__init__.pyi
spconv/core_cc/csrc/sparse/all/__init__.pyi
+8
-51
spconv/core_cc/csrc/sparse/convops/convops.pyi
spconv/core_cc/csrc/sparse/convops/convops.pyi
+11
-6
spconv/core_cc/csrc/sparse/convops/spops.pyi
spconv/core_cc/csrc/sparse/convops/spops.pyi
+6
-4
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+44
-108
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+86
-28
spconv/csrc/sparse/indices.py
spconv/csrc/sparse/indices.py
+12
-11
spconv/csrc/sparse/pointops.py
spconv/csrc/sparse/pointops.py
+3
-3
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+128
-33
spconv/pytorch/core.py
spconv/pytorch/core.py
+3
-3
spconv/pytorch/functional.py
spconv/pytorch/functional.py
+2
-6
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+71
-140
No files found.
example/fuse_bn_act.py
View file @
aa26c99e
...
@@ -207,7 +207,7 @@ class Net(nn.Module):
...
@@ -207,7 +207,7 @@ class Net(nn.Module):
pool_algo
=
algo
pool_algo
=
algo
# pool_algo = ConvAlgo.Native
# pool_algo = ConvAlgo.Native
self
.
net
=
spconv
.
SparseSequential
(
self
.
net
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
3
,
64
,
3
,
bias
=
False
,
indice_key
=
"c0"
,
spconv
.
SubMConv3d
(
16
,
64
,
3
,
bias
=
False
,
indice_key
=
"c0"
,
algo
=
algo
),
algo
=
algo
),
nn
.
BatchNorm1d
(
64
),
nn
.
BatchNorm1d
(
64
),
nn
.
ReLU
(),
nn
.
ReLU
(),
...
@@ -373,6 +373,11 @@ class Net(nn.Module):
...
@@ -373,6 +373,11 @@ class Net(nn.Module):
x
=
spconv
.
SparseConvTensor
(
features
,
coors
,
self
.
shape
,
batch_size
,
voxel_num
=
vx_num
)
x
=
spconv
.
SparseConvTensor
(
features
,
coors
,
self
.
shape
,
batch_size
,
voxel_num
=
vx_num
)
return
self
.
net
(
x
)
return
self
.
net
(
x
)
def
_set_enable_int8_test_inplace
(
simple_module
:
torch
.
fx
.
GraphModule
,
enable
:
bool
):
for
m
in
simple_module
.
modules
():
if
isinstance
(
m
,
SparseConvolution
):
if
m
.
in_channels
%
32
==
0
and
m
.
out_channels
%
32
==
0
:
m
.
enable_int8_test_mode
=
enable
class
MyTracer
(
torch
.
fx
.
Tracer
):
class
MyTracer
(
torch
.
fx
.
Tracer
):
...
@@ -387,6 +392,7 @@ def main():
...
@@ -387,6 +392,7 @@ def main():
torch
.
backends
.
cudnn
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
with
open
(
Path
(
__file__
).
parent
.
parent
/
"test"
/
"data"
/
"test_spconv.pkl"
,
"rb"
)
as
f
:
with
open
(
Path
(
__file__
).
parent
.
parent
/
"test"
/
"data"
/
"test_spconv.pkl"
,
"rb"
)
as
f
:
(
voxels
,
coors
,
spatial_shape
)
=
pickle
.
load
(
f
)
(
voxels
,
coors
,
spatial_shape
)
=
pickle
.
load
(
f
)
voxels
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
voxels
.
shape
[
0
],
16
]).
astype
(
np
.
float32
)
np
.
random
.
seed
(
50051
)
np
.
random
.
seed
(
50051
)
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
device_cpu
=
torch
.
device
(
"cpu:0"
)
device_cpu
=
torch
.
device
(
"cpu:0"
)
...
@@ -408,6 +414,10 @@ def main():
...
@@ -408,6 +414,10 @@ def main():
out_fused
=
net_fused
(
voxels_th_cuda
,
coors_th_cuda
,
1
)
out_fused
=
net_fused
(
voxels_th_cuda
,
coors_th_cuda
,
1
)
res
=
Fsp
.
sparse_add_hash_based
(
out_ref
,
out_fused
.
minus
())
res
=
Fsp
.
sparse_add_hash_based
(
out_ref
,
out_fused
.
minus
())
print
(
torch
.
linalg
.
norm
(
res
.
features
))
print
(
torch
.
linalg
.
norm
(
res
.
features
))
_set_enable_int8_test_inplace
(
net_fused
,
True
)
qvoxels_cuda
=
voxels_th_cuda
.
to
(
torch
.
int8
)
out_int8
=
net_fused
(
qvoxels_cuda
,
coors_th_cuda
,
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
example/libspconv/main.cu
View file @
aa26c99e
...
@@ -426,7 +426,7 @@ int main(int argc, char **argv) {
...
@@ -426,7 +426,7 @@ int main(int argc, char **argv) {
{
SPCONV_ALLOC_OUT_FEATURES
,
out_features
}};
{
SPCONV_ALLOC_OUT_FEATURES
,
out_features
}};
StaticAllocator
alloc2
(
tensor_dict
);
StaticAllocator
alloc2
(
tensor_dict
);
ConvTunerSimple
tuner
(
ConvMain
::
get_all_conv_algo_desp
());
ConvTunerSimple
tuner
(
ConvMain
::
get_all_conv_algo_desp
());
auto
conv_r
e
s
=
ConvGemmOps
::
implicit_gemm
(
auto
conv_r
un_statu
s
=
ConvGemmOps
::
implicit_gemm
(
alloc2
,
tuner
,
input_features_real
,
weights
,
pair_fwd_real
,
alloc2
,
tuner
,
input_features_real
,
weights
,
pair_fwd_real
,
pair_mask_splits
,
mask_argsort_splits
,
num_act_out_real
,
pair_mask_splits
,
mask_argsort_splits
,
num_act_out_real
,
mask_tensor
,
arch
,
false
,
is_subm
,
mask_tensor
,
arch
,
false
,
is_subm
,
...
@@ -435,7 +435,7 @@ int main(int argc, char **argv) {
...
@@ -435,7 +435,7 @@ int main(int argc, char **argv) {
1.0
/*bias alpha, only used for leaky relu*/
,
1.0
/*bias alpha, only used for leaky relu*/
,
0.0
/*unused for now*/
,
tv
::
gemm
::
Activation
::
kReLU
);
0.0
/*unused for now*/
,
tv
::
gemm
::
Activation
::
kReLU
);
tv
::
ssprint
(
"selected conv algo"
,
tv
::
ssprint
(
"selected conv algo"
,
std
::
get
<
1
>
(
conv_r
e
s
).
algo_desp
.
__repr__
());
std
::
get
<
1
>
(
conv_r
un_statu
s
).
algo_desp
.
__repr__
());
// FINISH!!!
// FINISH!!!
}
}
// calc maximum number of output points.
// calc maximum number of output points.
...
...
example/mnist/mnist_qat.py
0 → 100644
View file @
aa26c99e
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
torch
import
spconv.pytorch
as
spconv
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
torchvision
import
datasets
,
transforms
from
torch.optim.lr_scheduler
import
StepLR
import
contextlib
import
torch.cuda.amp
@
contextlib
.
contextmanager
def
identity_ctx
():
yield
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
nn
.
BatchNorm1d
(
1
),
spconv
.
SubMConv2d
(
1
,
32
,
3
,
1
),
nn
.
ReLU
(),
spconv
.
SubMConv2d
(
32
,
64
,
3
,
1
),
nn
.
ReLU
(),
spconv
.
SparseConv2d
(
64
,
64
,
2
,
2
),
spconv
.
ToDense
(),
)
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
x
.
reshape
(
-
1
,
28
,
28
,
1
))
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x
=
self
.
net
(
x_sp
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
dropout1
(
x
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
def
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
):
model
.
train
()
scaler
=
torch
.
cuda
.
amp
.
grad_scaler
.
GradScaler
()
amp_ctx
=
contextlib
.
nullcontext
()
if
args
.
fp16
:
amp_ctx
=
torch
.
cuda
.
amp
.
autocast
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
with
amp_ctx
:
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
scale
=
1.0
if
args
.
fp16
:
assert
loss
.
dtype
is
torch
.
float32
scaler
.
scale
(
loss
).
backward
()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
# scaler.unscale_(optim)
# Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
# You may use the same value for max_norm here as you would without gradient scaling.
# torch.nn.utils.clip_grad_norm_(models[0].net.parameters(), max_norm=0.1)
scaler
.
step
(
optimizer
)
# Updates the scale for next iteration.
scaler
.
update
()
scale
=
scaler
.
get_scale
()
else
:
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
args
.
log_interval
==
0
:
print
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}'
.
format
(
epoch
,
batch_idx
*
len
(
data
),
len
(
train_loader
.
dataset
),
100.
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
args
,
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
correct
=
0
amp_ctx
=
contextlib
.
nullcontext
()
if
args
.
fp16
:
amp_ctx
=
torch
.
cuda
.
amp
.
autocast
()
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
with
amp_ctx
:
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
# sum up batch loss
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
# get the index of the max log-probability
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'
\n
Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
\n
'
.
format
(
test_loss
,
correct
,
len
(
test_loader
.
dataset
),
100.
*
correct
/
len
(
test_loader
.
dataset
)))
def
main
():
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
64
,
metavar
=
'N'
,
help
=
'input batch size for training (default: 64)'
)
parser
.
add_argument
(
'--test-batch-size'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
help
=
'input batch size for testing (default: 1000)'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
14
,
metavar
=
'N'
,
help
=
'number of epochs to train (default: 14)'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
1.0
,
metavar
=
'LR'
,
help
=
'learning rate (default: 1.0)'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.7
,
metavar
=
'M'
,
help
=
'Learning rate step gamma (default: 0.7)'
)
parser
.
add_argument
(
'--no-cuda'
,
action
=
'store_true'
,
default
=
False
,
help
=
'disables CUDA training'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1
,
metavar
=
'S'
,
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
help
=
'how many batches to wait before logging training status'
)
parser
.
add_argument
(
'--save-model'
,
action
=
'store_true'
,
default
=
False
,
help
=
'For Saving the current Model'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
default
=
False
,
help
=
'For mixed precision training'
)
args
=
parser
.
parse_args
()
use_cuda
=
not
args
.
no_cuda
and
torch
.
cuda
.
is_available
()
torch
.
manual_seed
(
args
.
seed
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'../data'
,
train
=
True
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
**
kwargs
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'../data'
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size
=
args
.
test_batch_size
,
shuffle
=
True
,
**
kwargs
)
model
=
Net
().
to
(
device
)
optimizer
=
optim
.
Adadelta
(
model
.
parameters
(),
lr
=
args
.
lr
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
1
,
gamma
=
args
.
gamma
)
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
)
test
(
args
,
model
,
device
,
test_loader
)
scheduler
.
step
()
if
args
.
save_model
:
torch
.
save
(
model
.
state_dict
(),
"mnist_cnn.pt"
)
if
__name__
==
'__main__'
:
main
()
example/mnist_sparse.py
→
example/mnist
/mnist
_sparse.py
View file @
aa26c99e
File moved
pyproject.toml
View file @
aa26c99e
[build-system]
[build-system]
requires
=
[
"setuptools>=41.0"
,
"wheel"
,
"pccm>=0.4.0"
,
"cumm>=0.3.7"
]
requires
=
[
"setuptools>=41.0"
,
"wheel"
,
"pccm>=0.4.0"
,
"cumm>=0.3.7"
]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu1
18
-0.3.
4
-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu1
20
-0.3.
7
-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend
=
"setuptools.build_meta"
build-backend
=
"setuptools.build_meta"
spconv/algo.py
View file @
aa26c99e
...
@@ -616,6 +616,7 @@ class SimpleConv:
...
@@ -616,6 +616,7 @@ class SimpleConv:
algocore
.
get_conv_algo_desp_from_param
(
p
)
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
for
p
in
ALL_IMPGEMM_PARAMS
]
]
self
.
all_desps
=
all_desps
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
...
@@ -648,13 +649,13 @@ class SimpleConv:
...
@@ -648,13 +649,13 @@ class SimpleConv:
tile_ms_list
,
tile_ns_list
,
tile_ks_list
,
tile_shape_to_algos
)
tile_ms_list
,
tile_ns_list
,
tile_ks_list
,
tile_shape_to_algos
)
self
.
kc_forward_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
self
.
kc_forward_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
],
int
,
bool
],
BestConvAlgoByProfile
]
=
{}
# for forward
BestConvAlgoByProfile
]
=
{}
# for forward
self
.
kc_dgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
self
.
kc_dgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
],
BestConvAlgoByProfile
]
=
{
int
,
bool
],
BestConvAlgoByProfile
]
=
{
}
# for backward weight
}
# for backward weight
self
.
kc_wgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
self
.
kc_wgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
],
BestConvAlgoByProfile
]
=
{
int
,
bool
],
BestConvAlgoByProfile
]
=
{
}
# for backward weight
}
# for backward weight
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
]],
NVRTCParams
]
=
{}
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
]],
NVRTCParams
]
=
{}
...
@@ -679,11 +680,12 @@ class SimpleConv:
...
@@ -679,11 +680,12 @@ class SimpleConv:
op_type
:
ConvOpType
,
op_type
:
ConvOpType
,
mask_width
:
int
,
mask_width
:
int
,
fp32_accum
:
Optional
[
bool
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
use_tf32
:
bool
=
True
):
use_tf32
:
bool
=
True
,
bias
:
tv
.
Tensor
=
tv
.
Tensor
(),
scale
:
tv
.
Tensor
=
tv
.
Tensor
()):
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
finally_algos
:
List
[
ConvAlgoDesp
]
=
[]
finally_algos
:
List
[
ConvAlgoDesp
]
=
[]
is_fp16
=
inp
.
dtype
==
tv
.
float16
and
weight
.
dtype
==
tv
.
float16
and
out
.
dtype
==
tv
.
float16
is_fp16
=
inp
.
dtype
==
tv
.
float16
and
weight
.
dtype
==
tv
.
float16
#
and out.dtype == tv.float16
use_f32_as_accum
=
False
use_f32_as_accum
=
False
kv
=
int
(
np
.
prod
(
weight
.
shape
[
1
:
-
1
]))
kv
=
int
(
np
.
prod
(
weight
.
shape
[
1
:
-
1
]))
# for 3d conv, if reduce axis is too large, may cause nan during
# for 3d conv, if reduce axis is too large, may cause nan during
...
@@ -703,6 +705,10 @@ class SimpleConv:
...
@@ -703,6 +705,10 @@ class SimpleConv:
layout_w
.
interleave
,
layout_o
.
interleave
,
inp
.
dtype
,
layout_w
.
interleave
,
layout_o
.
interleave
,
inp
.
dtype
,
weight
.
dtype
,
out
.
dtype
,
op_type
.
value
)
weight
.
dtype
,
out
.
dtype
,
op_type
.
value
)
desps
=
self
.
static_key_to_desps
.
get
(
static_key
,
None
)
desps
=
self
.
static_key_to_desps
.
get
(
static_key
,
None
)
# for d in self.all_desps:
# print(d)
# print(len(desps))
# breakpoint()
if
desps
is
None
or
len
(
desps
)
==
0
:
if
desps
is
None
or
len
(
desps
)
==
0
:
return
finally_algos
return
finally_algos
for
desp
in
desps
:
for
desp
in
desps
:
...
@@ -726,11 +732,21 @@ class SimpleConv:
...
@@ -726,11 +732,21 @@ class SimpleConv:
ldw
=
weight
.
dim
(
-
1
)
ldw
=
weight
.
dim
(
-
1
)
ldo
=
out
.
dim
(
-
1
)
ldo
=
out
.
dim
(
-
1
)
mask_width_valid
=
True
mask_width_valid
=
True
if
desp
.
op_type
.
value
==
ConvOpType
.
kBackwardWeight
.
value
:
if
desp
.
op_type
.
value
==
ConvOpType
.
kBackwardWeight
.
value
:
assert
mask_width
>
0
assert
mask_width
>
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
require_dynamic_mask
=
kv
>
32
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
not
bias
.
empty
()
and
not
scale
.
empty
():
# int8 inference, bias/scale dtype must equal to compute dtype in gemm
assert
bias
.
dtype
==
scale
.
dtype
if
desp
.
dcomp
!=
bias
.
dtype
:
continue
if
not
desp
.
is_int8_inference
:
continue
else
:
if
desp
.
is_int8_inference
:
continue
if
desp
.
is_nvrtc
:
if
desp
.
is_nvrtc
:
if
not
CompileInfo
.
algo_can_be_nvrtc_compiled
(
desp
.
min_arch
):
if
not
CompileInfo
.
algo_can_be_nvrtc_compiled
(
desp
.
min_arch
):
continue
continue
...
@@ -747,6 +763,12 @@ class SimpleConv:
...
@@ -747,6 +763,12 @@ class SimpleConv:
continue
continue
if
SPCONV_DEBUG_NVRTC_KERNELS
:
if
SPCONV_DEBUG_NVRTC_KERNELS
:
desp
.
is_nvrtc
=
True
desp
.
is_nvrtc
=
True
if
require_dynamic_mask
:
if
not
desp
.
dynamic_mask
:
continue
else
:
if
desp
.
dynamic_mask
:
continue
finally_algos
.
append
(
desp
)
finally_algos
.
append
(
desp
)
return
finally_algos
return
finally_algos
...
@@ -758,11 +780,12 @@ class SimpleConv:
...
@@ -758,11 +780,12 @@ class SimpleConv:
k
:
int
,
k
:
int
,
c
:
int
,
c
:
int
,
arch
:
Tuple
[
int
,
int
],
arch
:
Tuple
[
int
,
int
],
mask_width
:
int
=
-
1
):
mask_width
:
int
=
-
1
,
need_dynamic_mask
:
bool
=
False
):
if
not
op_type
==
ConvOpType
.
kBackwardWeight
:
if
not
op_type
==
ConvOpType
.
kBackwardWeight
:
# fwd and dgrad don't need
# fwd and dgrad don't need
mask_width
=
-
1
mask_width
=
-
1
key
=
(
i_dtype
,
w_dtype
,
o_dtype
,
k
,
c
,
arch
[
0
],
arch
[
1
],
mask_width
)
key
=
(
i_dtype
,
w_dtype
,
o_dtype
,
k
,
c
,
arch
[
0
],
arch
[
1
],
mask_width
,
need_dynamic_mask
)
if
op_type
==
ConvOpType
.
kForward
:
if
op_type
==
ConvOpType
.
kForward
:
return
self
.
kc_forward_cache
.
get
(
key
,
None
)
return
self
.
kc_forward_cache
.
get
(
key
,
None
)
elif
op_type
==
ConvOpType
.
kBackwardInput
:
elif
op_type
==
ConvOpType
.
kBackwardInput
:
...
@@ -795,8 +818,9 @@ class SimpleConv:
...
@@ -795,8 +818,9 @@ class SimpleConv:
cudadevrt
=
str
(
cudadevrt_p
)
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
verbose
=
True
,
custom_names
=
custom_names
)
custom_names
=
custom_names
,
verbose_path
=
"/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8"
)
mod
.
load
()
mod
.
load
()
return
mod
,
kernel
return
mod
,
kernel
...
@@ -824,7 +848,6 @@ class SimpleConv:
...
@@ -824,7 +848,6 @@ class SimpleConv:
mask_argsort
:
tv
.
Tensor
,
mask_argsort
:
tv
.
Tensor
,
indices
:
tv
.
Tensor
,
indices
:
tv
.
Tensor
,
reverse_mask
:
bool
,
reverse_mask
:
bool
,
mask_int_count
:
int
=
1
,
mask_filter
:
int
=
0xffffffff
,
mask_filter
:
int
=
0xffffffff
,
mask_width
:
int
=
-
1
,
mask_width
:
int
=
-
1
,
mask_output
:
tv
.
Tensor
=
tv
.
Tensor
(),
mask_output
:
tv
.
Tensor
=
tv
.
Tensor
(),
...
@@ -832,17 +855,20 @@ class SimpleConv:
...
@@ -832,17 +855,20 @@ class SimpleConv:
beta
:
float
=
0.0
,
beta
:
float
=
0.0
,
stream
:
int
=
0
,
stream
:
int
=
0
,
fp32_accum
:
Optional
[
bool
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
use_tf32
:
bool
=
True
):
use_tf32
:
bool
=
True
,
bias
:
tv
.
Tensor
=
tv
.
Tensor
(),
scale
:
tv
.
Tensor
=
tv
.
Tensor
()):
avail
=
self
.
get_all_available
(
inp
,
weight
,
output
,
layout_i
,
layout_w
,
avail
=
self
.
get_all_available
(
inp
,
weight
,
output
,
layout_i
,
layout_w
,
layout_o
,
arch
,
op_type
,
mask_width
,
layout_o
,
arch
,
op_type
,
mask_width
,
fp32_accum
,
use_tf32
)
fp32_accum
,
use_tf32
,
bias
,
scale
)
inp
=
inp
.
clone
()
inp
=
inp
.
clone
()
weight
=
weight
.
clone
()
weight
=
weight
.
clone
()
output
=
output
.
clone
()
output
=
output
.
clone
()
print
(
len
(
avail
),
inp
.
dtype
,
weight
.
dtype
,
output
.
dtype
,
bias
.
dtype
,
scale
.
dtype
,
bias
.
empty
(),
scale
.
empty
())
channel_k
=
output
.
dim
(
1
)
channel_k
=
output
.
dim
(
1
)
channel_c
=
inp
.
dim
(
1
)
channel_c
=
inp
.
dim
(
1
)
weight
=
weight
.
view
([
channel_k
,
-
1
,
channel_c
])
need_dynamic_mask
=
weight
.
dim
(
1
)
>
32
times
:
List
[
float
]
=
[]
times
:
List
[
float
]
=
[]
all_profile_res
:
List
[
BestConvAlgoByProfile
]
=
[]
all_profile_res
:
List
[
BestConvAlgoByProfile
]
=
[]
group_by_algo
=
{}
group_by_algo
=
{}
...
@@ -865,8 +891,9 @@ class SimpleConv:
...
@@ -865,8 +891,9 @@ class SimpleConv:
params
.
indices
=
indices
params
.
indices
=
indices
params
.
mask
=
mask
params
.
mask
=
mask
params
.
mask_output
=
mask_output
params
.
mask_output
=
mask_output
params
.
mask_int_count
=
mask_int_count
if
desp
.
is_int8_inference
:
params
.
bias
=
bias
params
.
scale
=
scale
# if op_type == ConvOpType.kBackwardWeight:
# if op_type == ConvOpType.kBackwardWeight:
# assert not mask_output.empty()
# assert not mask_output.empty()
if
op_type
==
ConvOpType
.
kBackwardInput
:
if
op_type
==
ConvOpType
.
kBackwardInput
:
...
@@ -909,7 +936,7 @@ class SimpleConv:
...
@@ -909,7 +936,7 @@ class SimpleConv:
# fwd and dgrad don't need
# fwd and dgrad don't need
mask_width
=
-
1
mask_width
=
-
1
key
=
(
inp
.
dtype
,
weight
.
dtype
,
output
.
dtype
,
channel_k
,
channel_c
,
key
=
(
inp
.
dtype
,
weight
.
dtype
,
output
.
dtype
,
channel_k
,
channel_c
,
arch
[
0
],
arch
[
1
],
mask_width
)
arch
[
0
],
arch
[
1
],
mask_width
,
need_dynamic_mask
)
with
self
.
lock
:
with
self
.
lock
:
if
op_type
==
ConvOpType
.
kForward
:
if
op_type
==
ConvOpType
.
kForward
:
self
.
kc_forward_cache
[
key
]
=
res
self
.
kc_forward_cache
[
key
]
=
res
...
@@ -945,7 +972,9 @@ class SimpleConv:
...
@@ -945,7 +972,9 @@ class SimpleConv:
act_alpha
:
float
=
0.0
,
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
mask_int_count
:
Union
[
int
,
None
]
=
None
):
scale
:
Optional
[
tv
.
Tensor
]
=
None
,
output_add
:
Optional
[
tv
.
Tensor
]
=
None
):
channel_k
=
output
.
dim
(
1
)
channel_k
=
output
.
dim
(
1
)
channel_c
=
inp
.
dim
(
1
)
channel_c
=
inp
.
dim
(
1
)
# GemmMainUnitTest.stream_synchronize(stream)
# GemmMainUnitTest.stream_synchronize(stream)
...
@@ -986,9 +1015,12 @@ class SimpleConv:
...
@@ -986,9 +1015,12 @@ class SimpleConv:
params
.
mask_filter
=
mask_filter
params
.
mask_filter
=
mask_filter
params
.
mask_output
=
mask_output
params
.
mask_output
=
mask_output
params
.
reverse_mask
=
reverse_mask
params
.
reverse_mask
=
reverse_mask
params
.
mask_int_count
=
mask_int_count
if
bias
is
not
None
:
if
bias
is
not
None
:
params
.
bias
=
bias
params
.
bias
=
bias
if
output_add
is
not
None
and
algo_desp
.
is_int8_inference
:
params
.
output_add
=
output_add
if
scale
is
not
None
and
algo_desp
.
is_int8_inference
:
params
.
scale
=
scale
if
timer
.
enable
:
if
timer
.
enable
:
assert
timer
.
_timer
is
not
None
assert
timer
.
_timer
is
not
None
params
.
timer
=
timer
.
_timer
params
.
timer
=
timer
.
_timer
...
...
spconv/algocore.py
View file @
aa26c99e
...
@@ -36,8 +36,8 @@ from cumm.gemm.algospec import TensorOp
...
@@ -36,8 +36,8 @@ from cumm.gemm.algospec import TensorOp
def
_assign_gemm_desp_props
(
desp
:
Union
[
ConvAlgoDesp
,
GemmAlgoDesp
],
def
_assign_gemm_desp_props
(
desp
:
Union
[
ConvAlgoDesp
,
GemmAlgoDesp
],
p
:
Union
[
GemmAlgoParams
,
ConvAlgoParams
]):
p
:
Union
[
GemmAlgoParams
,
ConvAlgoParams
]):
desp
.
dtype_a
=
p
.
dtype_a
.
tv_dtype
desp
.
dtype_a
=
p
.
dtype_a
.
tv_dtype
desp
.
dtype_b
=
p
.
dtype_
a
.
tv_dtype
desp
.
dtype_b
=
p
.
dtype_
b
.
tv_dtype
desp
.
dtype_c
=
p
.
dtype_
a
.
tv_dtype
desp
.
dtype_c
=
p
.
dtype_
c
.
tv_dtype
desp
.
dacc
=
p
.
dtype_acc
.
tv_dtype
desp
.
dacc
=
p
.
dtype_acc
.
tv_dtype
desp
.
dcomp
=
p
.
dtype_comp
.
tv_dtype
desp
.
dcomp
=
p
.
dtype_comp
.
tv_dtype
desp
.
trans_a
=
p
.
trans_a
desp
.
trans_a
=
p
.
trans_a
...
@@ -87,6 +87,9 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams):
...
@@ -87,6 +87,9 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp
.
element_per_access_a
=
ker
.
input_spec
.
input_iter_a
.
element_per_acc
desp
.
element_per_access_a
=
ker
.
input_spec
.
input_iter_a
.
element_per_acc
desp
.
element_per_access_b
=
ker
.
input_spec
.
input_iter_b
.
element_per_acc
desp
.
element_per_access_b
=
ker
.
input_spec
.
input_iter_b
.
element_per_acc
desp
.
element_per_access_c
=
ker
.
output_spec
.
out_iter
.
element_per_acc
desp
.
element_per_access_c
=
ker
.
output_spec
.
out_iter
.
element_per_acc
desp
.
is_int8_inference
=
ker
.
int8_inference
desp
.
dynamic_mask
=
ker
.
dynamic_mask
desp
.
min_arch
=
ker
.
min_arch
()
desp
.
min_arch
=
ker
.
min_arch
()
return
desp
return
desp
...
@@ -141,4 +144,6 @@ def get_conv_param_from_desp(desp: ConvAlgoDesp):
...
@@ -141,4 +144,6 @@ def get_conv_param_from_desp(desp: ConvAlgoDesp):
desp
.
interleave_o
)
desp
.
interleave_o
)
p
.
mask_sparse
=
desp
.
mask_sparse
p
.
mask_sparse
=
desp
.
mask_sparse
p
.
increment_k_first
=
desp
.
increment_k_first
p
.
increment_k_first
=
desp
.
increment_k_first
p
.
int8_inference
=
desp
.
is_int8_inference
p
.
dynamic_mask
=
desp
.
dynamic_mask
return
p
return
p
spconv/build.py
View file @
aa26c99e
...
@@ -39,12 +39,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
...
@@ -39,12 +39,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from
spconv.csrc.sparse.inference
import
InferenceOps
from
spconv.csrc.sparse.inference
import
InferenceOps
all_shuffle
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_AMPERE_PARAMS
all_shuffle
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_AMPERE_PARAMS
all_shuffle
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_shuffle
))
#
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu
=
GemmMainUnitTest
(
all_shuffle
)
cu
=
GemmMainUnitTest
(
all_shuffle
)
cu
.
namespace
=
"cumm.gemm.main"
cu
.
namespace
=
"cumm.gemm.main"
all_imp
=
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
all_imp
=
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_AMPERE_PARAMS
)
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_AMPERE_PARAMS
)
all_imp
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_imp
))
#
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu
=
ConvMainUnitTest
(
all_imp
)
convcu
=
ConvMainUnitTest
(
all_imp
)
convcu
.
namespace
=
"cumm.conv.main"
convcu
.
namespace
=
"cumm.conv.main"
gemmtuner
=
GemmTunerSimple
(
cu
)
gemmtuner
=
GemmTunerSimple
(
cu
)
...
...
spconv/core.py
View file @
aa26c99e
...
@@ -619,14 +619,11 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -619,14 +619,11 @@ IMPLGEMM_AMPERE_PARAMS = [
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -635,13 +632,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -635,13 +632,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -650,13 +648,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -650,13 +648,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -665,13 +664,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -665,13 +664,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -680,13 +680,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -680,13 +680,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -695,13 +696,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -695,13 +696,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -710,13 +712,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -710,13 +712,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -725,13 +728,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -725,13 +728,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -740,13 +744,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -740,13 +744,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -755,13 +760,14 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -755,13 +760,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
],
[
2
,
3
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -770,7 +776,8 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -770,7 +776,8 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
]
]
IMPLGEMM_TURING_PARAMS
=
[
IMPLGEMM_TURING_PARAMS
=
[
...
@@ -779,7 +786,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -779,7 +786,7 @@ IMPLGEMM_TURING_PARAMS = [
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -788,13 +795,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -788,13 +795,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -803,13 +811,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -803,13 +811,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -818,13 +827,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -818,13 +827,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -833,13 +843,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -833,13 +843,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -848,13 +859,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -848,13 +859,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -863,13 +875,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -863,13 +875,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -878,13 +891,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -878,13 +891,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -893,13 +907,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -893,13 +907,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -908,13 +923,14 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -908,13 +923,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
2
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
NHWC
,
...
@@ -923,7 +939,8 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -923,7 +939,8 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
32
,
16
,
16
),
(
16
,
16
,
16
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
32
,
16
,
16
),
(
16
,
16
,
16
),
...
...
spconv/core_cc/csrc/sparse/all/__init__.pyi
View file @
aa26c99e
...
@@ -144,7 +144,7 @@ class SpconvOps:
...
@@ -144,7 +144,7 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0
, mask_int_count: int = 1
) -> int:
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
...
@@ -167,11 +167,10 @@ class SpconvOps:
...
@@ -167,11 +167,10 @@ class SpconvOps:
dilation:
dilation:
transposed:
transposed:
stream_int:
stream_int:
mask_int_count:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0
, mask_int_count: int = 1
) -> int:
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
...
@@ -194,11 +193,10 @@ class SpconvOps:
...
@@ -194,11 +193,10 @@ class SpconvOps:
dilation:
dilation:
transposed:
transposed:
stream_int:
stream_int:
mask_int_count:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0
, mask_int_count: int = 1
) -> int:
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
...
@@ -214,7 +212,6 @@ class SpconvOps:
...
@@ -214,7 +212,6 @@ class SpconvOps:
indice_pair_mask:
indice_pair_mask:
backward:
backward:
stream_int:
stream_int:
mask_int_count:
"""
"""
...
...
@staticmethod
@staticmethod
...
@@ -383,65 +380,25 @@ class SpconvOps:
...
@@ -383,65 +380,25 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def sort_1d_by_key_allocator
_mask32
(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0
, mask_count: int = 1
) -> Tensor:
"""
"""
Args:
Args:
data:
data:
alloc_func:
alloc_func:
indices:
indices:
stream:
stream:
mask_count:
"""
"""
...
...
@staticmethod
@staticmethod
def sort_1d_by_key_allocator_
mask32_
v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0
, mask_count: int = 1
) -> Tensor:
"""
"""
Args:
Args:
data:
data:
allocator:
allocator:
indices:
indices:
stream:
stream:
"""
mask_count:
...
@staticmethod
def sort_1d_by_key_allocator_mask128(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto_v2(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
"""
...
...
@staticmethod
@staticmethod
...
@@ -598,7 +555,7 @@ class SpconvOps:
...
@@ -598,7 +555,7 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor,
int,
int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
"""
"""
Args:
Args:
allocator:
allocator:
...
...
spconv/core_cc/csrc/sparse/convops/convops.pyi
View file @
aa26c99e
...
@@ -20,7 +20,7 @@ class ConvTunerSimple:
...
@@ -20,7 +20,7 @@ class ConvTunerSimple:
arch:
arch:
"""
"""
...
...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True) -> List[ConvAlgoDesp]:
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True
, bias: Tensor = Tensor(), scale: Tensor = Tensor()
) -> List[ConvAlgoDesp]:
"""
"""
Args:
Args:
inp:
inp:
...
@@ -38,6 +38,8 @@ class ConvTunerSimple:
...
@@ -38,6 +38,8 @@ class ConvTunerSimple:
auto_fp32_accum:
auto_fp32_accum:
fp32_accum:
fp32_accum:
use_tf32:
use_tf32:
bias:
scale:
"""
"""
...
...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
...
@@ -48,7 +50,7 @@ class ConvTunerSimple:
...
@@ -48,7 +50,7 @@ class ConvTunerSimple:
stream_int:
stream_int:
"""
"""
...
...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0,
mask_int_count: int = 1,
auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]:
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True
, bias: Tensor = Tensor(), scale: Tensor = Tensor()
) -> Tuple[ConvTuneResult, float]:
"""
"""
Args:
Args:
op_type:
op_type:
...
@@ -72,14 +74,15 @@ class ConvTunerSimple:
...
@@ -72,14 +74,15 @@ class ConvTunerSimple:
alpha:
alpha:
beta:
beta:
stream_int:
stream_int:
mask_int_count:
auto_fp32_accum:
auto_fp32_accum:
fp32_accum:
fp32_accum:
num_run:
num_run:
use_tf32:
use_tf32:
bias:
scale:
"""
"""
...
...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]:
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1
, need_dynamic_mask: bool = False
) -> Tuple[Any, bool]:
"""
"""
Args:
Args:
op_type:
op_type:
...
@@ -90,9 +93,10 @@ class ConvTunerSimple:
...
@@ -90,9 +93,10 @@ class ConvTunerSimple:
c:
c:
arch:
arch:
mask_width:
mask_width:
need_dynamic_mask:
"""
"""
...
...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0,
mask_int_count: int = 1,
workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None:
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_
, scale: Tensor = Tensor(), output_add: Tensor = Tensor()
) -> None:
"""
"""
Args:
Args:
profile_res:
profile_res:
...
@@ -110,7 +114,6 @@ class ConvTunerSimple:
...
@@ -110,7 +114,6 @@ class ConvTunerSimple:
alpha:
alpha:
beta:
beta:
stream_int:
stream_int:
mask_int_count:
workspace:
workspace:
verbose:
verbose:
timer:
timer:
...
@@ -119,6 +122,8 @@ class ConvTunerSimple:
...
@@ -119,6 +122,8 @@ class ConvTunerSimple:
act_alpha:
act_alpha:
act_beta:
act_beta:
act_type:
act_type:
scale:
output_add:
"""
"""
...
...
def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int:
def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int:
...
...
spconv/core_cc/csrc/sparse/convops/spops.pyi
View file @
aa26c99e
...
@@ -63,7 +63,7 @@ class ConvGemmOps:
...
@@ -63,7 +63,7 @@ class ConvGemmOps:
"""
"""
...
...
@staticmethod
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor,
mask_int_count: int,
arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True
, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0
) -> Tuple[int, Any]:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -75,7 +75,6 @@ class ConvGemmOps:
...
@@ -75,7 +75,6 @@ class ConvGemmOps:
mask_argsort_fwd_splits:
mask_argsort_fwd_splits:
num_activate_out:
num_activate_out:
masks:
masks:
mask_int_count:
arch:
arch:
is_train:
is_train:
is_subm:
is_subm:
...
@@ -88,10 +87,14 @@ class ConvGemmOps:
...
@@ -88,10 +87,14 @@ class ConvGemmOps:
act_beta:
act_beta:
act_type:
act_type:
use_tf32:
use_tf32:
output_scale:
scale:
output_add:
output_add_scale:
"""
"""
...
...
@staticmethod
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor,
mask_int_count: int,
arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -107,7 +110,6 @@ class ConvGemmOps:
...
@@ -107,7 +110,6 @@ class ConvGemmOps:
mask_argsort_bwd_splits:
mask_argsort_bwd_splits:
mask_output_fwd:
mask_output_fwd:
masks:
masks:
mask_int_count:
arch:
arch:
mask_width:
mask_width:
is_subm:
is_subm:
...
...
spconv/csrc/sparse/all.py
View file @
aa26c99e
...
@@ -30,7 +30,7 @@ from .alloc import ExternalAllocator, ThrustAllocator
...
@@ -30,7 +30,7 @@ from .alloc import ExternalAllocator, ThrustAllocator
from
spconv.constants
import
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
,
AllocKeys
from
spconv.constants
import
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
,
AllocKeys
import
re
import
re
import
os
import
os
from
cumm.gemm.codeops
import
dispatch
class
CustomThrustLib
(
pccm
.
Class
):
class
CustomThrustLib
(
pccm
.
Class
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -462,7 +462,6 @@ class SpconvOps(pccm.Class):
...
@@ -462,7 +462,6 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
int ndim = indices.dim(1) - 1;
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...
@@ -489,7 +488,7 @@ class SpconvOps(pccm.Class):
...
@@ -489,7 +488,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort,
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int
, mask_int_count
);
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
}}
"""
)
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
...
@@ -513,7 +512,6 @@ class SpconvOps(pccm.Class):
...
@@ -513,7 +512,6 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
int ndim = indices.dim(1) - 1;
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...
@@ -540,7 +538,7 @@ class SpconvOps(pccm.Class):
...
@@ -540,7 +538,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort,
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int
, mask_int_count
);
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
}}
"""
)
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
...
@@ -561,7 +559,6 @@ class SpconvOps(pccm.Class):
...
@@ -561,7 +559,6 @@ class SpconvOps(pccm.Class):
"cumm.tensorview.Tensor = Tensor()"
)
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"backward"
,
"bool"
,
"false"
)
code
.
arg
(
"backward"
,
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int = 0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int = 0"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
int ndim = indices.dim(1) - 1;
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
...
@@ -582,7 +579,7 @@ class SpconvOps(pccm.Class):
...
@@ -582,7 +579,7 @@ class SpconvOps(pccm.Class):
indice_pairs, out_inds, indice_num_per_loc,
indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_,
batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward,
ksize_, dilation_, indice_pair_mask, backward,
stream_int
, mask_int_count
);
stream_int);
}}
}}
"""
)
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
...
@@ -909,7 +906,7 @@ class SpconvOps(pccm.Class):
...
@@ -909,7 +906,7 @@ class SpconvOps(pccm.Class):
"""
)
"""
)
return
code
return
code
def
sort_1d_by_key_allocator_template
(
self
,
use_allocator
:
bool
,
int_count
:
int
=
1
):
def
sort_1d_by_key_allocator_template
(
self
,
use_allocator
:
bool
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
return
code
.
make_invalid
()
...
@@ -924,18 +921,7 @@ class SpconvOps(pccm.Class):
...
@@ -924,18 +921,7 @@ class SpconvOps(pccm.Class):
"tv::Tensor()"
,
"tv::Tensor()"
,
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
code_after_include
=
f
"""
code
.
arg
(
"mask_count"
,
"int"
,
"1"
,
pyanno
=
"int"
)
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
self
.
cuda_common_kernel
)
code
.
add_param_class
(
"cudakers"
,
self
.
cuda_common_kernel
)
if
not
use_allocator
:
if
not
use_allocator
:
...
@@ -945,20 +931,29 @@ class SpconvOps(pccm.Class):
...
@@ -945,20 +931,29 @@ class SpconvOps(pccm.Class):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
if (indices.empty()){{
indices = tv::empty({{data.dim(0)
/
{
int_count
}
}}, tv::int32, 0);
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
// auto timer = tv::CUDATimer();
// auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
"""
)
using T_ = TV_DECLTYPE(I);
# nested tv::dispatch may cause compiler bug in msvc.
using T =
{
"T_"
if
int_count
==
1
else
f
"thrust::tuple<
{
', '
.
join
([
'T_'
]
*
int_count
)
}
>
"
}
;
for
dtype
in
dispatch
(
code
,
[
dtypes
.
int32
,
dtypes
.
int64
,
dtypes
.
uint32
,
dtypes
.
uint64
],
"data.dtype()"
):
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
code
.
raw
(
f
"""
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
using T_ =
{
dtype
}
;
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
tv::dispatch_int<1, 2, 3, 4>(mask_count, [&](auto IV){{
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
constexpr int I = TV_DECLTYPE(IV)::value;
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0) /
{
int_count
}
, ptr_k);
// we can't use thrust::tuple in mp_repeat_c directly because
}});
// thrust tuple actually has fixed size template arguments.
using T = tv::mp_rename<tv::mp_repeat_c<tv::mp_list<T_>, I>, thrust::tuple>;
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k);
}});
"""
)
code
.
raw
(
f
"""
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
return indices;
"""
)
"""
)
...
@@ -967,71 +962,12 @@ class SpconvOps(pccm.Class):
...
@@ -967,71 +962,12 @@ class SpconvOps(pccm.Class):
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
_STATIC_FUNCTION
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_mask32
(
self
):
def
sort_1d_by_key_allocator
(
self
):
# for python
return
self
.
sort_1d_by_key_allocator_template
(
False
)
return
self
.
sort_1d_by_key_allocator_template
(
False
)
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_mask32_v2
(
self
):
# for python
return
self
.
sort_1d_by_key_allocator_template
(
True
)
@
pccm
.
pybind
.
mark
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_mask128
(
self
):
# for python
return
self
.
sort_1d_by_key_allocator_template
(
False
,
4
)
@
pccm
.
pybind
.
mark
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_mask128_v2
(
self
):
# for python
return
self
.
sort_1d_by_key_allocator_template
(
True
,
4
)
def
sort_1d_by_key_allocator_mask_auto_template
(
self
,
use_allocator
:
bool
):
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
code
.
arg
(
"data"
,
"tv::Tensor"
)
if
not
use_allocator
:
code
.
arg
(
"alloc_param"
,
"std::function<std::uintptr_t(std::size_t)>"
)
else
:
code
.
arg
(
"alloc_param"
,
"ThrustAllocator&"
)
code
.
arg
(
"indices"
,
"tv::Tensor"
,
"tv::Tensor()"
,
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
switch (mask_int_count){{
case 1:
return sort_1d_by_key_allocator_mask32
{
"_v2"
if
use_allocator
else
""
}
(data, alloc_param, indices, stream);
case 4:
return sort_1d_by_key_allocator_mask128
{
"_v2"
if
use_allocator
else
""
}
(data, alloc_param, indices, stream);
default:
TV_ASSERT_RT_ERR(false, "Not implement for other mask_int_count");
return tv::Tensor();
}}
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
sort_1d_by_key_allocator_mask_auto
(
self
):
return
self
.
sort_1d_by_key_allocator_mask_auto_template
(
False
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
sort_1d_by_key_allocator_mask_auto_v2
(
self
):
return
self
.
sort_1d_by_key_allocator_mask_auto_template
(
True
)
@
_STATIC_FUNCTION
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_v2
(
self
):
def
sort_1d_by_key_allocator_v2
(
self
):
# for cpp only
return
self
.
sort_1d_by_key_allocator_template
(
True
)
return
self
.
sort_1d_by_key_allocator_template
(
True
)
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
...
@@ -1622,7 +1558,7 @@ class SpconvOps(pccm.Class):
...
@@ -1622,7 +1558,7 @@ class SpconvOps(pccm.Class):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
int hash_size = 2 * num_act_out_bound;
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
if (direct_table){{
hash_size = int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory);
hash_size =
tv::align_up(
int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory)
, 2)
;
}}
}}
size_t res = 0;
size_t res = 0;
if (subm){{
if (subm){{
...
@@ -1655,7 +1591,7 @@ class SpconvOps(pccm.Class):
...
@@ -1655,7 +1591,7 @@ class SpconvOps(pccm.Class):
max_act_out_in_theory, subm, use_int64_hash_k, direct_table);
max_act_out_in_theory, subm, use_int64_hash_k, direct_table);
int hash_size = 2 * num_act_out_bound;
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
if (direct_table){{
hash_size = int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory);
hash_size =
tv::align_up(
int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory)
, 2)
;
}}
}}
if (use_int64_hash_k){{
if (use_int64_hash_k){{
auto ten = tv::from_blob(workspace, {{int64_t(hash_size)}}, tv::int64, 0);
auto ten = tv::from_blob(workspace, {{int64_t(hash_size)}}, tv::int64, 0);
...
@@ -1720,10 +1656,10 @@ class SpconvOps(pccm.Class):
...
@@ -1720,10 +1656,10 @@ class SpconvOps(pccm.Class):
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
int mask_int_count =
(kv + 31) /
32;
int mask_int_count =
tv::div_up(kv,
32
)
;
if (mask_int_count > 1 && mask_int_count < 4)
//
if (mask_int_count > 1 && mask_int_count < 4)
mask_int_count = 4;
//
mask_int_count = 4;
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
//
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
// TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
// TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int> out_shape;
std::vector<int> out_shape;
if (!subm){{
if (!subm){{
...
@@ -1845,14 +1781,14 @@ class SpconvOps(pccm.Class):
...
@@ -1845,14 +1781,14 @@ class SpconvOps(pccm.Class):
pair_mask = preallocated.at(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
);
pair_mask = preallocated.at(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
);
}}else{{
}}else{{
pair_mask = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
pair_mask = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
{{mask_split_count, num_act_in
*
mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_in
,
mask_int_count}}, tv::uint32, 0, stream_int);
}}
}}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int
, mask_int_count
);
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_
mask_auto_
v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
}}
}}
"""
)
"""
)
with
code
.
else_
():
with
code
.
else_
():
...
@@ -1958,11 +1894,11 @@ Your Conv Params: )" << "\\n";
...
@@ -1958,11 +1894,11 @@ Your Conv Params: )" << "\\n";
pair_fwd = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
pair_fwd = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
pair_mask_fwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
pair_mask_fwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
{{mask_split_count, num_act_out
*
mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_out
,
mask_int_count}}, tv::uint32, 0, stream_int);
pair_mask_bwd = tv::Tensor();
pair_mask_bwd = tv::Tensor();
if (is_train){{
if (is_train){{
pair_mask_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMaskBwd
)
}
,
pair_mask_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMaskBwd
)
}
,
{{mask_split_count, indices.dim(0)
*
mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, indices.dim(0)
,
mask_int_count}}, tv::uint32, 0, stream_int);
}}
}}
}}
}}
if (!direct_table){{
if (!direct_table){{
...
@@ -1994,13 +1930,13 @@ Your Conv Params: )" << "\\n";
...
@@ -1994,13 +1930,13 @@ Your Conv Params: )" << "\\n";
indice_pairs_uniq, indice_pairs_uniq_bkp,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int
, mask_int_count
);
transposed, stream_int);
}}else{{
}}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int
, mask_int_count
);
transposed, stream_int);
}}
}}
}}
}}
"""
)
"""
)
...
@@ -2030,21 +1966,21 @@ Your Conv Params: )" << "\\n";
...
@@ -2030,21 +1966,21 @@ Your Conv Params: )" << "\\n";
}}
}}
}}else{{
}}else{{
if (!is_train){{
if (!is_train){{
sort_1d_by_key_allocator_
mask_auto_
v2(pair_mask_fwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
mask_argsort_fwd[0], stream_int, mask_int_count);
}}else{{
}}else{{
sort_1d_by_key_allocator_
mask_auto_
v2(pair_mask_fwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
mask_argsort_fwd[0], stream_int, mask_int_count);
sort_1d_by_key_allocator_
mask_auto_
v2(pair_mask_bwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count);
mask_argsort_bwd[0], stream_int, mask_int_count);
}}
}}
}}
}}
}}
}}
"""
)
"""
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
return std::make_tuple(mask_tensor, num_act_out
, mask_int_count
);
return std::make_tuple(mask_tensor, num_act_out);
"""
)
"""
)
return
code
.
ret
(
"std::tuple<tv::Tensor,
int,
int>"
)
return
code
.
ret
(
"std::tuple<tv::Tensor, int>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
@
pccm
.
static_function
...
...
spconv/csrc/sparse/convops.py
View file @
aa26c99e
...
@@ -936,7 +936,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -936,7 +936,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
"int, int, int, int, int>"
))
"int, int, int, int, int>"
))
self
.
add_typedef
(
self
.
add_typedef
(
"algo_cache_key_t"
,
"std::tuple<int, int, int, int, "
"algo_cache_key_t"
,
"std::tuple<int, int, int, int, "
"int, int, int, int>"
)
"int, int, int, int
, bool
>"
)
self
.
add_member
(
"desps_"
,
"std::vector<tv::gemm::ConvAlgoDesp>"
)
self
.
add_member
(
"desps_"
,
"std::vector<tv::gemm::ConvAlgoDesp>"
)
self
.
add_member
(
self
.
add_member
(
...
@@ -1009,7 +1009,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1009,7 +1009,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"auto_fp32_accum"
,
"bool"
)
code
.
arg
(
"auto_fp32_accum"
,
"bool"
)
code
.
arg
(
"fp32_accum"
,
"bool"
)
code
.
arg
(
"fp32_accum"
,
"bool"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"bias"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"scale"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
...
@@ -1077,7 +1080,22 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1077,7 +1080,22 @@ class ConvTunerSimple(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}}
}}
bool require_dynamic_mask = kv > 32;
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!bias.empty() && !scale.empty()){{
TV_ASSERT_RT_ERR(bias.dtype() == scale.dtype(), "bias/scale dtype must equal to compute dtype in gemm");
if (desp.dcomp != bias.dtype()){{
continue;
}}
if (!desp.is_int8_inference){{
continue;
}}
}}else{{
if (desp.is_int8_inference){{
continue;
}}
}}
auto desp2 = desp;
auto desp2 = desp;
if (desp.is_nvrtc){{
if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
...
@@ -1093,6 +1111,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1093,6 +1111,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
}}
}}
}}
}}
}}
if (require_dynamic_mask){{
if (!desp.dynamic_mask){{
continue;
}}
}}else{{
if (desp.dynamic_mask){{
continue;
}}
}}
finally_algos.push_back(desp2);
finally_algos.push_back(desp2);
}}
}}
}}
}}
...
@@ -1138,11 +1165,14 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1138,11 +1165,14 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
arg
(
"auto_fp32_accum"
,
"bool"
,
"true"
)
code
.
arg
(
"auto_fp32_accum"
,
"bool"
,
"true"
)
code
.
arg
(
"fp32_accum"
,
"bool"
,
"false"
)
code
.
arg
(
"fp32_accum"
,
"bool"
,
"false"
)
code
.
arg
(
"num_run"
,
"int"
,
"5"
)
code
.
arg
(
"num_run"
,
"int"
,
"5"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"bias"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"scale"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
...
@@ -1157,12 +1187,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1157,12 +1187,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o,
layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width,
arch, op_type, mask_width,
auto_fp32_accum, fp32_accum, use_tf32);
auto_fp32_accum, fp32_accum, use_tf32,
bias, scale);
inp = inp.clone();
inp = inp.clone();
weight = weight.clone();
weight = weight.clone();
bool need_dynamic_mask = weight.dim(1) > 32;
output = output.clone();
output = output.clone();
int channel_k = output.dim(1);
int channel_k = output.dim(1);
int channel_c = inp.dim(1);
int channel_c = inp.dim(1);
weight = weight.view(channel_k, -1, channel_c);
std::vector<ConvTuneResult> all_profile_res;
std::vector<ConvTuneResult> all_profile_res;
std::unordered_set<int> splitk_tests;
std::unordered_set<int> splitk_tests;
...
@@ -1187,7 +1220,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1187,7 +1220,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.indices = indices;
params.indices = indices;
params.mask = mask;
params.mask = mask;
params.mask_output = mask_output;
params.mask_output = mask_output;
params.mask_int_count = mask_int_count;
if (desp.is_int8_inference){{
params.bias = bias;
params.scale = scale;
}}
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
...
@@ -1246,7 +1282,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1246,7 +1282,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
}}
algo_cache_key_t key;
algo_cache_key_t key;
key = std::make_tuple(int(inp.dtype()), int(weight.dtype()),
key = std::make_tuple(int(inp.dtype()), int(weight.dtype()),
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width);
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width
, need_dynamic_mask
);
{{
{{
std::lock_guard<std::mutex> guard(mutex_);
std::lock_guard<std::mutex> guard(mutex_);
...
@@ -1279,6 +1315,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1279,6 +1315,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"k, c"
,
"int"
)
code
.
arg
(
"k, c"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"mask_width"
,
"int"
,
"-1"
)
code
.
arg
(
"mask_width"
,
"int"
,
"-1"
)
code
.
arg
(
"need_dynamic_mask"
,
"bool"
,
"false"
)
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
.
ret
(
"std::tuple<ConvTuneResult, bool>"
)
return
code
.
ret
(
"std::tuple<ConvTuneResult, bool>"
)
...
@@ -1290,7 +1328,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1290,7 +1328,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
}}
algo_cache_key_t key;
algo_cache_key_t key;
key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c,
key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c,
std::get<0>(arch), std::get<1>(arch), mask_width);
std::get<0>(arch), std::get<1>(arch), mask_width
, need_dynamic_mask
);
ConvTuneResult res;
ConvTuneResult res;
bool exists = false;
bool exists = false;
{{
{{
...
@@ -1338,7 +1376,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1338,7 +1376,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
arg
(
"workspace"
,
"tv::Tensor"
,
"tv::Tensor()"
,
code
.
arg
(
"workspace"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"verbose"
,
f
"bool"
,
"false"
)
code
.
arg
(
"verbose"
,
f
"bool"
,
"false"
)
...
@@ -1350,7 +1387,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1350,7 +1387,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"act_alpha"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_alpha"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_beta"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_beta"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_type"
,
f
"tv::gemm::Activation"
,
"tv::gemm::Activation::kNone"
,
"cumm.tensorview.gemm.Activation = Activation.None_"
)
code
.
arg
(
"act_type"
,
f
"tv::gemm::Activation"
,
"tv::gemm::Activation::kNone"
,
"cumm.tensorview.gemm.Activation = Activation.None_"
)
code
.
arg
(
"scale"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"output_add"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
return
code
...
@@ -1376,6 +1416,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1376,6 +1416,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.output = output;
params.output = output;
params.verbose = verbose;
params.verbose = verbose;
params.bias = bias;
params.bias = bias;
params.scale = scale;
params.split_k_slices = split_k_slices;
params.split_k_slices = split_k_slices;
params.alpha = alpha;
params.alpha = alpha;
...
@@ -1383,7 +1424,9 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1383,7 +1424,9 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.act_alpha = act_alpha;
params.act_alpha = act_alpha;
params.act_beta = act_beta;
params.act_beta = act_beta;
params.act_type = act_type;
params.act_type = act_type;
if (!output_add.empty() && desp.is_int8_inference){{
params.output_add = output_add;
}}
params.stream = stream_int;
params.stream = stream_int;
params.mask_argsort = mask_argsort;
params.mask_argsort = mask_argsort;
params.indices = indices;
params.indices = indices;
...
@@ -1393,7 +1436,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1393,7 +1436,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.mask_width = mask_width;
params.mask_width = mask_width;
params.mask_output = mask_output;
params.mask_output = mask_output;
params.reverse_mask = reverse_mask;
params.reverse_mask = reverse_mask;
params.mask_int_count = mask_int_count;
if (timer.enable()){{
if (timer.enable()){{
params.timer = timer;
params.timer = timer;
...
@@ -2039,7 +2081,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2039,7 +2081,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>"
)
"std::vector<tv::Tensor>"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"mask_int_count"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"is_train, is_subm"
,
"bool"
,
"false"
)
code
.
arg
(
"is_train, is_subm"
,
"bool"
,
"false"
)
...
@@ -2055,7 +2096,13 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2055,7 +2096,13 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"act_beta"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_beta"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_type"
,
f
"tv::gemm::Activation"
,
"tv::gemm::Activation::kNone"
,
"cumm.tensorview.gemm.Activation = Activation.None_"
)
code
.
arg
(
"act_type"
,
f
"tv::gemm::Activation"
,
"tv::gemm::Activation::kNone"
,
"cumm.tensorview.gemm.Activation = Activation.None_"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"output_scale"
,
"float"
,
"1.0"
)
code
.
arg
(
"scale"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"output_add"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"output_add_scale"
,
"float"
,
"1.0"
)
code
.
arg
(
"output_dtype"
,
"int"
,
"-1"
)
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
...
@@ -2072,13 +2119,18 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2072,13 +2119,18 @@ class ConvGemmOps(pccm.ParameterizedClass):
int num_split = pair_mask_fwd_splits.size();
int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error");
TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel);
filters = filters.view(out_channel, -1, in_channel);
int kv = filters.dim(1);
int mask_int_count = tv::div_up(kv, 32);
tv::Tensor out_features;
tv::Tensor out_features;
if (output_dtype < 0){{
output_dtype = int(features.dtype());
}}
if (is_subm){{
if (is_subm){{
out_features = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
out_features = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}},
features.
dtype
(
), features.device(), stream_int);
{{num_activate_out, out_channel}},
tv::DType(output_
dtype), features.device(), stream_int);
}}else{{
}}else{{
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}},
features.
dtype
(
), features.device(), stream_int);
{{num_activate_out, out_channel}},
tv::DType(output_
dtype), features.device(), stream_int);
}}
}}
// auto start_ev = tv::CUDAEvent();
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
// start_ev.record(stream_int);
...
@@ -2113,20 +2165,24 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2113,20 +2165,24 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
tv::Tensor(), // mask_output
1.0, 0.0,
1.0, 0.0,
stream_int,
stream_int,
mask_int_count, // mask_int_count is after stream_int
auto_fp32_accum,
auto_fp32_accum,
fp32_accum,
fp32_accum,
5, // num_run
5, // num_run
use_tf32);
use_tf32,
bias,
scale);
tune_res = std::get<0>(tune_res_time);
tune_res = std::get<0>(tune_res_time);
}}
}}
float alpha = 1.0;
if (tune_res.algo_desp.is_int8_inference){{
alpha = output_scale;
}}
int mask_width = tune_res.algo_desp.tile_shape[0];
int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd;
tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits;
std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{
if (is_train){{
mask_output_fwd = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskOutputFwd
)
}
,
mask_output_fwd = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskOutputFwd
)
}
,
{{num_split, tv::div_up(num_activate_out, mask_width)
*
mask_int_count}},
{{num_split, tv::div_up(num_activate_out, mask_width)
,
mask_int_count}},
tv::uint32, features.device(), stream_int);
tv::uint32, features.device(), stream_int);
for (int i = 0; i < num_split; ++i){{
for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]);
mask_output_fwd_splits.push_back(mask_output_fwd[i]);
...
@@ -2139,9 +2195,15 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2139,9 +2195,15 @@ class ConvGemmOps(pccm.ParameterizedClass):
for (int j = 0; j < num_split; ++j){{
for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1;
float beta = j == 0 ? 0 : 1;
if (!bias.empty()){{
if (!bias.empty() && !tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = 1;
beta = 1;
}}
}}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = output_add_scale;
}}
if (j > 0){{
if (j > 0){{
bias = tv::Tensor();
bias = tv::Tensor();
}}
}}
...
@@ -2158,9 +2220,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2158,9 +2220,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // reverse_mask
false, // reverse_mask
mask_ptr[j],
mask_ptr[j],
-1, // mask_width
-1, // mask_width
1.0
, beta,
alpha
, beta,
stream_int,
stream_int,
mask_int_count, // mask_int_count is after stream_int
tv::Tensor(), // workspace
tv::Tensor(), // workspace
false, // verbose
false, // verbose
timer,
timer,
...
@@ -2168,7 +2229,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2168,7 +2229,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
bias,
bias,
act_alpha,
act_alpha,
act_beta,
act_beta,
act_type);
act_type,
scale,
output_add);
}}
}}
// auto end_ev = tv::CUDAEvent();
// auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int);
// end_ev.record(stream_int);
...
@@ -2193,7 +2256,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2193,7 +2256,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"mask_output_fwd"
,
"tv::Tensor"
)
code
.
arg
(
"mask_output_fwd"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"mask_int_count"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"mask_width"
,
"int"
)
code
.
arg
(
"mask_width"
,
"int"
)
...
@@ -2286,7 +2348,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2286,7 +2348,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
tv::Tensor(), // mask_output
1.0, 0.0,
1.0, 0.0,
stream_int,
stream_int,
mask_int_count,
auto_fp32_accum,
auto_fp32_accum,
fp32_accum,
fp32_accum,
5, // num_run
5, // num_run
...
@@ -2311,7 +2372,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2311,7 +2372,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
tv::Tensor(), // mask_output
1.0, 0.0,
1.0, 0.0,
stream_int,
stream_int,
mask_int_count,
auto_fp32_accum,
auto_fp32_accum,
fp32_accum,
fp32_accum,
5, // num_run
5, // num_run
...
@@ -2354,7 +2414,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2354,7 +2414,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width
-1, // mask_width
1.0, beta,
1.0, beta,
stream_int,
stream_int,
mask_int_count,
tv::Tensor(), // workspace
tv::Tensor(), // workspace
false, // verbose
false, // verbose
timer);
timer);
...
@@ -2372,7 +2431,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2372,7 +2431,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_width,
mask_width,
1.0, 0.0,
1.0, 0.0,
stream_int,
stream_int,
mask_int_count,
workspace, // workspace
workspace, // workspace
false, // verbose
false, // verbose
timer);
timer);
...
...
spconv/csrc/sparse/indices.py
View file @
aa26c99e
...
@@ -829,7 +829,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -829,7 +829,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32));
uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32));
uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32;
uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32;
// uint32_t filter_mask_center = (1u << (RS / 2));
// uint32_t filter_mask_center = (1u << (RS / 2));
loc_iter.set_filter_offset(filter_offset);
loc_iter.set_filter_offset(filter_offset);
int indices_pair_size_mul_RS = indices_pair_size * RS;
int indices_pair_size_mul_RS = indices_pair_size * RS;
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
...
@@ -1255,13 +1254,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1255,13 +1254,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f
"tv::array<int,
{
self
.
ndim
}
>"
)
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream
// TODO stream
// TODO handle num input == 0
// TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>();
int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
// indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_fwd: [kv, num_act_out]
// indice_pairs_fwd: [kv, num_act_out]
auto ctx = tv::Context();
auto ctx = tv::Context();
...
@@ -1504,7 +1503,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1504,7 +1503,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"cumm.tensorview.Tensor = Tensor()"
)
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"is_train"
,
"bool"
,
"true"
)
code
.
arg
(
"is_train"
,
"bool"
,
"true"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
int num_act_in_real = indices.dim(0);
int num_act_in_real = indices.dim(0);
...
@@ -1523,6 +1521,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1523,6 +1521,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
padding[i] = (ksize[i] / 2) * dilation[i];
padding[i] = (ksize[i] / 2) * dilation[i];
}}
}}
int kv = ksize.op<tv::arrayops::prod>();
int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in]
// indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in]
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
...
@@ -1556,8 +1555,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1556,8 +1555,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
if (!indice_pair_mask.empty()){{
if (!indice_pair_mask.empty()){{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() ==
2
, "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() ==
3
, "error");
// indice_pair_mask: [mask_split_count, num_act_in]
// indice_pair_mask: [mask_split_count, num_act_in
, num_mask_per_point
]
if (indice_pair_mask.dim(0) == 2){{
if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real);
auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real);
auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
...
@@ -1571,13 +1570,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -1571,13 +1570,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indices.dim(0), indice_pairs.dim(2), kv, is_train);
indices.dim(0), indice_pairs.dim(2), kv, is_train);
}}else{{
}}else{{
// indice_pair_mask: [1, num_act_in]
// indice_pair_mask: [1, num_act_in
, num_mask_per_point
]
tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
if (mask_int_count == 1)
if (mask_int_count == 1){{
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
else
}}
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
else{{
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
}}
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t,
{
loc_type
}
>, loc_iter, hash,
launcher_num_act_in(calc_subm_conv_indices_mask<table_t,
{
loc_type
}
>, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
...
...
spconv/csrc/sparse/pointops.py
View file @
aa26c99e
...
@@ -465,14 +465,14 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
...
@@ -465,14 +465,14 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(),
table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(),
count.data_ptr<int>(),
count.data_ptr<int>(),
layout, voxels.dim(0));
layout, voxels.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const
{
self
.
dtype
}
>(),
launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const
{
self
.
dtype
}
>(),
point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<
{
self
.
dtype
}
>(),
point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<
{
self
.
dtype
}
>(),
num_per_voxel.data_ptr<int>(), points_voxel_id.data_ptr<int64_t>(), points.dim(1), voxels.dim(1),
num_per_voxel.data_ptr<int>(), points_voxel_id.data_ptr<int64_t>(), points.dim(1), voxels.dim(1),
voxels.dim(0), vsize_tv, coors_range_tv,
voxels.dim(0), vsize_tv, coors_range_tv,
grid_size_tv, grid_stride_tv, points.dim(0));
grid_size_tv, grid_stride_tv, points.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
auto voxel_launcher = tv::cuda::Launch(count_val, custream);
auto voxel_launcher = tv::cuda::Launch(count_val, custream);
if (empty_mean){{
if (empty_mean){{
launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<
{
self
.
dtype
}
>(),
launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<
{
self
.
dtype
}
>(),
...
...
spconv/pytorch/conv.py
View file @
aa26c99e
...
@@ -37,10 +37,23 @@ from spconv.utils import nullcontext
...
@@ -37,10 +37,23 @@ from spconv.utils import nullcontext
from
torch.nn.init
import
calculate_gain
from
torch.nn.init
import
calculate_gain
from
cumm
import
tensorview
as
tv
from
cumm
import
tensorview
as
tv
from
torch.nn
import
functional
as
F
FILTER_HWIO
=
False
FILTER_HWIO
=
False
_MAX_NUM_VOXELS_DURING_TRAINING
=
"max_num_voxels_during_training"
_MAX_NUM_VOXELS_DURING_TRAINING
=
"max_num_voxels_during_training"
def
_apply_act
(
x
:
torch
.
Tensor
,
act_type
:
tv
.
gemm
.
Activation
,
act_alpha
:
float
,
act_beta
:
float
):
if
act_type
==
tv
.
gemm
.
Activation
.
None_
:
return
x
elif
act_type
==
tv
.
gemm
.
Activation
.
ReLU
:
return
F
.
relu
(
x
)
elif
act_type
==
tv
.
gemm
.
Activation
.
Sigmoid
:
return
F
.
sigmoid
(
x
)
elif
act_type
==
tv
.
gemm
.
Activation
.
LeakyReLU
:
return
F
.
leaky_relu
(
x
,
act_alpha
)
else
:
raise
NotImplementedError
class
SparseConvolution
(
SparseModule
):
class
SparseConvolution
(
SparseModule
):
__constants__
=
[
__constants__
=
[
...
@@ -104,7 +117,7 @@ class SparseConvolution(SparseModule):
...
@@ -104,7 +117,7 @@ class SparseConvolution(SparseModule):
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
))
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
))
self
.
record_voxel_count
=
record_voxel_count
self
.
record_voxel_count
=
record_voxel_count
if
algo
is
None
:
if
algo
is
None
:
if
kv
<=
32
and
not
CPU_ONLY_BUILD
:
if
kv
<=
128
and
not
CPU_ONLY_BUILD
:
if
kv
<
8
:
if
kv
<
8
:
algo
=
ConvAlgo
.
MaskImplicitGemm
algo
=
ConvAlgo
.
MaskImplicitGemm
else
:
else
:
...
@@ -139,6 +152,19 @@ class SparseConvolution(SparseModule):
...
@@ -139,6 +152,19 @@ class SparseConvolution(SparseModule):
self
.
act_type
=
act_type
self
.
act_type
=
act_type
self
.
act_alpha
=
act_alpha
self
.
act_alpha
=
act_alpha
self
.
act_beta
=
act_beta
self
.
act_beta
=
act_beta
self
.
enable_int8_test_mode
:
bool
=
False
self
.
_int8_weight
=
torch
.
Tensor
()
# calculated by max(abs(weight)) for each channel
self
.
_int8_weight_scale
=
torch
.
Tensor
()
# calculated by scale self.bias with _int8_input_scale
self
.
_int8_bias
=
torch
.
Tensor
()
# int8 inference must set _int8_input_scale
self
.
_int8_input_scale
:
Optional
[
float
]
=
None
# if _int8_output_scale unset, will execute s8 @ s8 => f16/f32 (weight dtype), i.e. dequantization
self
.
_int8_output_scale
:
Optional
[
float
]
=
None
if
self
.
conv1x1
:
if
self
.
conv1x1
:
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
self
.
reset_parameters
()
self
.
reset_parameters
()
...
@@ -151,11 +177,19 @@ class SparseConvolution(SparseModule):
...
@@ -151,11 +177,19 @@ class SparseConvolution(SparseModule):
return
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
)
return
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
)
return
None
return
None
def
set_int8_test
(
self
,
enable
:
bool
,
input_scale
:
float
,
output_scale
:
Optional
[
float
]
=
None
,
weight_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
self
.
_int8_input_scale
=
input_scale
self
.
_int8_output_scale
=
output_scale
if
weight_scale
is
not
None
:
self
.
_int8_weight_scale
=
weight_scale
self
.
enable_int8_test_mode
=
enable
def
_load_weight_different_layout
(
self
,
state_dict
,
prefix
,
local_metadata
,
def
_load_weight_different_layout
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
error_msgs
):
if
self
.
record_voxel_count
and
not
self
.
subm
and
not
self
.
inverse
and
_MAX_NUM_VOXELS_DURING_TRAINING
not
in
state_dict
:
name
=
prefix
+
_MAX_NUM_VOXELS_DURING_TRAINING
state_dict
[
prefix
+
_MAX_NUM_VOXELS_DURING_TRAINING
]
=
torch
.
zeros
(
if
self
.
record_voxel_count
and
not
self
.
subm
and
not
self
.
inverse
and
name
not
in
state_dict
:
state_dict
[
name
]
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
)
1
,
dtype
=
torch
.
int32
)
if
not
SAVED_WEIGHT_LAYOUT
:
if
not
SAVED_WEIGHT_LAYOUT
:
return
return
...
@@ -255,7 +289,10 @@ class SparseConvolution(SparseModule):
...
@@ -255,7 +289,10 @@ class SparseConvolution(SparseModule):
def
is_inverseable
(
self
):
def
is_inverseable
(
self
):
return
self
.
indice_key
is
not
None
and
not
self
.
subm
return
self
.
indice_key
is
not
None
and
not
self
.
subm
def
forward
(
self
,
input
:
SparseConvTensor
):
def
forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
return
self
.
_conv_forward
(
input
,
self
.
weight
,
self
.
bias
,
add_input
)
def
_conv_forward
(
self
,
input
:
SparseConvTensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
assert
isinstance
(
input
,
SparseConvTensor
)
assert
isinstance
(
input
,
SparseConvTensor
)
assert
input
.
features
.
shape
[
assert
input
.
features
.
shape
[
1
]
==
self
.
in_channels
,
"channel size mismatch"
1
]
==
self
.
in_channels
,
"channel size mismatch"
...
@@ -264,9 +301,34 @@ class SparseConvolution(SparseModule):
...
@@ -264,9 +301,34 @@ class SparseConvolution(SparseModule):
indices
=
input
.
indices
indices
=
input
.
indices
spatial_shape
=
input
.
spatial_shape
spatial_shape
=
input
.
spatial_shape
batch_size
=
input
.
batch_size
batch_size
=
input
.
batch_size
bias_for_training
=
self
.
bias
if
self
.
training
else
None
bias_for_training
=
bias
if
self
.
training
else
None
bias_for_infer
=
self
.
bias
if
not
self
.
training
else
None
bias_for_infer
=
bias
if
not
self
.
training
else
None
output_scale
=
None
output_add_scale
=
1.0
if
self
.
enable_int8_test_mode
:
assert
not
self
.
training
,
"must in eval mode"
assert
self
.
algo
==
ConvAlgo
.
MaskImplicitGemm
,
"int8 inference only support MaskImplicitGemm"
assert
bias_for_infer
is
not
None
,
"conv-bn-relu must be fused"
assert
self
.
_int8_input_scale
is
not
None
if
features
.
dtype
!=
torch
.
int8
:
# quantize
features
=
torch
.
clamp
(
torch
.
round
(
features
/
self
.
_int8_input_scale
),
-
128
,
127
).
to
(
torch
.
int8
)
output_scale
=
self
.
_int8_output_scale
int8_out_scale
=
output_scale
if
int8_out_scale
is
None
:
int8_out_scale
=
1
if
add_input
is
not
None
:
assert
add_input
.
int8_scale
is
not
None
,
"only support int8 add"
output_add_scale
=
add_input
.
int8_scale
if
self
.
_int8_weight
.
numel
()
==
0
:
with
torch
.
no_grad
():
assert
ALL_WEIGHT_IS_KRSC
weight_scales
=
torch
.
abs
(
weight
).
view
(
self
.
out_channels
,
-
1
).
max
(
1
)[
0
]
num_1s
=
[
1
]
*
(
self
.
ndim
+
1
)
self
.
_int8_weight
=
(
weight
/
weight_scales
.
view
(
self
.
out_channels
,
*
num_1s
)
*
127
).
to
(
torch
.
int8
)
if
self
.
_int8_weight_scale
.
numel
()
==
0
:
self
.
_int8_weight_scale
=
int8_out_scale
/
(
self
.
_int8_input_scale
*
weight_scales
)
self
.
_int8_bias
=
bias_for_infer
*
int8_out_scale
if
self
.
training
:
if
self
.
training
:
msg
=
"act don't support backward, only used in inference"
msg
=
"act don't support backward, only used in inference"
assert
self
.
act_type
==
tv
.
gemm
.
Activation
.
None_
,
msg
assert
self
.
act_type
==
tv
.
gemm
.
Activation
.
None_
,
msg
...
@@ -310,18 +372,19 @@ class SparseConvolution(SparseModule):
...
@@ -310,18 +372,19 @@ class SparseConvolution(SparseModule):
"out_channels"
:
self
.
out_channels
,
"out_channels"
:
self
.
out_channels
,
}
}
}
}
if
self
.
conv1x1
:
if
self
.
conv1x1
and
not
self
.
enable_int8_test_mode
:
# in int8 test mode, we don't implement conv1x1 via mm.
if
FILTER_HWIO
:
if
FILTER_HWIO
:
features
=
torch
.
mm
(
features
=
torch
.
mm
(
input
.
features
,
input
.
features
,
self
.
weight
.
view
(
self
.
out_channels
,
self
.
in_channels
).
T
)
weight
.
view
(
self
.
out_channels
,
self
.
in_channels
).
T
)
else
:
else
:
features
=
torch
.
mm
(
features
=
torch
.
mm
(
input
.
features
,
input
.
features
,
self
.
weight
.
view
(
self
.
in_channels
,
self
.
out_channels
))
weight
.
view
(
self
.
in_channels
,
self
.
out_channels
))
if
self
.
bias
is
not
None
:
if
bias
is
not
None
:
features
+=
self
.
bias
features
+=
bias
out_tensor
=
out_tensor
.
replace_feature
(
features
)
out_tensor
=
out_tensor
.
replace_feature
(
features
)
# padding may change spatial shape of conv 1x1.
# padding may change spatial shape of conv 1x1.
out_tensor
.
spatial_shape
=
out_spatial_shape
out_tensor
.
spatial_shape
=
out_spatial_shape
...
@@ -413,7 +476,7 @@ class SparseConvolution(SparseModule):
...
@@ -413,7 +476,7 @@ class SparseConvolution(SparseModule):
if
self
.
subm
:
if
self
.
subm
:
out_features
=
Fsp
.
indice_subm_conv
(
out_features
=
Fsp
.
indice_subm_conv
(
features
,
features
,
self
.
weight
,
weight
,
indice_pairs_calc
,
indice_pairs_calc
,
indice_pair_num
,
indice_pair_num
,
outids
.
shape
[
0
],
outids
.
shape
[
0
],
...
@@ -427,7 +490,7 @@ class SparseConvolution(SparseModule):
...
@@ -427,7 +490,7 @@ class SparseConvolution(SparseModule):
if
self
.
inverse
:
if
self
.
inverse
:
out_features
=
Fsp
.
indice_inverse_conv
(
out_features
=
Fsp
.
indice_inverse_conv
(
features
,
features
,
self
.
weight
,
weight
,
indice_pairs_calc
,
indice_pairs_calc
,
indice_pair_num
,
indice_pair_num
,
outids
.
shape
[
0
],
outids
.
shape
[
0
],
...
@@ -440,7 +503,7 @@ class SparseConvolution(SparseModule):
...
@@ -440,7 +503,7 @@ class SparseConvolution(SparseModule):
else
:
else
:
out_features
=
Fsp
.
indice_conv
(
out_features
=
Fsp
.
indice_conv
(
features
,
features
,
self
.
weight
,
weight
,
indice_pairs_calc
,
indice_pairs_calc
,
indice_pair_num
,
indice_pair_num
,
outids
.
shape
[
0
],
outids
.
shape
[
0
],
...
@@ -481,11 +544,13 @@ class SparseConvolution(SparseModule):
...
@@ -481,11 +544,13 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits
=
datas
.
mask_argsort_fwd_splits
mask_argsort_fwd_splits
=
datas
.
mask_argsort_fwd_splits
mask_argsort_bwd_splits
=
datas
.
mask_argsort_bwd_splits
mask_argsort_bwd_splits
=
datas
.
mask_argsort_bwd_splits
masks
=
datas
.
masks
masks
=
datas
.
masks
mask_int_count
=
datas
.
mask_int_count
assert
self
.
subm
,
"only support reuse subm indices"
assert
self
.
subm
,
"only support reuse subm indices"
self
.
_check_subm_reuse_valid
(
input
,
spatial_shape
,
self
.
_check_subm_reuse_valid
(
input
,
spatial_shape
,
datas
)
datas
)
else
:
else
:
if
input
.
benchmark
:
torch
.
cuda
.
synchronize
()
t
=
time
.
time
()
with
input
.
_timer
.
namespace
(
"gen_pairs"
):
with
input
.
_timer
.
namespace
(
"gen_pairs"
):
# we need to gen bwd indices for regular conv
# we need to gen bwd indices for regular conv
# because it may be inversed.
# because it may be inversed.
...
@@ -514,7 +579,11 @@ class SparseConvolution(SparseModule):
...
@@ -514,7 +579,11 @@ class SparseConvolution(SparseModule):
print
(
msg
,
file
=
sys
.
stderr
)
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
(
indices
)
spconv_save_debug_data
(
indices
)
raise
e
raise
e
if
input
.
benchmark
:
torch
.
cuda
.
synchronize
()
interval
=
time
.
time
()
-
t
out_tensor
.
benchmark_record
[
self
.
name
][
"indice_gen_time"
].
append
(
interval
)
outids
=
res
[
0
]
outids
=
res
[
0
]
num_inds_per_loc
=
res
[
1
]
num_inds_per_loc
=
res
[
1
]
pair_fwd
=
res
[
2
]
pair_fwd
=
res
[
2
]
...
@@ -524,7 +593,6 @@ class SparseConvolution(SparseModule):
...
@@ -524,7 +593,6 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits
=
res
[
6
]
mask_argsort_fwd_splits
=
res
[
6
]
mask_argsort_bwd_splits
=
res
[
7
]
mask_argsort_bwd_splits
=
res
[
7
]
masks
=
res
[
8
]
masks
=
res
[
8
]
mask_int_count
=
res
[
9
]
if
self
.
indice_key
is
not
None
:
if
self
.
indice_key
is
not
None
:
indice_data
=
ImplicitGemmIndiceData
(
indice_data
=
ImplicitGemmIndiceData
(
outids
,
outids
,
...
@@ -543,8 +611,7 @@ class SparseConvolution(SparseModule):
...
@@ -543,8 +611,7 @@ class SparseConvolution(SparseModule):
ksize
=
self
.
kernel_size
,
ksize
=
self
.
kernel_size
,
stride
=
self
.
stride
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
dilation
=
self
.
dilation
)
mask_int_count
=
mask_int_count
)
msg
=
f
"your indice key
{
self
.
indice_key
}
already exists in this sparse tensor."
msg
=
f
"your indice key
{
self
.
indice_key
}
already exists in this sparse tensor."
assert
self
.
indice_key
not
in
indice_dict
,
msg
assert
self
.
indice_key
not
in
indice_dict
,
msg
indice_dict
[
self
.
indice_key
]
=
indice_data
indice_dict
[
self
.
indice_key
]
=
indice_data
...
@@ -552,16 +619,43 @@ class SparseConvolution(SparseModule):
...
@@ -552,16 +619,43 @@ class SparseConvolution(SparseModule):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t
=
time
.
time
()
t
=
time
.
time
()
num_activate_out
=
outids
.
shape
[
0
]
num_activate_out
=
outids
.
shape
[
0
]
out_features
=
Fsp
.
implicit_gemm
(
weight_cur
=
weight
features
,
self
.
weight
,
pair_fwd
,
pair_bwd
,
bias_cur
=
bias_for_infer
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
if
self
.
enable_int8_test_mode
:
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
assert
features
.
dtype
==
torch
.
int8
,
"in int8 test mode, feature must be int8"
num_activate_out
,
masks
,
mask_int_count
,
self
.
training
,
self
.
subm
,
weight_cur
=
self
.
_int8_weight
input
.
_timer
,
self
.
fp32_accum
,
bias_cur
=
self
.
_int8_bias
bias_for_infer
,
if
self
.
training
:
self
.
act_alpha
,
out_features
=
Fsp
.
implicit_gemm
(
self
.
act_beta
,
features
,
weight_cur
,
pair_fwd
,
pair_bwd
,
self
.
act_type
)
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
num_activate_out
,
masks
,
self
.
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
,
bias_cur
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
else
:
output_dtype
=
None
if
self
.
_int8_output_scale
is
None
:
output_dtype
=
weight
.
dtype
out_features
,
_
,
_
=
ops
.
implicit_gemm
(
features
,
weight_cur
,
pair_fwd
,
pair_mask_fwd_splits
,
mask_argsort_fwd_splits
,
num_activate_out
,
masks
,
self
.
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
,
bias_cur
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
,
# TODO do we really need output scale to scale bias in kernel?
1.0
,
# output_scale
self
.
_int8_weight_scale
,
# scale
output_add
=
add_input
.
features
if
add_input
is
not
None
else
None
,
output_add_scale
=
output_add_scale
,
output_dtype
=
output_dtype
)
if
bias_for_training
is
not
None
:
if
bias_for_training
is
not
None
:
out_features
+=
bias_for_training
out_features
+=
bias_for_training
if
input
.
benchmark
:
if
input
.
benchmark
:
...
@@ -581,9 +675,10 @@ class SparseConvolution(SparseModule):
...
@@ -581,9 +675,10 @@ class SparseConvolution(SparseModule):
out_tensor
.
indices
=
outids
out_tensor
.
indices
=
outids
out_tensor
.
indice_dict
=
indice_dict
out_tensor
.
indice_dict
=
indice_dict
out_tensor
.
spatial_shape
=
out_spatial_shape
out_tensor
.
spatial_shape
=
out_spatial_shape
# print(outids.shape, spatial_shape, self.kernel_size, self.stride, self.padding,
if
add_input
is
not
None
and
not
self
.
enable_int8_test_mode
:
# self.dilation, self.output_padding, out_spatial_shape)
out_tensor
=
out_tensor
.
replace_feature
(
_apply_act
(
out_tensor
.
features
+
add_input
.
features
,
self
.
act_type
,
self
.
act_alpha
,
self
.
act_beta
))
out_tensor
.
int8_scale
=
output_scale
return
out_tensor
return
out_tensor
def
_check_subm_reuse_valid
(
self
,
inp
:
SparseConvTensor
,
def
_check_subm_reuse_valid
(
self
,
inp
:
SparseConvTensor
,
...
...
spconv/pytorch/core.py
View file @
aa26c99e
...
@@ -89,8 +89,7 @@ class ImplicitGemmIndiceData(object):
...
@@ -89,8 +89,7 @@ class ImplicitGemmIndiceData(object):
out_spatial_shape
,
is_subm
:
bool
,
algo
:
ConvAlgo
,
out_spatial_shape
,
is_subm
:
bool
,
algo
:
ConvAlgo
,
ksize
:
List
[
int
],
stride
:
List
[
int
],
dilation
:
List
[
int
],
padding
:
List
[
int
],
ksize
:
List
[
int
],
stride
:
List
[
int
],
dilation
:
List
[
int
],
padding
:
List
[
int
],
in_voxel_num
:
Optional
[
Any
]
=
None
,
in_voxel_num
:
Optional
[
Any
]
=
None
,
out_voxel_num
:
Optional
[
Any
]
=
None
,
out_voxel_num
:
Optional
[
Any
]
=
None
):
mask_int_count
:
int
=
1
):
self
.
out_indices
=
out_indices
self
.
out_indices
=
out_indices
self
.
indices
=
indices
self
.
indices
=
indices
self
.
pair_fwd
=
pair_fwd
self
.
pair_fwd
=
pair_fwd
...
@@ -111,7 +110,6 @@ class ImplicitGemmIndiceData(object):
...
@@ -111,7 +110,6 @@ class ImplicitGemmIndiceData(object):
# in/out voxel_num is only used in tensorrt conversion.
# in/out voxel_num is only used in tensorrt conversion.
self
.
in_voxel_num
=
in_voxel_num
self
.
in_voxel_num
=
in_voxel_num
self
.
out_voxel_num
=
out_voxel_num
self
.
out_voxel_num
=
out_voxel_num
self
.
mask_int_count
=
mask_int_count
def
scatter_nd
(
indices
,
updates
,
shape
):
def
scatter_nd
(
indices
,
updates
,
shape
):
...
@@ -183,6 +181,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
...
@@ -183,6 +181,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
force_algo
=
force_algo
self
.
force_algo
=
force_algo
# for simple int8 torch inference
self
.
int8_scale
:
Optional
[
float
]
=
None
def
replace_feature
(
self
,
feature
:
torch
.
Tensor
):
def
replace_feature
(
self
,
feature
:
torch
.
Tensor
):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
...
...
spconv/pytorch/functional.py
View file @
aa26c99e
...
@@ -198,7 +198,6 @@ class SparseImplicitGemmFunction(Function):
...
@@ -198,7 +198,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits
:
List
[
torch
.
Tensor
],
mask_argsort_bwd_splits
:
List
[
torch
.
Tensor
],
num_activate_out
:
int
,
num_activate_out
:
int
,
masks
:
List
[
np
.
ndarray
],
masks
:
List
[
np
.
ndarray
],
mask_int_count
:
int
,
is_train
:
bool
,
is_train
:
bool
,
is_subm
:
bool
,
is_subm
:
bool
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
...
@@ -210,7 +209,7 @@ class SparseImplicitGemmFunction(Function):
...
@@ -210,7 +209,7 @@ class SparseImplicitGemmFunction(Function):
try
:
try
:
out
,
mask_out
,
mask_width
=
ops
.
implicit_gemm
(
out
,
mask_out
,
mask_width
=
ops
.
implicit_gemm
(
features
,
filters
,
pair_fwd
,
pair_mask_fwd_splits
,
features
,
filters
,
pair_fwd
,
pair_mask_fwd_splits
,
mask_argsort_fwd_splits
,
num_activate_out
,
masks
,
mask_int_count
,
is_train
,
mask_argsort_fwd_splits
,
num_activate_out
,
masks
,
is_train
,
is_subm
,
timer
,
fp32_accum
,
bias
,
act_alpha
,
act_beta
,
is_subm
,
timer
,
fp32_accum
,
bias
,
act_alpha
,
act_beta
,
act_type
)
act_type
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -236,7 +235,6 @@ class SparseImplicitGemmFunction(Function):
...
@@ -236,7 +235,6 @@ class SparseImplicitGemmFunction(Function):
ctx
.
masks
=
masks
ctx
.
masks
=
masks
ctx
.
is_subm
=
is_subm
ctx
.
is_subm
=
is_subm
ctx
.
fp32_accum
=
fp32_accum
ctx
.
fp32_accum
=
fp32_accum
ctx
.
mask_int_count
=
mask_int_count
return
out
return
out
@
staticmethod
@
staticmethod
...
@@ -255,7 +253,6 @@ class SparseImplicitGemmFunction(Function):
...
@@ -255,7 +253,6 @@ class SparseImplicitGemmFunction(Function):
is_subm
=
ctx
.
is_subm
is_subm
=
ctx
.
is_subm
timer
=
ctx
.
timer
timer
=
ctx
.
timer
fp32_accum
=
ctx
.
fp32_accum
fp32_accum
=
ctx
.
fp32_accum
mask_int_count
=
ctx
.
mask_int_count
try
:
try
:
input_bp
,
filters_bp
=
ops
.
implicit_gemm_backward
(
input_bp
,
filters_bp
=
ops
.
implicit_gemm_backward
(
...
@@ -270,7 +267,6 @@ class SparseImplicitGemmFunction(Function):
...
@@ -270,7 +267,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits
,
mask_argsort_bwd_splits
,
mask_output_fwd
=
mask_out
,
mask_output_fwd
=
mask_out
,
masks
=
masks
,
masks
=
masks
,
mask_int_count
=
mask_int_count
,
mask_width
=
mask_width
,
mask_width
=
mask_width
,
is_subm
=
is_subm
,
is_subm
=
is_subm
,
timer
=
timer
,
timer
=
timer
,
...
@@ -286,7 +282,7 @@ class SparseImplicitGemmFunction(Function):
...
@@ -286,7 +282,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits
,
masks
))
mask_argsort_bwd_splits
,
masks
))
raise
e
raise
e
None_9
=
[
None
]
*
1
7
None_9
=
[
None
]
*
1
6
return
(
input_bp
,
filters_bp
,
*
None_9
)
return
(
input_bp
,
filters_bp
,
*
None_9
)
...
...
spconv/pytorch/ops.py
View file @
aa26c99e
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment