Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
aa26c99e
Commit
aa26c99e
authored
Dec 29, 2022
by
yan.yan
Browse files
working on quantization
parent
ee8c9465
Changes
29
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
745 additions
and
465 deletions
+745
-465
example/fuse_bn_act.py
example/fuse_bn_act.py
+11
-1
example/libspconv/main.cu
example/libspconv/main.cu
+2
-2
example/mnist/mnist_qat.py
example/mnist/mnist_qat.py
+235
-0
example/mnist/mnist_sparse.py
example/mnist/mnist_sparse.py
+0
-0
pyproject.toml
pyproject.toml
+1
-1
spconv/algo.py
spconv/algo.py
+53
-21
spconv/algocore.py
spconv/algocore.py
+7
-2
spconv/build.py
spconv/build.py
+2
-2
spconv/core.py
spconv/core.py
+60
-43
spconv/core_cc/csrc/sparse/all/__init__.pyi
spconv/core_cc/csrc/sparse/all/__init__.pyi
+8
-51
spconv/core_cc/csrc/sparse/convops/convops.pyi
spconv/core_cc/csrc/sparse/convops/convops.pyi
+11
-6
spconv/core_cc/csrc/sparse/convops/spops.pyi
spconv/core_cc/csrc/sparse/convops/spops.pyi
+6
-4
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+44
-108
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+86
-28
spconv/csrc/sparse/indices.py
spconv/csrc/sparse/indices.py
+12
-11
spconv/csrc/sparse/pointops.py
spconv/csrc/sparse/pointops.py
+3
-3
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+128
-33
spconv/pytorch/core.py
spconv/pytorch/core.py
+3
-3
spconv/pytorch/functional.py
spconv/pytorch/functional.py
+2
-6
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+71
-140
No files found.
example/fuse_bn_act.py
View file @
aa26c99e
...
...
@@ -207,7 +207,7 @@ class Net(nn.Module):
pool_algo
=
algo
# pool_algo = ConvAlgo.Native
self
.
net
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
3
,
64
,
3
,
bias
=
False
,
indice_key
=
"c0"
,
spconv
.
SubMConv3d
(
16
,
64
,
3
,
bias
=
False
,
indice_key
=
"c0"
,
algo
=
algo
),
nn
.
BatchNorm1d
(
64
),
nn
.
ReLU
(),
...
...
@@ -373,6 +373,11 @@ class Net(nn.Module):
x
=
spconv
.
SparseConvTensor
(
features
,
coors
,
self
.
shape
,
batch_size
,
voxel_num
=
vx_num
)
return
self
.
net
(
x
)
def
_set_enable_int8_test_inplace
(
simple_module
:
torch
.
fx
.
GraphModule
,
enable
:
bool
):
for
m
in
simple_module
.
modules
():
if
isinstance
(
m
,
SparseConvolution
):
if
m
.
in_channels
%
32
==
0
and
m
.
out_channels
%
32
==
0
:
m
.
enable_int8_test_mode
=
enable
class
MyTracer
(
torch
.
fx
.
Tracer
):
...
...
@@ -387,6 +392,7 @@ def main():
torch
.
backends
.
cudnn
.
allow_tf32
=
False
with
open
(
Path
(
__file__
).
parent
.
parent
/
"test"
/
"data"
/
"test_spconv.pkl"
,
"rb"
)
as
f
:
(
voxels
,
coors
,
spatial_shape
)
=
pickle
.
load
(
f
)
voxels
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
voxels
.
shape
[
0
],
16
]).
astype
(
np
.
float32
)
np
.
random
.
seed
(
50051
)
device
=
torch
.
device
(
"cuda:0"
)
device_cpu
=
torch
.
device
(
"cpu:0"
)
...
...
@@ -408,6 +414,10 @@ def main():
out_fused
=
net_fused
(
voxels_th_cuda
,
coors_th_cuda
,
1
)
res
=
Fsp
.
sparse_add_hash_based
(
out_ref
,
out_fused
.
minus
())
print
(
torch
.
linalg
.
norm
(
res
.
features
))
_set_enable_int8_test_inplace
(
net_fused
,
True
)
qvoxels_cuda
=
voxels_th_cuda
.
to
(
torch
.
int8
)
out_int8
=
net_fused
(
qvoxels_cuda
,
coors_th_cuda
,
1
)
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
example/libspconv/main.cu
View file @
aa26c99e
...
...
@@ -426,7 +426,7 @@ int main(int argc, char **argv) {
{
SPCONV_ALLOC_OUT_FEATURES
,
out_features
}};
StaticAllocator
alloc2
(
tensor_dict
);
ConvTunerSimple
tuner
(
ConvMain
::
get_all_conv_algo_desp
());
auto
conv_r
e
s
=
ConvGemmOps
::
implicit_gemm
(
auto
conv_r
un_statu
s
=
ConvGemmOps
::
implicit_gemm
(
alloc2
,
tuner
,
input_features_real
,
weights
,
pair_fwd_real
,
pair_mask_splits
,
mask_argsort_splits
,
num_act_out_real
,
mask_tensor
,
arch
,
false
,
is_subm
,
...
...
@@ -435,7 +435,7 @@ int main(int argc, char **argv) {
1.0
/*bias alpha, only used for leaky relu*/
,
0.0
/*unused for now*/
,
tv
::
gemm
::
Activation
::
kReLU
);
tv
::
ssprint
(
"selected conv algo"
,
std
::
get
<
1
>
(
conv_r
e
s
).
algo_desp
.
__repr__
());
std
::
get
<
1
>
(
conv_r
un_statu
s
).
algo_desp
.
__repr__
());
// FINISH!!!
}
// calc maximum number of output points.
...
...
example/mnist/mnist_qat.py
0 → 100644
View file @
aa26c99e
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
torch
import
spconv.pytorch
as
spconv
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
torchvision
import
datasets
,
transforms
from
torch.optim.lr_scheduler
import
StepLR
import
contextlib
import
torch.cuda.amp
@
contextlib
.
contextmanager
def
identity_ctx
():
yield
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
nn
.
BatchNorm1d
(
1
),
spconv
.
SubMConv2d
(
1
,
32
,
3
,
1
),
nn
.
ReLU
(),
spconv
.
SubMConv2d
(
32
,
64
,
3
,
1
),
nn
.
ReLU
(),
spconv
.
SparseConv2d
(
64
,
64
,
2
,
2
),
spconv
.
ToDense
(),
)
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
x
.
reshape
(
-
1
,
28
,
28
,
1
))
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x
=
self
.
net
(
x_sp
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
dropout1
(
x
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
def
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
):
model
.
train
()
scaler
=
torch
.
cuda
.
amp
.
grad_scaler
.
GradScaler
()
amp_ctx
=
contextlib
.
nullcontext
()
if
args
.
fp16
:
amp_ctx
=
torch
.
cuda
.
amp
.
autocast
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
with
amp_ctx
:
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
scale
=
1.0
if
args
.
fp16
:
assert
loss
.
dtype
is
torch
.
float32
scaler
.
scale
(
loss
).
backward
()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
# scaler.unscale_(optim)
# Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
# You may use the same value for max_norm here as you would without gradient scaling.
# torch.nn.utils.clip_grad_norm_(models[0].net.parameters(), max_norm=0.1)
scaler
.
step
(
optimizer
)
# Updates the scale for next iteration.
scaler
.
update
()
scale
=
scaler
.
get_scale
()
else
:
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
args
.
log_interval
==
0
:
print
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}'
.
format
(
epoch
,
batch_idx
*
len
(
data
),
len
(
train_loader
.
dataset
),
100.
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
args
,
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
correct
=
0
amp_ctx
=
contextlib
.
nullcontext
()
if
args
.
fp16
:
amp_ctx
=
torch
.
cuda
.
amp
.
autocast
()
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
with
amp_ctx
:
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
# sum up batch loss
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
# get the index of the max log-probability
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'
\n
Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
\n
'
.
format
(
test_loss
,
correct
,
len
(
test_loader
.
dataset
),
100.
*
correct
/
len
(
test_loader
.
dataset
)))
def
main
():
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
64
,
metavar
=
'N'
,
help
=
'input batch size for training (default: 64)'
)
parser
.
add_argument
(
'--test-batch-size'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
help
=
'input batch size for testing (default: 1000)'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
14
,
metavar
=
'N'
,
help
=
'number of epochs to train (default: 14)'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
1.0
,
metavar
=
'LR'
,
help
=
'learning rate (default: 1.0)'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.7
,
metavar
=
'M'
,
help
=
'Learning rate step gamma (default: 0.7)'
)
parser
.
add_argument
(
'--no-cuda'
,
action
=
'store_true'
,
default
=
False
,
help
=
'disables CUDA training'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1
,
metavar
=
'S'
,
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
help
=
'how many batches to wait before logging training status'
)
parser
.
add_argument
(
'--save-model'
,
action
=
'store_true'
,
default
=
False
,
help
=
'For Saving the current Model'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
default
=
False
,
help
=
'For mixed precision training'
)
args
=
parser
.
parse_args
()
use_cuda
=
not
args
.
no_cuda
and
torch
.
cuda
.
is_available
()
torch
.
manual_seed
(
args
.
seed
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'../data'
,
train
=
True
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
**
kwargs
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'../data'
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size
=
args
.
test_batch_size
,
shuffle
=
True
,
**
kwargs
)
model
=
Net
().
to
(
device
)
optimizer
=
optim
.
Adadelta
(
model
.
parameters
(),
lr
=
args
.
lr
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
1
,
gamma
=
args
.
gamma
)
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
)
test
(
args
,
model
,
device
,
test_loader
)
scheduler
.
step
()
if
args
.
save_model
:
torch
.
save
(
model
.
state_dict
(),
"mnist_cnn.pt"
)
if
__name__
==
'__main__'
:
main
()
example/mnist_sparse.py
→
example/mnist
/mnist
_sparse.py
View file @
aa26c99e
File moved
pyproject.toml
View file @
aa26c99e
[build-system]
requires
=
[
"setuptools>=41.0"
,
"wheel"
,
"pccm>=0.4.0"
,
"cumm>=0.3.7"
]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu1
18
-0.3.
4
-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu1
20
-0.3.
7
-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend
=
"setuptools.build_meta"
spconv/algo.py
View file @
aa26c99e
...
...
@@ -616,6 +616,7 @@ class SimpleConv:
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
]
self
.
all_desps
=
all_desps
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
...
...
@@ -648,13 +649,13 @@ class SimpleConv:
tile_ms_list
,
tile_ns_list
,
tile_ks_list
,
tile_shape_to_algos
)
self
.
kc_forward_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
],
int
,
bool
],
BestConvAlgoByProfile
]
=
{}
# for forward
self
.
kc_dgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
],
BestConvAlgoByProfile
]
=
{
int
,
bool
],
BestConvAlgoByProfile
]
=
{
}
# for backward weight
self
.
kc_wgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
],
BestConvAlgoByProfile
]
=
{
int
,
bool
],
BestConvAlgoByProfile
]
=
{
}
# for backward weight
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
]],
NVRTCParams
]
=
{}
...
...
@@ -679,11 +680,12 @@ class SimpleConv:
op_type
:
ConvOpType
,
mask_width
:
int
,
fp32_accum
:
Optional
[
bool
]
=
None
,
use_tf32
:
bool
=
True
):
use_tf32
:
bool
=
True
,
bias
:
tv
.
Tensor
=
tv
.
Tensor
(),
scale
:
tv
.
Tensor
=
tv
.
Tensor
()):
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
finally_algos
:
List
[
ConvAlgoDesp
]
=
[]
is_fp16
=
inp
.
dtype
==
tv
.
float16
and
weight
.
dtype
==
tv
.
float16
and
out
.
dtype
==
tv
.
float16
is_fp16
=
inp
.
dtype
==
tv
.
float16
and
weight
.
dtype
==
tv
.
float16
#
and out.dtype == tv.float16
use_f32_as_accum
=
False
kv
=
int
(
np
.
prod
(
weight
.
shape
[
1
:
-
1
]))
# for 3d conv, if reduce axis is too large, may cause nan during
...
...
@@ -703,6 +705,10 @@ class SimpleConv:
layout_w
.
interleave
,
layout_o
.
interleave
,
inp
.
dtype
,
weight
.
dtype
,
out
.
dtype
,
op_type
.
value
)
desps
=
self
.
static_key_to_desps
.
get
(
static_key
,
None
)
# for d in self.all_desps:
# print(d)
# print(len(desps))
# breakpoint()
if
desps
is
None
or
len
(
desps
)
==
0
:
return
finally_algos
for
desp
in
desps
:
...
...
@@ -726,11 +732,21 @@ class SimpleConv:
ldw
=
weight
.
dim
(
-
1
)
ldo
=
out
.
dim
(
-
1
)
mask_width_valid
=
True
if
desp
.
op_type
.
value
==
ConvOpType
.
kBackwardWeight
.
value
:
assert
mask_width
>
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
require_dynamic_mask
=
kv
>
32
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
not
bias
.
empty
()
and
not
scale
.
empty
():
# int8 inference, bias/scale dtype must equal to compute dtype in gemm
assert
bias
.
dtype
==
scale
.
dtype
if
desp
.
dcomp
!=
bias
.
dtype
:
continue
if
not
desp
.
is_int8_inference
:
continue
else
:
if
desp
.
is_int8_inference
:
continue
if
desp
.
is_nvrtc
:
if
not
CompileInfo
.
algo_can_be_nvrtc_compiled
(
desp
.
min_arch
):
continue
...
...
@@ -747,6 +763,12 @@ class SimpleConv:
continue
if
SPCONV_DEBUG_NVRTC_KERNELS
:
desp
.
is_nvrtc
=
True
if
require_dynamic_mask
:
if
not
desp
.
dynamic_mask
:
continue
else
:
if
desp
.
dynamic_mask
:
continue
finally_algos
.
append
(
desp
)
return
finally_algos
...
...
@@ -758,11 +780,12 @@ class SimpleConv:
k
:
int
,
c
:
int
,
arch
:
Tuple
[
int
,
int
],
mask_width
:
int
=
-
1
):
mask_width
:
int
=
-
1
,
need_dynamic_mask
:
bool
=
False
):
if
not
op_type
==
ConvOpType
.
kBackwardWeight
:
# fwd and dgrad don't need
mask_width
=
-
1
key
=
(
i_dtype
,
w_dtype
,
o_dtype
,
k
,
c
,
arch
[
0
],
arch
[
1
],
mask_width
)
key
=
(
i_dtype
,
w_dtype
,
o_dtype
,
k
,
c
,
arch
[
0
],
arch
[
1
],
mask_width
,
need_dynamic_mask
)
if
op_type
==
ConvOpType
.
kForward
:
return
self
.
kc_forward_cache
.
get
(
key
,
None
)
elif
op_type
==
ConvOpType
.
kBackwardInput
:
...
...
@@ -795,8 +818,9 @@ class SimpleConv:
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
custom_names
=
custom_names
)
verbose
=
True
,
custom_names
=
custom_names
,
verbose_path
=
"/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8"
)
mod
.
load
()
return
mod
,
kernel
...
...
@@ -824,7 +848,6 @@ class SimpleConv:
mask_argsort
:
tv
.
Tensor
,
indices
:
tv
.
Tensor
,
reverse_mask
:
bool
,
mask_int_count
:
int
=
1
,
mask_filter
:
int
=
0xffffffff
,
mask_width
:
int
=
-
1
,
mask_output
:
tv
.
Tensor
=
tv
.
Tensor
(),
...
...
@@ -832,17 +855,20 @@ class SimpleConv:
beta
:
float
=
0.0
,
stream
:
int
=
0
,
fp32_accum
:
Optional
[
bool
]
=
None
,
use_tf32
:
bool
=
True
):
use_tf32
:
bool
=
True
,
bias
:
tv
.
Tensor
=
tv
.
Tensor
(),
scale
:
tv
.
Tensor
=
tv
.
Tensor
()):
avail
=
self
.
get_all_available
(
inp
,
weight
,
output
,
layout_i
,
layout_w
,
layout_o
,
arch
,
op_type
,
mask_width
,
fp32_accum
,
use_tf32
)
fp32_accum
,
use_tf32
,
bias
,
scale
)
inp
=
inp
.
clone
()
weight
=
weight
.
clone
()
output
=
output
.
clone
()
print
(
len
(
avail
),
inp
.
dtype
,
weight
.
dtype
,
output
.
dtype
,
bias
.
dtype
,
scale
.
dtype
,
bias
.
empty
(),
scale
.
empty
())
channel_k
=
output
.
dim
(
1
)
channel_c
=
inp
.
dim
(
1
)
weight
=
weight
.
view
([
channel_k
,
-
1
,
channel_c
])
need_dynamic_mask
=
weight
.
dim
(
1
)
>
32
times
:
List
[
float
]
=
[]
all_profile_res
:
List
[
BestConvAlgoByProfile
]
=
[]
group_by_algo
=
{}
...
...
@@ -865,8 +891,9 @@ class SimpleConv:
params
.
indices
=
indices
params
.
mask
=
mask
params
.
mask_output
=
mask_output
params
.
mask_int_count
=
mask_int_count
if
desp
.
is_int8_inference
:
params
.
bias
=
bias
params
.
scale
=
scale
# if op_type == ConvOpType.kBackwardWeight:
# assert not mask_output.empty()
if
op_type
==
ConvOpType
.
kBackwardInput
:
...
...
@@ -909,7 +936,7 @@ class SimpleConv:
# fwd and dgrad don't need
mask_width
=
-
1
key
=
(
inp
.
dtype
,
weight
.
dtype
,
output
.
dtype
,
channel_k
,
channel_c
,
arch
[
0
],
arch
[
1
],
mask_width
)
arch
[
0
],
arch
[
1
],
mask_width
,
need_dynamic_mask
)
with
self
.
lock
:
if
op_type
==
ConvOpType
.
kForward
:
self
.
kc_forward_cache
[
key
]
=
res
...
...
@@ -945,7 +972,9 @@ class SimpleConv:
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
mask_int_count
:
Union
[
int
,
None
]
=
None
):
scale
:
Optional
[
tv
.
Tensor
]
=
None
,
output_add
:
Optional
[
tv
.
Tensor
]
=
None
):
channel_k
=
output
.
dim
(
1
)
channel_c
=
inp
.
dim
(
1
)
# GemmMainUnitTest.stream_synchronize(stream)
...
...
@@ -986,9 +1015,12 @@ class SimpleConv:
params
.
mask_filter
=
mask_filter
params
.
mask_output
=
mask_output
params
.
reverse_mask
=
reverse_mask
params
.
mask_int_count
=
mask_int_count
if
bias
is
not
None
:
params
.
bias
=
bias
if
output_add
is
not
None
and
algo_desp
.
is_int8_inference
:
params
.
output_add
=
output_add
if
scale
is
not
None
and
algo_desp
.
is_int8_inference
:
params
.
scale
=
scale
if
timer
.
enable
:
assert
timer
.
_timer
is
not
None
params
.
timer
=
timer
.
_timer
...
...
spconv/algocore.py
View file @
aa26c99e
...
...
@@ -36,8 +36,8 @@ from cumm.gemm.algospec import TensorOp
def
_assign_gemm_desp_props
(
desp
:
Union
[
ConvAlgoDesp
,
GemmAlgoDesp
],
p
:
Union
[
GemmAlgoParams
,
ConvAlgoParams
]):
desp
.
dtype_a
=
p
.
dtype_a
.
tv_dtype
desp
.
dtype_b
=
p
.
dtype_
a
.
tv_dtype
desp
.
dtype_c
=
p
.
dtype_
a
.
tv_dtype
desp
.
dtype_b
=
p
.
dtype_
b
.
tv_dtype
desp
.
dtype_c
=
p
.
dtype_
c
.
tv_dtype
desp
.
dacc
=
p
.
dtype_acc
.
tv_dtype
desp
.
dcomp
=
p
.
dtype_comp
.
tv_dtype
desp
.
trans_a
=
p
.
trans_a
...
...
@@ -87,6 +87,9 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp
.
element_per_access_a
=
ker
.
input_spec
.
input_iter_a
.
element_per_acc
desp
.
element_per_access_b
=
ker
.
input_spec
.
input_iter_b
.
element_per_acc
desp
.
element_per_access_c
=
ker
.
output_spec
.
out_iter
.
element_per_acc
desp
.
is_int8_inference
=
ker
.
int8_inference
desp
.
dynamic_mask
=
ker
.
dynamic_mask
desp
.
min_arch
=
ker
.
min_arch
()
return
desp
...
...
@@ -141,4 +144,6 @@ def get_conv_param_from_desp(desp: ConvAlgoDesp):
desp
.
interleave_o
)
p
.
mask_sparse
=
desp
.
mask_sparse
p
.
increment_k_first
=
desp
.
increment_k_first
p
.
int8_inference
=
desp
.
is_int8_inference
p
.
dynamic_mask
=
desp
.
dynamic_mask
return
p
spconv/build.py
View file @
aa26c99e
...
...
@@ -39,12 +39,12 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from
spconv.csrc.sparse.inference
import
InferenceOps
all_shuffle
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_AMPERE_PARAMS
all_shuffle
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_shuffle
))
#
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu
=
GemmMainUnitTest
(
all_shuffle
)
cu
.
namespace
=
"cumm.gemm.main"
all_imp
=
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_AMPERE_PARAMS
)
all_imp
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_imp
))
#
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu
=
ConvMainUnitTest
(
all_imp
)
convcu
.
namespace
=
"cumm.conv.main"
gemmtuner
=
GemmTunerSimple
(
cu
)
...
...
spconv/core.py
View file @
aa26c99e
...
...
@@ -619,14 +619,11 @@ IMPLGEMM_AMPERE_PARAMS = [
increment_k_first
=
True
,
access_per_vector
=
1
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -635,13 +632,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -650,13 +648,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -665,13 +664,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -680,13 +680,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -695,13 +696,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -710,13 +712,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -725,13 +728,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -740,13 +744,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -755,13 +760,14 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
],
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -770,7 +776,8 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
]
IMPLGEMM_TURING_PARAMS
=
[
...
...
@@ -779,7 +786,7 @@ IMPLGEMM_TURING_PARAMS = [
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -788,13 +795,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -803,13 +811,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -818,13 +827,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -833,13 +843,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -848,13 +859,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -863,13 +875,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -878,13 +891,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -893,13 +907,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -908,13 +923,14 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
"s8,s8,s8,s32,
s32"
,
[
"s8,s8,s8,s32,
f32"
,
"s8,s8,s8,s32,f16"
]
,
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -923,7 +939,8 @@ IMPLGEMM_TURING_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
),
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
32
,
16
,
16
),
(
16
,
16
,
16
),
...
...
spconv/core_cc/csrc/sparse/all/__init__.pyi
View file @
aa26c99e
...
...
@@ -144,7 +144,7 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0
, mask_int_count: int = 1
) -> int:
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
...
...
@@ -167,11 +167,10 @@ class SpconvOps:
dilation:
transposed:
stream_int:
mask_int_count:
"""
...
@staticmethod
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0
, mask_int_count: int = 1
) -> int:
def generate_conv_inds_stage2_mask_direct_table(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
...
...
@@ -194,11 +193,10 @@ class SpconvOps:
dilation:
transposed:
stream_int:
mask_int_count:
"""
...
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0
, mask_int_count: int = 1
) -> int:
def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
Args:
indices:
...
...
@@ -214,7 +212,6 @@ class SpconvOps:
indice_pair_mask:
backward:
stream_int:
mask_int_count:
"""
...
@staticmethod
...
...
@@ -383,65 +380,25 @@ class SpconvOps:
"""
...
@staticmethod
def sort_1d_by_key_allocator
_mask32
(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0
, mask_count: int = 1
) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
mask_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_
mask32_
v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0
, mask_count: int = 1
) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
alloc_func:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask128_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0) -> Tensor:
"""
Args:
data:
allocator:
indices:
stream:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
"""
...
@staticmethod
def sort_1d_by_key_allocator_mask_auto_v2(data: Tensor, alloc_param, indices: Tensor = Tensor(), stream: int = 0, mask_int_count: int = 1) -> Tensor:
"""
Args:
data:
alloc_param:
indices:
stream:
mask_int_count:
mask_count:
"""
...
@staticmethod
...
...
@@ -598,7 +555,7 @@ class SpconvOps:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor,
int,
int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
"""
Args:
allocator:
...
...
spconv/core_cc/csrc/sparse/convops/convops.pyi
View file @
aa26c99e
...
...
@@ -20,7 +20,7 @@ class ConvTunerSimple:
arch:
"""
...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True) -> List[ConvAlgoDesp]:
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True
, bias: Tensor = Tensor(), scale: Tensor = Tensor()
) -> List[ConvAlgoDesp]:
"""
Args:
inp:
...
...
@@ -38,6 +38,8 @@ class ConvTunerSimple:
auto_fp32_accum:
fp32_accum:
use_tf32:
bias:
scale:
"""
...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
...
...
@@ -48,7 +50,7 @@ class ConvTunerSimple:
stream_int:
"""
...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0,
mask_int_count: int = 1,
auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]:
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True
, bias: Tensor = Tensor(), scale: Tensor = Tensor()
) -> Tuple[ConvTuneResult, float]:
"""
Args:
op_type:
...
...
@@ -72,14 +74,15 @@ class ConvTunerSimple:
alpha:
beta:
stream_int:
mask_int_count:
auto_fp32_accum:
fp32_accum:
num_run:
use_tf32:
bias:
scale:
"""
...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]:
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1
, need_dynamic_mask: bool = False
) -> Tuple[Any, bool]:
"""
Args:
op_type:
...
...
@@ -90,9 +93,10 @@ class ConvTunerSimple:
c:
arch:
mask_width:
need_dynamic_mask:
"""
...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0,
mask_int_count: int = 1,
workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None:
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_
, scale: Tensor = Tensor(), output_add: Tensor = Tensor()
) -> None:
"""
Args:
profile_res:
...
...
@@ -110,7 +114,6 @@ class ConvTunerSimple:
alpha:
beta:
stream_int:
mask_int_count:
workspace:
verbose:
timer:
...
...
@@ -119,6 +122,8 @@ class ConvTunerSimple:
act_alpha:
act_beta:
act_type:
scale:
output_add:
"""
...
def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int:
...
...
spconv/core_cc/csrc/sparse/convops/spops.pyi
View file @
aa26c99e
...
...
@@ -63,7 +63,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor,
mask_int_count: int,
arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True
, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0
) -> Tuple[int, Any]:
"""
Args:
allocator:
...
...
@@ -75,7 +75,6 @@ class ConvGemmOps:
mask_argsort_fwd_splits:
num_activate_out:
masks:
mask_int_count:
arch:
is_train:
is_subm:
...
...
@@ -88,10 +87,14 @@ class ConvGemmOps:
act_beta:
act_type:
use_tf32:
output_scale:
scale:
output_add:
output_add_scale:
"""
...
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor,
mask_int_count: int,
arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
"""
Args:
allocator:
...
...
@@ -107,7 +110,6 @@ class ConvGemmOps:
mask_argsort_bwd_splits:
mask_output_fwd:
masks:
mask_int_count:
arch:
mask_width:
is_subm:
...
...
spconv/csrc/sparse/all.py
View file @
aa26c99e
...
...
@@ -30,7 +30,7 @@ from .alloc import ExternalAllocator, ThrustAllocator
from
spconv.constants
import
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
,
AllocKeys
import
re
import
os
from
cumm.gemm.codeops
import
dispatch
class
CustomThrustLib
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
...
...
@@ -462,7 +462,6 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...
...
@@ -489,7 +488,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int
, mask_int_count
);
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
...
...
@@ -513,7 +512,6 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(output_dims.size() == ndim && input_dims.size() == ndim &&
...
...
@@ -540,7 +538,7 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int
, mask_int_count
);
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
...
...
@@ -561,7 +559,6 @@ class SpconvOps(pccm.Class):
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"backward"
,
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int = 0"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
int ndim = indices.dim(1) - 1;
TV_ASSERT_RT_ERR(input_dims.size() == ndim &&
...
...
@@ -582,7 +579,7 @@ class SpconvOps(pccm.Class):
indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward,
stream_int
, mask_int_count
);
stream_int);
}}
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
...
...
@@ -909,7 +906,7 @@ class SpconvOps(pccm.Class):
"""
)
return
code
def
sort_1d_by_key_allocator_template
(
self
,
use_allocator
:
bool
,
int_count
:
int
=
1
):
def
sort_1d_by_key_allocator_template
(
self
,
use_allocator
:
bool
):
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
...
...
@@ -924,18 +921,7 @@ class SpconvOps(pccm.Class):
"tv::Tensor()"
,
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
code_after_include
=
f
"""
template <typename T> struct SmallOrEqualTo {{
TV_HOST_DEVICE_INLINE T operator()(const T &x, const T &y) const {{
return x < y;
}}
}};
template <typename T> __global__ void mask_input(T* inp, T mask, int size){{
for (int i : tv::KernelLoopX<int>(size)){{
inp[i] &= mask;
}}
}}
"""
code
.
arg
(
"mask_count"
,
"int"
,
"1"
,
pyanno
=
"int"
)
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
self
.
cuda_common_kernel
)
if
not
use_allocator
:
...
...
@@ -945,20 +931,29 @@ class SpconvOps(pccm.Class):
code
.
raw
(
f
"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
indices = tv::empty({{data.dim(0)
/
{
int_count
}
}}, tv::int32, 0);
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
// auto timer = tv::CUDATimer();
tv::dispatch<int32_t, uint32_t, int64_t, uint64_t>(data.dtype(), [&](auto I){{
using T_ = TV_DECLTYPE(I);
using T =
{
"T_"
if
int_count
==
1
else
f
"thrust::tuple<
{
', '
.
join
([
'T_'
]
*
int_count
)
}
>
"
}
;
"""
)
# nested tv::dispatch may cause compiler bug in msvc.
for
dtype
in
dispatch
(
code
,
[
dtypes
.
int32
,
dtypes
.
int64
,
dtypes
.
uint32
,
dtypes
.
uint64
],
"data.dtype()"
):
code
.
raw
(
f
"""
using T_ =
{
dtype
}
;
tv::dispatch_int<1, 2, 3, 4>(mask_count, [&](auto IV){{
constexpr int I = TV_DECLTYPE(IV)::value;
// we can't use thrust::tuple in mp_repeat_c directly because
// thrust tuple actually has fixed size template arguments.
using T = tv::mp_rename<tv::mp_repeat_c<tv::mp_list<T_>, I>, thrust::tuple>;
thrust::device_ptr<T> ptr_tr(reinterpret_cast<T*>(data.data_ptr<T_>()));
thrust::device_ptr<int32_t> ptr_k(indices.data_ptr<int32_t>());
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
auto ctx2 = thrust::cuda::par(allocator).on(stream_cu);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0)
/
{
int_count
}
, ptr_k);
thrust::sort_by_key(ctx2, ptr_tr, ptr_tr + data.dim(0), ptr_k);
}});
"""
)
code
.
raw
(
f
"""
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
"""
)
...
...
@@ -967,71 +962,12 @@ class SpconvOps(pccm.Class):
@
pccm
.
pybind
.
mark
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_mask32
(
self
):
# for python
def
sort_1d_by_key_allocator
(
self
):
return
self
.
sort_1d_by_key_allocator_template
(
False
)
@
pccm
.
pybind
.
mark
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_mask32_v2
(
self
):
# for python
return
self
.
sort_1d_by_key_allocator_template
(
True
)
@
pccm
.
pybind
.
mark
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_mask128
(
self
):
# for python
return
self
.
sort_1d_by_key_allocator_template
(
False
,
4
)
@
pccm
.
pybind
.
mark
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_mask128_v2
(
self
):
# for python
return
self
.
sort_1d_by_key_allocator_template
(
True
,
4
)
def
sort_1d_by_key_allocator_mask_auto_template
(
self
,
use_allocator
:
bool
):
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
code
.
arg
(
"data"
,
"tv::Tensor"
)
if
not
use_allocator
:
code
.
arg
(
"alloc_param"
,
"std::function<std::uintptr_t(std::size_t)>"
)
else
:
code
.
arg
(
"alloc_param"
,
"ThrustAllocator&"
)
code
.
arg
(
"indices"
,
"tv::Tensor"
,
"tv::Tensor()"
,
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
switch (mask_int_count){{
case 1:
return sort_1d_by_key_allocator_mask32
{
"_v2"
if
use_allocator
else
""
}
(data, alloc_param, indices, stream);
case 4:
return sort_1d_by_key_allocator_mask128
{
"_v2"
if
use_allocator
else
""
}
(data, alloc_param, indices, stream);
default:
TV_ASSERT_RT_ERR(false, "Not implement for other mask_int_count");
return tv::Tensor();
}}
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
sort_1d_by_key_allocator_mask_auto
(
self
):
return
self
.
sort_1d_by_key_allocator_mask_auto_template
(
False
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
sort_1d_by_key_allocator_mask_auto_v2
(
self
):
return
self
.
sort_1d_by_key_allocator_mask_auto_template
(
True
)
@
_STATIC_FUNCTION
def
sort_1d_by_key_allocator_v2
(
self
):
# for cpp only
return
self
.
sort_1d_by_key_allocator_template
(
True
)
@
pccm
.
pybind
.
mark
...
...
@@ -1622,7 +1558,7 @@ class SpconvOps(pccm.Class):
code
.
raw
(
f
"""
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
hash_size = int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory);
hash_size =
tv::align_up(
int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory)
, 2)
;
}}
size_t res = 0;
if (subm){{
...
...
@@ -1655,7 +1591,7 @@ class SpconvOps(pccm.Class):
max_act_out_in_theory, subm, use_int64_hash_k, direct_table);
int hash_size = 2 * num_act_out_bound;
if (direct_table){{
hash_size = int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory);
hash_size =
tv::align_up(
int(
{
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
}
* max_act_out_in_theory)
, 2)
;
}}
if (use_int64_hash_k){{
auto ten = tv::from_blob(workspace, {{int64_t(hash_size)}}, tv::int64, 0);
...
...
@@ -1720,10 +1656,10 @@ class SpconvOps(pccm.Class):
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
int mask_int_count =
(kv + 31) /
32;
if (mask_int_count > 1 && mask_int_count < 4)
mask_int_count = 4;
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
int mask_int_count =
tv::div_up(kv,
32
)
;
//
if (mask_int_count > 1 && mask_int_count < 4)
//
mask_int_count = 4;
//
TV_ASSERT_RT_ERR(mask_int_count == 1 || mask_int_count == 4, "Not Implement too large kernel");
// TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int> out_shape;
if (!subm){{
...
...
@@ -1845,14 +1781,14 @@ class SpconvOps(pccm.Class):
pair_mask = preallocated.at(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
);
}}else{{
pair_mask = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
{{mask_split_count, num_act_in
*
mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_in
,
mask_int_count}}, tv::uint32, 0, stream_int);
}}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int
, mask_int_count
);
batch_size, input_dims, ksize, dilation, pair_mask, is_train, stream_int);
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_
mask_auto_
v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
}}
"""
)
with
code
.
else_
():
...
...
@@ -1958,11 +1894,11 @@ Your Conv Params: )" << "\\n";
pair_fwd = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
{{kv, num_act_out}}, -1, indices.dtype(), indices.device(), stream_int);
pair_mask_fwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
{{mask_split_count, num_act_out
*
mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, num_act_out
,
mask_int_count}}, tv::uint32, 0, stream_int);
pair_mask_bwd = tv::Tensor();
if (is_train){{
pair_mask_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMaskBwd
)
}
,
{{mask_split_count, indices.dim(0)
*
mask_int_count}}, tv::uint32, 0, stream_int);
{{mask_split_count, indices.dim(0)
,
mask_int_count}}, tv::uint32, 0, stream_int);
}}
}}
if (!direct_table){{
...
...
@@ -1994,13 +1930,13 @@ Your Conv Params: )" << "\\n";
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int
, mask_int_count
);
transposed, stream_int);
}}else{{
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int
, mask_int_count
);
transposed, stream_int);
}}
}}
"""
)
...
...
@@ -2030,21 +1966,21 @@ Your Conv Params: )" << "\\n";
}}
}}else{{
if (!is_train){{
sort_1d_by_key_allocator_
mask_auto_
v2(pair_mask_fwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
}}else{{
sort_1d_by_key_allocator_
mask_auto_
v2(pair_mask_fwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
sort_1d_by_key_allocator_
mask_auto_
v2(pair_mask_bwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count);
}}
}}
}}
"""
)
code
.
raw
(
f
"""
return std::make_tuple(mask_tensor, num_act_out
, mask_int_count
);
return std::make_tuple(mask_tensor, num_act_out);
"""
)
return
code
.
ret
(
"std::tuple<tv::Tensor,
int,
int>"
)
return
code
.
ret
(
"std::tuple<tv::Tensor, int>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
...
...
spconv/csrc/sparse/convops.py
View file @
aa26c99e
...
...
@@ -936,7 +936,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
"int, int, int, int, int>"
))
self
.
add_typedef
(
"algo_cache_key_t"
,
"std::tuple<int, int, int, int, "
"int, int, int, int>"
)
"int, int, int, int
, bool
>"
)
self
.
add_member
(
"desps_"
,
"std::vector<tv::gemm::ConvAlgoDesp>"
)
self
.
add_member
(
...
...
@@ -1009,7 +1009,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"auto_fp32_accum"
,
"bool"
)
code
.
arg
(
"fp32_accum"
,
"bool"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"bias"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"scale"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
raw
(
f
"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
...
...
@@ -1077,7 +1080,22 @@ class ConvTunerSimple(pccm.ParameterizedClass):
TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}}
bool require_dynamic_mask = kv > 32;
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!bias.empty() && !scale.empty()){{
TV_ASSERT_RT_ERR(bias.dtype() == scale.dtype(), "bias/scale dtype must equal to compute dtype in gemm");
if (desp.dcomp != bias.dtype()){{
continue;
}}
if (!desp.is_int8_inference){{
continue;
}}
}}else{{
if (desp.is_int8_inference){{
continue;
}}
}}
auto desp2 = desp;
if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
...
...
@@ -1093,6 +1111,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
}}
}}
if (require_dynamic_mask){{
if (!desp.dynamic_mask){{
continue;
}}
}}else{{
if (desp.dynamic_mask){{
continue;
}}
}}
finally_algos.push_back(desp2);
}}
}}
...
...
@@ -1138,11 +1165,14 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
arg
(
"auto_fp32_accum"
,
"bool"
,
"true"
)
code
.
arg
(
"fp32_accum"
,
"bool"
,
"false"
)
code
.
arg
(
"num_run"
,
"int"
,
"5"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"bias"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"scale"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
...
...
@@ -1157,12 +1187,15 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width,
auto_fp32_accum, fp32_accum, use_tf32);
auto_fp32_accum, fp32_accum, use_tf32,
bias, scale);
inp = inp.clone();
weight = weight.clone();
bool need_dynamic_mask = weight.dim(1) > 32;
output = output.clone();
int channel_k = output.dim(1);
int channel_c = inp.dim(1);
weight = weight.view(channel_k, -1, channel_c);
std::vector<ConvTuneResult> all_profile_res;
std::unordered_set<int> splitk_tests;
...
...
@@ -1187,7 +1220,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.indices = indices;
params.mask = mask;
params.mask_output = mask_output;
params.mask_int_count = mask_int_count;
if (desp.is_int8_inference){{
params.bias = bias;
params.scale = scale;
}}
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
...
...
@@ -1246,7 +1282,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
algo_cache_key_t key;
key = std::make_tuple(int(inp.dtype()), int(weight.dtype()),
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width);
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width
, need_dynamic_mask
);
{{
std::lock_guard<std::mutex> guard(mutex_);
...
...
@@ -1279,6 +1315,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"k, c"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"mask_width"
,
"int"
,
"-1"
)
code
.
arg
(
"need_dynamic_mask"
,
"bool"
,
"false"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
.
ret
(
"std::tuple<ConvTuneResult, bool>"
)
...
...
@@ -1290,7 +1328,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
}}
algo_cache_key_t key;
key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c,
std::get<0>(arch), std::get<1>(arch), mask_width);
std::get<0>(arch), std::get<1>(arch), mask_width
, need_dynamic_mask
);
ConvTuneResult res;
bool exists = false;
{{
...
...
@@ -1338,7 +1376,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
arg
(
"workspace"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"verbose"
,
f
"bool"
,
"false"
)
...
...
@@ -1350,7 +1387,10 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code
.
arg
(
"act_alpha"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_beta"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_type"
,
f
"tv::gemm::Activation"
,
"tv::gemm::Activation::kNone"
,
"cumm.tensorview.gemm.Activation = Activation.None_"
)
code
.
arg
(
"scale"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"output_add"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
...
...
@@ -1376,6 +1416,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.output = output;
params.verbose = verbose;
params.bias = bias;
params.scale = scale;
params.split_k_slices = split_k_slices;
params.alpha = alpha;
...
...
@@ -1383,7 +1424,9 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.act_alpha = act_alpha;
params.act_beta = act_beta;
params.act_type = act_type;
if (!output_add.empty() && desp.is_int8_inference){{
params.output_add = output_add;
}}
params.stream = stream_int;
params.mask_argsort = mask_argsort;
params.indices = indices;
...
...
@@ -1393,7 +1436,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
params.mask_width = mask_width;
params.mask_output = mask_output;
params.reverse_mask = reverse_mask;
params.mask_int_count = mask_int_count;
if (timer.enable()){{
params.timer = timer;
...
...
@@ -2039,7 +2081,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"mask_int_count"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"is_train, is_subm"
,
"bool"
,
"false"
)
...
...
@@ -2055,7 +2096,13 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"act_beta"
,
f
"float"
,
"0.0"
)
code
.
arg
(
"act_type"
,
f
"tv::gemm::Activation"
,
"tv::gemm::Activation::kNone"
,
"cumm.tensorview.gemm.Activation = Activation.None_"
)
code
.
arg
(
"use_tf32"
,
"bool"
,
"true"
)
code
.
arg
(
"output_scale"
,
"float"
,
"1.0"
)
code
.
arg
(
"scale"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"output_add"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"output_add_scale"
,
"float"
,
"1.0"
)
code
.
arg
(
"output_dtype"
,
"int"
,
"-1"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
...
...
@@ -2072,13 +2119,18 @@ class ConvGemmOps(pccm.ParameterizedClass):
int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel);
int kv = filters.dim(1);
int mask_int_count = tv::div_up(kv, 32);
tv::Tensor out_features;
if (output_dtype < 0){{
output_dtype = int(features.dtype());
}}
if (is_subm){{
out_features = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}},
features.
dtype
(
), features.device(), stream_int);
{{num_activate_out, out_channel}},
tv::DType(output_
dtype), features.device(), stream_int);
}}else{{
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}},
features.
dtype
(
), features.device(), stream_int);
{{num_activate_out, out_channel}},
tv::DType(output_
dtype), features.device(), stream_int);
}}
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
...
...
@@ -2113,20 +2165,24 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count, // mask_int_count is after stream_int
auto_fp32_accum,
fp32_accum,
5, // num_run
use_tf32);
use_tf32,
bias,
scale);
tune_res = std::get<0>(tune_res_time);
}}
float alpha = 1.0;
if (tune_res.algo_desp.is_int8_inference){{
alpha = output_scale;
}}
int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{
mask_output_fwd = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskOutputFwd
)
}
,
{{num_split, tv::div_up(num_activate_out, mask_width)
*
mask_int_count}},
{{num_split, tv::div_up(num_activate_out, mask_width)
,
mask_int_count}},
tv::uint32, features.device(), stream_int);
for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]);
...
...
@@ -2139,9 +2195,15 @@ class ConvGemmOps(pccm.ParameterizedClass):
for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1;
if (!bias.empty()){{
if (!bias.empty() && !tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = 1;
}}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = output_add_scale;
}}
if (j > 0){{
bias = tv::Tensor();
}}
...
...
@@ -2158,9 +2220,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // reverse_mask
mask_ptr[j],
-1, // mask_width
1.0
, beta,
alpha
, beta,
stream_int,
mask_int_count, // mask_int_count is after stream_int
tv::Tensor(), // workspace
false, // verbose
timer,
...
...
@@ -2168,7 +2229,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
bias,
act_alpha,
act_beta,
act_type);
act_type,
scale,
output_add);
}}
// auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int);
...
...
@@ -2193,7 +2256,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"mask_output_fwd"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"mask_int_count"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"mask_width"
,
"int"
)
...
...
@@ -2286,7 +2348,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count,
auto_fp32_accum,
fp32_accum,
5, // num_run
...
...
@@ -2311,7 +2372,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
mask_int_count,
auto_fp32_accum,
fp32_accum,
5, // num_run
...
...
@@ -2354,7 +2414,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
-1, // mask_width
1.0, beta,
stream_int,
mask_int_count,
tv::Tensor(), // workspace
false, // verbose
timer);
...
...
@@ -2372,7 +2431,6 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_width,
1.0, 0.0,
stream_int,
mask_int_count,
workspace, // workspace
false, // verbose
timer);
...
...
spconv/csrc/sparse/indices.py
View file @
aa26c99e
...
...
@@ -829,7 +829,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
uint32_t filter_mask_in = (1u << ((RS - 1 - filter_offset) % 32));
uint32_t filter_mask_in_offset = (RS - 1 - filter_offset) / 32;
// uint32_t filter_mask_center = (1u << (RS / 2));
loc_iter.set_filter_offset(filter_offset);
int indices_pair_size_mul_RS = indices_pair_size * RS;
int filter_offset_mul_indices_pair_size = filter_offset * indices_pair_size;
...
...
@@ -1255,13 +1254,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream
// TODO handle num input == 0
int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
// indice_pairs_bwd: [kv, num_act_in] or empty
// indice_pairs_fwd: [kv, num_act_out]
auto ctx = tv::Context();
...
...
@@ -1504,7 +1503,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"is_train"
,
"bool"
,
"true"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"mask_int_count"
,
"int"
,
"1"
)
code
.
raw
(
f
"""
int num_act_in_real = indices.dim(0);
...
...
@@ -1523,6 +1521,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
padding[i] = (ksize[i] / 2) * dilation[i];
}}
int kv = ksize.op<tv::arrayops::prod>();
int mask_int_count = tv::div_up(kv, 32);
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [1 or 2, kv, num_act_in] if mask else [2, kv, num_act_in]
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
...
...
@@ -1556,8 +1555,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
if (!indice_pair_mask.empty()){{
TV_ASSERT_RT_ERR(indice_pairs.ndim() == 3, "error");
TV_ASSERT_RT_ERR(indice_pairs.dim(0) == (is_train ? 2 : 1), "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() ==
2
, "error");
// indice_pair_mask: [mask_split_count, num_act_in]
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() ==
3
, "error");
// indice_pair_mask: [mask_split_count, num_act_in
, num_mask_per_point
]
if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0].slice_first_axis(0, num_act_in_real);
auto mask_1 = indice_pair_mask[1].slice_first_axis(0, num_act_in_real);
...
...
@@ -1571,13 +1570,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
indices.dim(0), indice_pairs.dim(2), kv, is_train);
}}else{{
// indice_pair_mask: [1, num_act_in]
// indice_pair_mask: [1, num_act_in
, num_mask_per_point
]
tv::cuda::Launch lanucher_fill(num_act_in_real, custream);
if (mask_int_count == 1)
if (mask_int_count == 1)
{{
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indices.dim(0));
else
}}
else{{
lanucher_fill(init_subm_multiple_mask_int_kernel<uint32_t>,
indice_pair_mask.data_ptr<uint32_t>(), kv / 2, indices.dim(0), mask_int_count);
}}
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t,
{
loc_type
}
>, loc_iter, hash,
indices.data_ptr<const int>(), indice_pairs.data_ptr<int>(),
...
...
spconv/csrc/sparse/pointops.py
View file @
aa26c99e
...
...
@@ -465,14 +465,14 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
table_launcher(kernel::assign_table<table_t>, hash, indices.data_ptr<int>(),
count.data_ptr<int>(),
layout, voxels.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
launcher(kernel::generate_voxel<table_t>, hash, points.data_ptr<const
{
self
.
dtype
}
>(),
point_indice_data.data_ptr<const int64_t>(), voxels.data_ptr<
{
self
.
dtype
}
>(),
num_per_voxel.data_ptr<int>(), points_voxel_id.data_ptr<int64_t>(), points.dim(1), voxels.dim(1),
voxels.dim(0), vsize_tv, coors_range_tv,
grid_size_tv, grid_stride_tv, points.dim(0));
auto count_cpu = count.cpu();
int count_val = count_cpu.item<int32_t>();
count_val = count_val > voxels.dim(0) ? voxels.dim(0) : count_val;
auto voxel_launcher = tv::cuda::Launch(count_val, custream);
if (empty_mean){{
launcher(kernel::voxel_empty_fill_mean, voxels.data_ptr<
{
self
.
dtype
}
>(),
...
...
spconv/pytorch/conv.py
View file @
aa26c99e
...
...
@@ -37,10 +37,23 @@ from spconv.utils import nullcontext
from
torch.nn.init
import
calculate_gain
from
cumm
import
tensorview
as
tv
from
torch.nn
import
functional
as
F
FILTER_HWIO
=
False
_MAX_NUM_VOXELS_DURING_TRAINING
=
"max_num_voxels_during_training"
def
_apply_act
(
x
:
torch
.
Tensor
,
act_type
:
tv
.
gemm
.
Activation
,
act_alpha
:
float
,
act_beta
:
float
):
if
act_type
==
tv
.
gemm
.
Activation
.
None_
:
return
x
elif
act_type
==
tv
.
gemm
.
Activation
.
ReLU
:
return
F
.
relu
(
x
)
elif
act_type
==
tv
.
gemm
.
Activation
.
Sigmoid
:
return
F
.
sigmoid
(
x
)
elif
act_type
==
tv
.
gemm
.
Activation
.
LeakyReLU
:
return
F
.
leaky_relu
(
x
,
act_alpha
)
else
:
raise
NotImplementedError
class
SparseConvolution
(
SparseModule
):
__constants__
=
[
...
...
@@ -104,7 +117,7 @@ class SparseConvolution(SparseModule):
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
))
self
.
record_voxel_count
=
record_voxel_count
if
algo
is
None
:
if
kv
<=
32
and
not
CPU_ONLY_BUILD
:
if
kv
<=
128
and
not
CPU_ONLY_BUILD
:
if
kv
<
8
:
algo
=
ConvAlgo
.
MaskImplicitGemm
else
:
...
...
@@ -139,6 +152,19 @@ class SparseConvolution(SparseModule):
self
.
act_type
=
act_type
self
.
act_alpha
=
act_alpha
self
.
act_beta
=
act_beta
self
.
enable_int8_test_mode
:
bool
=
False
self
.
_int8_weight
=
torch
.
Tensor
()
# calculated by max(abs(weight)) for each channel
self
.
_int8_weight_scale
=
torch
.
Tensor
()
# calculated by scale self.bias with _int8_input_scale
self
.
_int8_bias
=
torch
.
Tensor
()
# int8 inference must set _int8_input_scale
self
.
_int8_input_scale
:
Optional
[
float
]
=
None
# if _int8_output_scale unset, will execute s8 @ s8 => f16/f32 (weight dtype), i.e. dequantization
self
.
_int8_output_scale
:
Optional
[
float
]
=
None
if
self
.
conv1x1
:
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
self
.
reset_parameters
()
...
...
@@ -151,11 +177,19 @@ class SparseConvolution(SparseModule):
return
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
)
return
None
def
set_int8_test
(
self
,
enable
:
bool
,
input_scale
:
float
,
output_scale
:
Optional
[
float
]
=
None
,
weight_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
self
.
_int8_input_scale
=
input_scale
self
.
_int8_output_scale
=
output_scale
if
weight_scale
is
not
None
:
self
.
_int8_weight_scale
=
weight_scale
self
.
enable_int8_test_mode
=
enable
def
_load_weight_different_layout
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
if
self
.
record_voxel_count
and
not
self
.
subm
and
not
self
.
inverse
and
_MAX_NUM_VOXELS_DURING_TRAINING
not
in
state_dict
:
state_dict
[
prefix
+
_MAX_NUM_VOXELS_DURING_TRAINING
]
=
torch
.
zeros
(
name
=
prefix
+
_MAX_NUM_VOXELS_DURING_TRAINING
if
self
.
record_voxel_count
and
not
self
.
subm
and
not
self
.
inverse
and
name
not
in
state_dict
:
state_dict
[
name
]
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
)
if
not
SAVED_WEIGHT_LAYOUT
:
return
...
...
@@ -255,7 +289,10 @@ class SparseConvolution(SparseModule):
def
is_inverseable
(
self
):
return
self
.
indice_key
is
not
None
and
not
self
.
subm
def
forward
(
self
,
input
:
SparseConvTensor
):
def
forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
return
self
.
_conv_forward
(
input
,
self
.
weight
,
self
.
bias
,
add_input
)
def
_conv_forward
(
self
,
input
:
SparseConvTensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
assert
isinstance
(
input
,
SparseConvTensor
)
assert
input
.
features
.
shape
[
1
]
==
self
.
in_channels
,
"channel size mismatch"
...
...
@@ -264,9 +301,34 @@ class SparseConvolution(SparseModule):
indices
=
input
.
indices
spatial_shape
=
input
.
spatial_shape
batch_size
=
input
.
batch_size
bias_for_training
=
self
.
bias
if
self
.
training
else
None
bias_for_infer
=
self
.
bias
if
not
self
.
training
else
None
bias_for_training
=
bias
if
self
.
training
else
None
bias_for_infer
=
bias
if
not
self
.
training
else
None
output_scale
=
None
output_add_scale
=
1.0
if
self
.
enable_int8_test_mode
:
assert
not
self
.
training
,
"must in eval mode"
assert
self
.
algo
==
ConvAlgo
.
MaskImplicitGemm
,
"int8 inference only support MaskImplicitGemm"
assert
bias_for_infer
is
not
None
,
"conv-bn-relu must be fused"
assert
self
.
_int8_input_scale
is
not
None
if
features
.
dtype
!=
torch
.
int8
:
# quantize
features
=
torch
.
clamp
(
torch
.
round
(
features
/
self
.
_int8_input_scale
),
-
128
,
127
).
to
(
torch
.
int8
)
output_scale
=
self
.
_int8_output_scale
int8_out_scale
=
output_scale
if
int8_out_scale
is
None
:
int8_out_scale
=
1
if
add_input
is
not
None
:
assert
add_input
.
int8_scale
is
not
None
,
"only support int8 add"
output_add_scale
=
add_input
.
int8_scale
if
self
.
_int8_weight
.
numel
()
==
0
:
with
torch
.
no_grad
():
assert
ALL_WEIGHT_IS_KRSC
weight_scales
=
torch
.
abs
(
weight
).
view
(
self
.
out_channels
,
-
1
).
max
(
1
)[
0
]
num_1s
=
[
1
]
*
(
self
.
ndim
+
1
)
self
.
_int8_weight
=
(
weight
/
weight_scales
.
view
(
self
.
out_channels
,
*
num_1s
)
*
127
).
to
(
torch
.
int8
)
if
self
.
_int8_weight_scale
.
numel
()
==
0
:
self
.
_int8_weight_scale
=
int8_out_scale
/
(
self
.
_int8_input_scale
*
weight_scales
)
self
.
_int8_bias
=
bias_for_infer
*
int8_out_scale
if
self
.
training
:
msg
=
"act don't support backward, only used in inference"
assert
self
.
act_type
==
tv
.
gemm
.
Activation
.
None_
,
msg
...
...
@@ -310,18 +372,19 @@ class SparseConvolution(SparseModule):
"out_channels"
:
self
.
out_channels
,
}
}
if
self
.
conv1x1
:
if
self
.
conv1x1
and
not
self
.
enable_int8_test_mode
:
# in int8 test mode, we don't implement conv1x1 via mm.
if
FILTER_HWIO
:
features
=
torch
.
mm
(
input
.
features
,
self
.
weight
.
view
(
self
.
out_channels
,
self
.
in_channels
).
T
)
weight
.
view
(
self
.
out_channels
,
self
.
in_channels
).
T
)
else
:
features
=
torch
.
mm
(
input
.
features
,
self
.
weight
.
view
(
self
.
in_channels
,
self
.
out_channels
))
weight
.
view
(
self
.
in_channels
,
self
.
out_channels
))
if
self
.
bias
is
not
None
:
features
+=
self
.
bias
if
bias
is
not
None
:
features
+=
bias
out_tensor
=
out_tensor
.
replace_feature
(
features
)
# padding may change spatial shape of conv 1x1.
out_tensor
.
spatial_shape
=
out_spatial_shape
...
...
@@ -413,7 +476,7 @@ class SparseConvolution(SparseModule):
if
self
.
subm
:
out_features
=
Fsp
.
indice_subm_conv
(
features
,
self
.
weight
,
weight
,
indice_pairs_calc
,
indice_pair_num
,
outids
.
shape
[
0
],
...
...
@@ -427,7 +490,7 @@ class SparseConvolution(SparseModule):
if
self
.
inverse
:
out_features
=
Fsp
.
indice_inverse_conv
(
features
,
self
.
weight
,
weight
,
indice_pairs_calc
,
indice_pair_num
,
outids
.
shape
[
0
],
...
...
@@ -440,7 +503,7 @@ class SparseConvolution(SparseModule):
else
:
out_features
=
Fsp
.
indice_conv
(
features
,
self
.
weight
,
weight
,
indice_pairs_calc
,
indice_pair_num
,
outids
.
shape
[
0
],
...
...
@@ -481,11 +544,13 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits
=
datas
.
mask_argsort_fwd_splits
mask_argsort_bwd_splits
=
datas
.
mask_argsort_bwd_splits
masks
=
datas
.
masks
mask_int_count
=
datas
.
mask_int_count
assert
self
.
subm
,
"only support reuse subm indices"
self
.
_check_subm_reuse_valid
(
input
,
spatial_shape
,
datas
)
else
:
if
input
.
benchmark
:
torch
.
cuda
.
synchronize
()
t
=
time
.
time
()
with
input
.
_timer
.
namespace
(
"gen_pairs"
):
# we need to gen bwd indices for regular conv
# because it may be inversed.
...
...
@@ -514,7 +579,11 @@ class SparseConvolution(SparseModule):
print
(
msg
,
file
=
sys
.
stderr
)
spconv_save_debug_data
(
indices
)
raise
e
if
input
.
benchmark
:
torch
.
cuda
.
synchronize
()
interval
=
time
.
time
()
-
t
out_tensor
.
benchmark_record
[
self
.
name
][
"indice_gen_time"
].
append
(
interval
)
outids
=
res
[
0
]
num_inds_per_loc
=
res
[
1
]
pair_fwd
=
res
[
2
]
...
...
@@ -524,7 +593,6 @@ class SparseConvolution(SparseModule):
mask_argsort_fwd_splits
=
res
[
6
]
mask_argsort_bwd_splits
=
res
[
7
]
masks
=
res
[
8
]
mask_int_count
=
res
[
9
]
if
self
.
indice_key
is
not
None
:
indice_data
=
ImplicitGemmIndiceData
(
outids
,
...
...
@@ -543,8 +611,7 @@ class SparseConvolution(SparseModule):
ksize
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
mask_int_count
=
mask_int_count
)
dilation
=
self
.
dilation
)
msg
=
f
"your indice key
{
self
.
indice_key
}
already exists in this sparse tensor."
assert
self
.
indice_key
not
in
indice_dict
,
msg
indice_dict
[
self
.
indice_key
]
=
indice_data
...
...
@@ -552,16 +619,43 @@ class SparseConvolution(SparseModule):
torch
.
cuda
.
synchronize
()
t
=
time
.
time
()
num_activate_out
=
outids
.
shape
[
0
]
weight_cur
=
weight
bias_cur
=
bias_for_infer
if
self
.
enable_int8_test_mode
:
assert
features
.
dtype
==
torch
.
int8
,
"in int8 test mode, feature must be int8"
weight_cur
=
self
.
_int8_weight
bias_cur
=
self
.
_int8_bias
if
self
.
training
:
out_features
=
Fsp
.
implicit_gemm
(
features
,
self
.
weight
,
pair_fwd
,
pair_bwd
,
features
,
weight
_cur
,
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
num_activate_out
,
masks
,
mask_int_count
,
self
.
training
,
self
.
subm
,
num_activate_out
,
masks
,
self
.
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
,
bias_
for_infe
r
,
bias_
cu
r
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
else
:
output_dtype
=
None
if
self
.
_int8_output_scale
is
None
:
output_dtype
=
weight
.
dtype
out_features
,
_
,
_
=
ops
.
implicit_gemm
(
features
,
weight_cur
,
pair_fwd
,
pair_mask_fwd_splits
,
mask_argsort_fwd_splits
,
num_activate_out
,
masks
,
self
.
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
,
bias_cur
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
,
# TODO do we really need output scale to scale bias in kernel?
1.0
,
# output_scale
self
.
_int8_weight_scale
,
# scale
output_add
=
add_input
.
features
if
add_input
is
not
None
else
None
,
output_add_scale
=
output_add_scale
,
output_dtype
=
output_dtype
)
if
bias_for_training
is
not
None
:
out_features
+=
bias_for_training
if
input
.
benchmark
:
...
...
@@ -581,8 +675,9 @@ class SparseConvolution(SparseModule):
out_tensor
.
indices
=
outids
out_tensor
.
indice_dict
=
indice_dict
out_tensor
.
spatial_shape
=
out_spatial_shape
# print(outids.shape, spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding, out_spatial_shape)
if
add_input
is
not
None
and
not
self
.
enable_int8_test_mode
:
out_tensor
=
out_tensor
.
replace_feature
(
_apply_act
(
out_tensor
.
features
+
add_input
.
features
,
self
.
act_type
,
self
.
act_alpha
,
self
.
act_beta
))
out_tensor
.
int8_scale
=
output_scale
return
out_tensor
...
...
spconv/pytorch/core.py
View file @
aa26c99e
...
...
@@ -89,8 +89,7 @@ class ImplicitGemmIndiceData(object):
out_spatial_shape
,
is_subm
:
bool
,
algo
:
ConvAlgo
,
ksize
:
List
[
int
],
stride
:
List
[
int
],
dilation
:
List
[
int
],
padding
:
List
[
int
],
in_voxel_num
:
Optional
[
Any
]
=
None
,
out_voxel_num
:
Optional
[
Any
]
=
None
,
mask_int_count
:
int
=
1
):
out_voxel_num
:
Optional
[
Any
]
=
None
):
self
.
out_indices
=
out_indices
self
.
indices
=
indices
self
.
pair_fwd
=
pair_fwd
...
...
@@ -111,7 +110,6 @@ class ImplicitGemmIndiceData(object):
# in/out voxel_num is only used in tensorrt conversion.
self
.
in_voxel_num
=
in_voxel_num
self
.
out_voxel_num
=
out_voxel_num
self
.
mask_int_count
=
mask_int_count
def
scatter_nd
(
indices
,
updates
,
shape
):
...
...
@@ -183,6 +181,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
force_algo
=
force_algo
# for simple int8 torch inference
self
.
int8_scale
:
Optional
[
float
]
=
None
def
replace_feature
(
self
,
feature
:
torch
.
Tensor
):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
...
...
spconv/pytorch/functional.py
View file @
aa26c99e
...
...
@@ -198,7 +198,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits
:
List
[
torch
.
Tensor
],
num_activate_out
:
int
,
masks
:
List
[
np
.
ndarray
],
mask_int_count
:
int
,
is_train
:
bool
,
is_subm
:
bool
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
...
...
@@ -210,7 +209,7 @@ class SparseImplicitGemmFunction(Function):
try
:
out
,
mask_out
,
mask_width
=
ops
.
implicit_gemm
(
features
,
filters
,
pair_fwd
,
pair_mask_fwd_splits
,
mask_argsort_fwd_splits
,
num_activate_out
,
masks
,
mask_int_count
,
is_train
,
mask_argsort_fwd_splits
,
num_activate_out
,
masks
,
is_train
,
is_subm
,
timer
,
fp32_accum
,
bias
,
act_alpha
,
act_beta
,
act_type
)
except
Exception
as
e
:
...
...
@@ -236,7 +235,6 @@ class SparseImplicitGemmFunction(Function):
ctx
.
masks
=
masks
ctx
.
is_subm
=
is_subm
ctx
.
fp32_accum
=
fp32_accum
ctx
.
mask_int_count
=
mask_int_count
return
out
@
staticmethod
...
...
@@ -255,7 +253,6 @@ class SparseImplicitGemmFunction(Function):
is_subm
=
ctx
.
is_subm
timer
=
ctx
.
timer
fp32_accum
=
ctx
.
fp32_accum
mask_int_count
=
ctx
.
mask_int_count
try
:
input_bp
,
filters_bp
=
ops
.
implicit_gemm_backward
(
...
...
@@ -270,7 +267,6 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits
,
mask_output_fwd
=
mask_out
,
masks
=
masks
,
mask_int_count
=
mask_int_count
,
mask_width
=
mask_width
,
is_subm
=
is_subm
,
timer
=
timer
,
...
...
@@ -286,7 +282,7 @@ class SparseImplicitGemmFunction(Function):
mask_argsort_bwd_splits
,
masks
))
raise
e
None_9
=
[
None
]
*
1
7
None_9
=
[
None
]
*
1
6
return
(
input_bp
,
filters_bp
,
*
None_9
)
...
...
spconv/pytorch/ops.py
View file @
aa26c99e
...
...
@@ -23,7 +23,7 @@ import spconv
from
spconv.core
import
AlgoHint
,
ConvAlgo
from
typing
import
Dict
,
List
,
Optional
,
Union
from
spconv.pytorch.core
import
ThrustSortAllocator
from
spconv.pytorch.cppcore
import
TorchAllocator
,
torch_tensor_to_tv
,
get_current_stream
,
get_arch
,
TorchSpconvMatmul
from
spconv.pytorch.cppcore
import
_TORCH_DTYPE_TO_TV
,
TorchAllocator
,
torch_tensor_to_tv
,
get_current_stream
,
get_arch
,
TorchSpconvMatmul
from
spconv.core_cc.csrc.sparse.all
import
SpconvOps
from
spconv.core_cc.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.constants
import
SPCONV_CPP_INDICE_PAIRS
,
SPCONV_CPP_INDICE_PAIRS_IGEMM
,
SPCONV_CPP_GEMM
,
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
,
SPCONV_ALLOW_TF32
...
...
@@ -31,7 +31,7 @@ import spconv.core_cc as _ext
from
spconv.core_cc.csrc.sparse.convops.spops
import
ConvGemmOps
from
spconv.core_cc.csrc.sparse.inference
import
InferenceOps
from
spconv.cppconstants
import
CPU_ONLY_BUILD
from
cumm.gemm.codeops
import
div_up
from
spconv.utils
import
nullcontext
if
not
CPU_ONLY_BUILD
:
...
...
@@ -365,7 +365,7 @@ def get_indice_pairs_implicit_gemm(
timer_cpp
=
tv
.
CUDAKernelTimer
(
False
)
if
timer
.
_timer
is
not
None
:
timer_cpp
=
timer
.
_timer
mask_tensor
,
num_act_out
,
mask_int_count
=
SpconvOps
.
get_indice_pairs_implicit_gemm
(
mask_tensor
,
num_act_out
=
SpconvOps
.
get_indice_pairs_implicit_gemm
(
thalloc
,
torch_tensor_to_tv
(
indices
),
batch_size
,
...
...
@@ -408,7 +408,7 @@ def get_indice_pairs_implicit_gemm(
assert
pair
.
shape
[
0
]
==
2
pair_bwd
=
pair
[
1
]
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair_bwd
,
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
,
mask_int_count
)
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
pair_bwd
=
thalloc
.
allocated
.
get
(
AllocKeys
.
PairBwd
,
torch
.
Tensor
())
pair_fwd
=
thalloc
.
allocated
[
AllocKeys
.
PairFwd
]
...
...
@@ -437,16 +437,13 @@ def get_indice_pairs_implicit_gemm(
]
return
(
out_inds
,
indice_num_per_loc
,
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
masks
,
mask_int_count
)
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
masks
)
assert
indices
.
is_cuda
,
"implicit gemm only support cuda"
ndim
=
indices
.
shape
[
1
]
-
1
kv
:
int
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ksize
,
1
)
# TODO in future we will support up to 128 kernel volume.
# assert kv <= 32, "currently only support kernel volume <= 32 to use implicit gemm"
mask_int_count
=
(
kv
+
31
)
//
32
if
1
<
mask_int_count
<
4
:
mask_int_count
=
4
assert
mask_int_count
in
[
1
,
4
]
mask_int_count
=
div_up
(
kv
,
32
)
if
not
subm
:
if
transpose
:
...
...
@@ -511,7 +508,7 @@ def get_indice_pairs_implicit_gemm(
hashdata
=
_HashData
(
out_inds
.
shape
[
0
],
use_int64_hash_k
,
indices
.
device
)
pair_mask
=
torch
.
empty
((
mask_split_count
,
indices
.
shape
[
0
]
*
mask_int_count
),
pair_mask
=
torch
.
empty
((
mask_split_count
,
indices
.
shape
[
0
]
,
mask_int_count
),
dtype
=
torch
.
int32
,
device
=
indices
.
device
)
...
...
@@ -531,8 +528,7 @@ def get_indice_pairs_implicit_gemm(
dilation
=
dilation
,
indice_pair_mask
=
pair_mask_tv
,
backward
=
is_train
,
stream_int
=
stream
,
mask_int_count
=
mask_int_count
)
stream_int
=
stream
)
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
# CONV.stream_synchronize(stream)
...
...
@@ -549,7 +545,7 @@ def get_indice_pairs_implicit_gemm(
# so I use this stupid hack to use torch allocator without touch
# pytorch binary (c++).
# f**k thrust
SpconvOps
.
sort_1d_by_key_allocator
_mask_auto
(
pair_mask_tv
[
j
],
SpconvOps
.
sort_1d_by_key_allocator
(
pair_mask_tv
[
j
],
alloc
.
alloc
,
mask_argsort_tv
[
j
],
stream
,
mask_int_count
)
...
...
@@ -560,10 +556,10 @@ def get_indice_pairs_implicit_gemm(
]
if
is_train
:
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair
[
1
],
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
,
mask_int_count
)
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
torch
.
Tensor
(),
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
,
mask_int_count
)
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
max_num_act
=
SpconvOps
.
get_handcrafted_max_act_out
(
indices
.
shape
[
0
],
ksize
,
stride
,
padding
,
dilation
)
...
...
@@ -621,7 +617,6 @@ def get_indice_pairs_implicit_gemm(
stream_int
=
stream
)
uniq_out_indices_offset_tv
=
tv
.
Tensor
()
with
timer
.
record
(
f
"unique_
{
indice_pairs_uniq
.
shape
[
0
]
}
"
,
stream
):
if
direct_table
:
uniq_cnt
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
,
...
...
@@ -655,7 +650,7 @@ def get_indice_pairs_implicit_gemm(
-
1
,
dtype
=
indices
.
dtype
,
device
=
indices
.
device
)
pair_mask_fwd
=
torch
.
zeros
((
mask_split_count
,
num_act_out
*
mask_int_count
),
pair_mask_fwd
=
torch
.
zeros
((
mask_split_count
,
num_act_out
,
mask_int_count
),
dtype
=
torch
.
int32
,
device
=
indices
.
device
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
...
...
@@ -665,7 +660,7 @@ def get_indice_pairs_implicit_gemm(
pair_mask_bwd_tv
=
tv
.
Tensor
()
if
is_train
:
pair_mask_bwd
=
torch
.
zeros
(
(
mask_split_count
,
indices
.
shape
[
0
]
*
mask_int_count
),
(
mask_split_count
,
indices
.
shape
[
0
]
,
mask_int_count
),
dtype
=
torch
.
int32
,
device
=
indices
.
device
)
pair_mask_bwd_tv
=
torch_tensor_to_tv
(
pair_mask_bwd
,
...
...
@@ -713,8 +708,7 @@ def get_indice_pairs_implicit_gemm(
padding
=
padding
,
dilation
=
dilation
,
transposed
=
transpose
,
stream_int
=
stream
,
mask_int_count
=
mask_int_count
)
stream_int
=
stream
)
mask_argsort_fwd
=
torch
.
empty
((
mask_split_count
,
out_inds
.
shape
[
0
]),
dtype
=
torch
.
int32
,
device
=
indices
.
device
)
...
...
@@ -766,24 +760,24 @@ def get_indice_pairs_implicit_gemm(
else
:
# if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
if
not
is_train
:
SpconvOps
.
sort_1d_by_key_allocator
_mask_auto
(
pair_mask_fwd_tv
[
0
],
SpconvOps
.
sort_1d_by_key_allocator
(
pair_mask_fwd_tv
[
0
],
alloc
.
alloc
,
mask_argsort_fwd_tv
[
0
],
stream
,
mask_int_count
)
else
:
if
pair_mask_bwd_tv
.
dim
(
1
)
>
pair_mask_fwd_tv
.
dim
(
1
):
SpconvOps
.
sort_1d_by_key_allocator
_mask_auto
(
SpconvOps
.
sort_1d_by_key_allocator
(
pair_mask_bwd_tv
[
0
],
alloc
.
alloc
,
mask_argsort_bwd_tv
[
0
],
stream
,
mask_int_count
)
SpconvOps
.
sort_1d_by_key_allocator
_mask_auto
(
SpconvOps
.
sort_1d_by_key_allocator
(
pair_mask_fwd_tv
[
0
],
alloc
.
alloc
,
mask_argsort_fwd_tv
[
0
],
stream
,
mask_int_count
)
else
:
SpconvOps
.
sort_1d_by_key_allocator
_mask_auto
(
SpconvOps
.
sort_1d_by_key_allocator
(
pair_mask_fwd_tv
[
0
],
alloc
.
alloc
,
mask_argsort_fwd_tv
[
0
],
stream
,
mask_int_count
)
SpconvOps
.
sort_1d_by_key_allocator
_mask_auto
(
SpconvOps
.
sort_1d_by_key_allocator
(
pair_mask_bwd_tv
[
0
],
alloc
.
alloc
,
mask_argsort_bwd_tv
[
0
],
stream
,
mask_int_count
)
...
...
@@ -808,7 +802,7 @@ def get_indice_pairs_implicit_gemm(
return
(
out_inds
,
indice_num_per_loc
,
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
masks
,
mask_int_count
)
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
masks
)
def
indice_conv
(
features
:
torch
.
Tensor
,
...
...
@@ -1457,7 +1451,6 @@ def implicit_gemm(features: torch.Tensor,
mask_argsort_fwd_splits
:
List
[
torch
.
Tensor
],
num_activate_out
:
int
,
masks
:
List
[
np
.
ndarray
],
mask_int_count
:
int
,
is_train
:
bool
,
is_subm
:
bool
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
...
...
@@ -1465,16 +1458,31 @@ def implicit_gemm(features: torch.Tensor,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
act_alpha
:
float
=
0.0
,
act_beta
:
float
=
0.0
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
):
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
output_scale
:
float
=
1.0
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_add
:
Optional
[
torch
.
Tensor
]
=
None
,
output_add_scale
:
float
=
1.0
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
stream
=
get_current_stream
()
bias_tv
=
tv
.
Tensor
()
scale_tv
=
tv
.
Tensor
()
output_add_tv
=
tv
.
Tensor
()
if
output_add
is
not
None
:
assert
features
.
dtype
==
torch
.
int8
,
"fused residual add only support int8"
if
bias
is
not
None
:
bias_tv
=
torch_tensor_to_tv
(
bias
)
if
scale
is
not
None
:
scale_tv
=
torch_tensor_to_tv
(
scale
)
if
output_add
is
not
None
:
output_add_tv
=
torch_tensor_to_tv
(
output_add
)
if
not
features
.
is_contiguous
():
features
=
features
.
contiguous
()
assert
features
.
is_contiguous
()
assert
filters
.
is_contiguous
()
if
output_dtype
is
None
:
output_dtype
=
features
.
dtype
if
SPCONV_CPP_GEMM
and
CONV_CPP
is
not
None
:
alloc
=
TorchAllocator
(
features
.
device
)
...
...
@@ -1497,13 +1505,15 @@ def implicit_gemm(features: torch.Tensor,
if
fp32_accum
is
None
:
fp32_accum
=
False
arch
=
get_arch
()
output_dtype_tv
=
_TORCH_DTYPE_TO_TV
[
output_dtype
]
mask_width
,
tune_res_cpp
=
ConvGemmOps
.
implicit_gemm
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
num_activate_out
,
mask_tv
,
mask_int_count
,
arch
,
is_train
,
is_subm
,
stream
,
num_activate_out
,
mask_tv
,
arch
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
,
bias_tv
,
act_alpha
,
act_beta
,
act_type
,
use_tf32
=
constants
.
SPCONV_ALLOW_TF32
)
use_tf32
=
constants
.
SPCONV_ALLOW_TF32
,
output_scale
=
output_scale
,
scale
=
scale_tv
,
output_add
=
output_add_tv
,
output_add_scale
=
output_add_scale
,
output_dtype
=
output_dtype_tv
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
if
is_train
:
...
...
@@ -1515,8 +1525,8 @@ def implicit_gemm(features: torch.Tensor,
# t = time.time()
if
features
.
dtype
==
torch
.
int8
or
features
.
dtype
==
torch
.
qint8
:
raise
NotImplementedError
(
"work in progress"
)
#
if features.dtype == torch.int8 or features.dtype == torch.qint8:
#
raise NotImplementedError("work in progress")
# here filters is KRSC
masks_ints
=
[
m
.
item
()
for
m
in
masks
]
out_channel
=
filters
.
shape
[
0
]
...
...
@@ -1524,13 +1534,14 @@ def implicit_gemm(features: torch.Tensor,
num_split
=
len
(
pair_mask_fwd_splits
)
filters
=
filters
.
reshape
(
out_channel
,
-
1
,
filters
.
shape
[
-
1
])
kv
=
filters
.
shape
[
1
]
mask_int_count
=
div_up
(
kv
,
32
)
if
is_subm
:
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
dtype
=
features
.
dtype
,
dtype
=
output_
dtype
,
device
=
features
.
device
)
else
:
out_features
=
torch
.
zeros
((
num_activate_out
,
out_channel
),
dtype
=
features
.
dtype
,
dtype
=
output_
dtype
,
device
=
features
.
device
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
...
...
@@ -1568,13 +1579,13 @@ def implicit_gemm(features: torch.Tensor,
stream
=
stream
,
fp32_accum
=
fp32_accum
,
use_tf32
=
constants
.
SPCONV_ALLOW_TF32
,
mask_int_count
=
mask_int_count
)
bias
=
bias_tv
,
scale
=
scale_tv
)
mask_width
=
tune_res
.
algo_desp
.
tile_shape
[
0
]
if
is_train
:
mask_output_fwd
=
torch
.
empty
(
[
num_split
,
codeops
.
div_up
(
num_activate_out
,
mask_width
)
*
mask_int_count
],
codeops
.
div_up
(
num_activate_out
,
mask_width
)
,
mask_int_count
],
dtype
=
torch
.
int32
,
device
=
features
.
device
)
# pytorch don't support uint32.
...
...
@@ -1597,12 +1608,16 @@ def implicit_gemm(features: torch.Tensor,
bias_tv
=
tv
.
Tensor
()
if
bias
is
not
None
:
bias_tv
=
torch_tensor_to_tv
(
bias
)
alpha
=
1.0
if
tune_res
.
algo_desp
.
is_int8_inference
:
alpha
=
output_scale
with
timer
.
record
(
"implicit_gemm"
,
stream
):
for
j
in
range
(
num_split
):
beta
=
0
if
j
==
0
else
1
if
bias
is
not
None
:
if
bias
is
not
None
and
not
tune_res
.
algo_desp
.
is_int8_inference
:
beta
=
1
if
output_add
is
not
None
and
tune_res
.
algo_desp
.
is_int8_inference
:
beta
=
output_add_scale
CONV
.
run_with_tuned_result
(
tune_res
,
ConvOpType
.
kForward
,
...
...
@@ -1616,6 +1631,7 @@ def implicit_gemm(features: torch.Tensor,
reverse_mask
=
False
,
mask_filter
=
masks_ints
[
j
],
mask_width
=-
1
,
alpha
=
alpha
,
beta
=
beta
,
stream
=
stream
,
verbose
=
False
,
...
...
@@ -1623,91 +1639,8 @@ def implicit_gemm(features: torch.Tensor,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
mask_int_count
=
mask_int_count
)
# INT8_TEST = True
# if INT8_TEST:
# if features.shape[1] % 32 != 0:
# return out_features, mask_output_fwd, mask_width
# features = features.to(torch.int8)
# filters = filters.to(torch.int8)
# out_features_i8 = out_features.to(torch.int8)
# features_tv = torch_tensor_to_tv(features)
# filters_tv = torch_tensor_to_tv(filters)
# out_features_i8_tv = torch_tensor_to_tv(out_features_i8)
# tune_res = CONV.get_tuned_algo(ConvOpType.kForward, features_tv.dtype,
# filters_tv.dtype, out_features_i8_tv.dtype,
# out_channel, in_channel, arch)
# if tune_res is None:
# tune_res, _ = CONV.tune_and_cache(
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# NHWC,
# KRSC,
# NHWC,
# arch,
# mask=pair_mask_fwd_split_tvs[0],
# mask_argsort=mask_argsort_fwd_split_tvs[0],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks[0].item(),
# stream=stream,
# fp32_accum=fp32_accum)
# mask_width = tune_res.algo_desp.tile_shape[0]
# if is_train:
# mask_output_fwd = torch.empty(
# [num_split,
# codeops.div_up(num_activate_out, mask_width)],
# dtype=torch.int32,
# device=features.device)
# # pytorch don't support uint32.
# mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd,
# dtype=tv.uint32)
# mask_output_fwd_tvs = [mask_output_fwd_tv[j] for j in range(num_split)]
# else:
# mask_output_fwd = None
# mask_output_fwd_tv = tv.Tensor()
# mask_output_fwd_tvs = [tv.Tensor() for _ in range(num_split)]
# # CONV.stream_synchronize(stream)
# # print("FPREPARE", time.time() - t)
# # # t = time.time()
# # CONV.stream_synchronize(stream)
# # t = time.time()
# # print(tune_res.algo_desp)
# with tv.measure_and_print(f"i8 time {features.shape[0]}-{in_channel}-{out_channel}"):
# with timer.record("implicit_gemm_i8", stream):
# for j in range(num_split):
# beta = 0 if j == 0 else 1
# CONV.run_with_tuned_result(
# tune_res,
# ConvOpType.kForward,
# features_tv,
# filters_tv,
# out_features_i8_tv,
# mask=pair_mask_fwd_split_tvs[j],
# mask_argsort=mask_argsort_fwd_split_tvs[j],
# mask_output=mask_output_fwd_tvs[j],
# indices=pair_fwd_tv,
# reverse_mask=False,
# mask_filter=masks_ints[j],
# mask_width=-1,
# beta=beta,
# stream=stream,
# verbose=False)
# torch.cuda.synchronize()
# if DEBUG:
# CONV.stream_synchronize(stream)
# dura = time.time() - t
# print("F", tune_res.algo_desp, dura)
# print(out_features.mean(), out_features.max(), out_features.min())
scale
=
scale_tv
,
output_add
=
output_add
)
return
out_features
,
mask_output_fwd
,
mask_width
...
...
@@ -1722,7 +1655,6 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_argsort_bwd_splits
:
List
[
torch
.
Tensor
],
mask_output_fwd
:
Optional
[
torch
.
Tensor
],
masks
:
List
[
np
.
ndarray
],
mask_int_count
:
int
,
mask_width
:
int
,
is_subm
:
bool
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
...
...
@@ -1782,7 +1714,7 @@ def implicit_gemm_backward(features: torch.Tensor,
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
out_bp_tv
,
pair_fwd_tv
,
pair_bwd_tv
,
pair_mask_fwd_splits_tv
,
pair_mask_bwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
mask_argsort_bwd_splits_tv
,
mask_output_fwd_tv
,
mask_tv
,
mask_int_count
,
arch
,
mask_width
,
is_subm
,
stream
,
mask_output_fwd_tv
,
mask_tv
,
arch
,
mask_width
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
,
use_tf32
=
constants
.
SPCONV_ALLOW_TF32
)
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
...
...
@@ -1802,6 +1734,8 @@ def implicit_gemm_backward(features: torch.Tensor,
filters
=
filters
.
reshape
(
out_channel
,
-
1
,
filters
.
shape
[
-
1
])
kv
=
filters
.
shape
[
1
]
need_dynamic_mask
=
kv
>
32
mask_int_count
=
div_up
(
kv
,
32
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
pair_bwd_tv
=
torch_tensor_to_tv
(
pair_bwd
)
...
...
@@ -1831,11 +1765,12 @@ def implicit_gemm_backward(features: torch.Tensor,
dgrad_tune_res
=
CONV
.
get_tuned_algo
(
ConvOpType
.
kBackwardInput
,
din_tv
.
dtype
,
filters_tv
.
dtype
,
dout_tv
.
dtype
,
out_channel
,
in_channel
,
arch
)
in_channel
,
arch
,
need_dynamic_mask
=
need_dynamic_mask
)
wgrad_tune_res
=
CONV
.
get_tuned_algo
(
ConvOpType
.
kBackwardWeight
,
features_tv
.
dtype
,
dfilters_tv
.
dtype
,
dout_tv
.
dtype
,
out_channel
,
in_channel
,
arch
,
mask_width
)
in_channel
,
arch
,
mask_width
,
need_dynamic_mask
=
need_dynamic_mask
)
if
dgrad_tune_res
is
None
:
# TODO split mask maybe completely invalid
...
...
@@ -1861,8 +1796,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter
=
masks
[
0
].
item
(),
stream
=
stream
,
fp32_accum
=
fp32_accum
,
use_tf32
=
constants
.
SPCONV_ALLOW_TF32
,
mask_int_count
=
mask_int_count
)
use_tf32
=
constants
.
SPCONV_ALLOW_TF32
)
if
wgrad_tune_res
is
None
:
wgrad_tune_res
,
_
=
CONV
.
tune_and_cache
(
ConvOpType
.
kBackwardWeight
,
...
...
@@ -1881,8 +1815,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_output
=
tv
.
Tensor
(),
mask_width
=
mask_width
,
stream
=
stream
,
use_tf32
=
constants
.
SPCONV_ALLOW_TF32
,
mask_int_count
=
mask_int_count
)
use_tf32
=
constants
.
SPCONV_ALLOW_TF32
)
workspace_size
=
CONV
.
query_workspace_size
(
wgrad_tune_res
.
algo_desp
,
wgrad_tune_res
.
splitk
,
ConvOpType
.
kBackwardWeight
,
...
...
@@ -1919,8 +1852,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter
=
masks
[
j
].
item
(),
mask_width
=-
1
,
beta
=
beta
,
stream
=
stream
,
mask_int_count
=
mask_int_count
)
stream
=
stream
)
# for backward weight, beta = 0 because each split
# handle different kernel locations.
# TODO remove D iterator in backward weight kernel
...
...
@@ -1939,8 +1871,7 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_width
=
mask_width
,
beta
=
0
,
workspace
=
workspace_tv
,
stream
=
stream
,
mask_int_count
=
mask_int_count
)
stream
=
stream
)
return
(
din
,
dfilters
.
reshape
(
filters_shape
))
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment