Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
b1c57a31
Commit
b1c57a31
authored
Jan 03, 2023
by
yan.yan
Browse files
still working on int8
parent
aa26c99e
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1871 additions
and
343 deletions
+1871
-343
example/mnist/mnist_qat.py
example/mnist/mnist_qat.py
+226
-19
spconv/build.py
spconv/build.py
+1
-1
spconv/core.py
spconv/core.py
+32
-2
spconv/core_cc/csrc/sparse/alloc.pyi
spconv/core_cc/csrc/sparse/alloc.pyi
+4
-2
spconv/core_cc/csrc/sparse/convops/spops.pyi
spconv/core_cc/csrc/sparse/convops/spops.pyi
+2
-1
spconv/csrc/sparse/alloc.py
spconv/csrc/sparse/alloc.py
+2
-2
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+2
-2
spconv/csrc/sparse/maxpool.py
spconv/csrc/sparse/maxpool.py
+3
-3
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+720
-219
spconv/pytorch/core.py
spconv/pytorch/core.py
+11
-4
spconv/pytorch/cppcore.py
spconv/pytorch/cppcore.py
+30
-9
spconv/pytorch/quantization/__init__.py
spconv/pytorch/quantization/__init__.py
+21
-0
spconv/pytorch/quantization/backend_cfg.py
spconv/pytorch/quantization/backend_cfg.py
+360
-0
spconv/pytorch/quantization/fake_q.py
spconv/pytorch/quantization/fake_q.py
+133
-10
spconv/pytorch/quantization/fuse_mapping.py
spconv/pytorch/quantization/fuse_mapping.py
+33
-60
spconv/pytorch/quantization/intrinsic/__init__.py
spconv/pytorch/quantization/intrinsic/__init__.py
+15
-0
spconv/pytorch/quantization/intrinsic/modules.py
spconv/pytorch/quantization/intrinsic/modules.py
+30
-3
spconv/pytorch/quantization/intrinsic/qat/__init__.py
spconv/pytorch/quantization/intrinsic/qat/__init__.py
+15
-0
spconv/pytorch/quantization/intrinsic/qat/modules.py
spconv/pytorch/quantization/intrinsic/qat/modules.py
+216
-6
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
+15
-0
No files found.
example/mnist/mnist_qat.py
View file @
b1c57a31
...
@@ -23,33 +23,95 @@ from torchvision import datasets, transforms
...
@@ -23,33 +23,95 @@ from torchvision import datasets, transforms
from
torch.optim.lr_scheduler
import
StepLR
from
torch.optim.lr_scheduler
import
StepLR
import
contextlib
import
contextlib
import
torch.cuda.amp
import
torch.cuda.amp
import
torch.ao.quantization
from
torch.ao.quantization
import
QuantStub
,
DeQuantStub
import
torch.ao.quantization.quantize_fx
as
qfx
from
spconv.pytorch.quantization.fake_q
import
get_default_spconv_qconfig_mapping
import
spconv.pytorch.quantization
as
spconvq
from
spconv.pytorch.quantization
import
get_default_spconv_trt_ptq_qconfig
from
torch.ao.quantization
import
get_default_qconfig_mapping
from
spconv.pytorch.quantization.backend_cfg
import
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
from
torch.ao.quantization.fx._lower_to_native_backend
import
STATIC_LOWER_FUSED_MODULE_MAP
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
identity_ctx
():
def
identity_ctx
():
yield
yield
class
SubMConvBNReLU
(
spconv
.
SparseSequential
):
def
__init__
(
self
,
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
groups
=
1
):
padding
=
(
kernel_size
-
1
)
//
2
super
(
SubMConvBNReLU
,
self
).
__init__
(
spconv
.
SubMConv2d
(
in_planes
,
out_planes
,
kernel_size
,
stride
,
padding
,
groups
=
groups
,
bias
=
False
),
nn
.
BatchNorm1d
(
out_planes
,
momentum
=
0.1
),
# Replace with ReLU
nn
.
ReLU
(
inplace
=
False
)
)
class
SparseConvBNReLU
(
spconv
.
SparseSequential
):
def
__init__
(
self
,
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
groups
=
1
):
padding
=
(
kernel_size
-
1
)
//
2
super
(
SparseConvBNReLU
,
self
).
__init__
(
spconv
.
SparseConv2d
(
in_planes
,
out_planes
,
kernel_size
,
stride
,
padding
,
groups
=
groups
,
bias
=
False
),
nn
.
BatchNorm1d
(
out_planes
,
momentum
=
0.1
),
# Replace with ReLU
nn
.
ReLU
(
inplace
=
False
)
)
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
super
(
Net
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
self
.
net
=
spconv
.
SparseSequential
(
nn
.
BatchNorm1d
(
1
),
SubMConvBNReLU
(
1
,
32
,
3
),
spconv
.
SubMConv2d
(
1
,
32
,
3
,
1
),
SubMConvBNReLU
(
32
,
64
,
3
),
nn
.
ReLU
(),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
spconv
.
SubMConv2d
(
32
,
64
,
3
,
1
),
nn
.
ReLU
(),
spconv
.
SparseConv2d
(
64
,
64
,
2
,
2
),
spconv
.
ToDense
(),
spconv
.
ToDense
(),
)
)
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
x_sp
:
spconv
.
SparseConvTensor
):
# def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
# x = self.quant(x)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
# x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x
=
self
.
net
(
x_sp
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
dropout1
(
x
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
# x = self.dequant(x)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
def
forward
(
self
,
x
:
torch
.
Tensor
):
class
NetV2
(
nn
.
Module
):
def
__init__
(
self
):
super
(
NetV2
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
SubMConvBNReLU
(
1
,
32
,
3
),
SubMConvBNReLU
(
32
,
64
,
3
),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
spconv
.
ToDense
(),
)
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
features
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
):
# x: [N, 28, 28, 1], must be NHWC tensor
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
x
.
reshape
(
-
1
,
28
,
28
,
1
))
x
=
self
.
quant
(
features
)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp
=
spconv
.
SparseConvTensor
(
features
,
indices
,
[
28
,
28
],
batch_size
)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x
=
self
.
net
(
x_sp
)
x
=
self
.
net
(
x_sp
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
torch
.
flatten
(
x
,
1
)
...
@@ -58,10 +120,93 @@ class Net(nn.Module):
...
@@ -58,10 +120,93 @@ class Net(nn.Module):
x
=
F
.
relu
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
class
NetPTQ
(
nn
.
Module
):
"""pytorch currently don't support cuda int8 inference, so
we only use sparse ops here.
"""
def
__init__
(
self
):
super
(
NetPTQ
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
SubMConvBNReLU
(
1
,
32
,
3
),
SubMConvBNReLU
(
32
,
64
,
3
),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 14x14
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 7x7
SparseConvBNReLU
(
64
,
64
,
3
,
2
,
1
),
# 4x4
spconv
.
SparseConv2d
(
64
,
10
,
4
,
4
),
spconv
.
ToDense
(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
features
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
):
# x: [N, 28, 28, 1], must be NHWC tensor
features
=
self
.
quant
(
features
)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp
=
spconv
.
SparseConvTensor
(
features
,
indices
,
[
28
,
28
],
batch_size
)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp
=
self
.
net
(
x_sp
)
# print(x_sp.shape)
x
=
x_sp
x
=
torch
.
flatten
(
x
,
1
)
# x_res = torch.zeros_like(x)
# x_res[x_sp.indices[:, 0].long()] = x
# x = x_res
# x = torch.flatten(x, 1)
# x = self.dropout1(x)
# x = self.fc1(x)
# x = F.relu(x)
# x = self.dropout2(x)
# x = self.fc2(x)
# print(x_sp.features.shape, x_sp.spatial_shape)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
return
output
class
NetDense
(
nn
.
Module
):
def
__init__
(
self
):
super
(
NetDense
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
3
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
32
,
64
,
3
,
1
)
self
.
dropout1
=
nn
.
Dropout
(
0.25
)
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
self
.
fc1
=
nn
.
Linear
(
9216
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
x
):
x
=
self
.
quant
(
x
)
x
=
self
.
conv1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
x
,
2
)
x
=
self
.
dropout1
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
def
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
):
def
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
):
model
.
train
()
model
.
train
()
scaler
=
torch
.
cuda
.
amp
.
grad_scaler
.
GradScaler
()
scaler
=
torch
.
cuda
.
amp
.
grad_scaler
.
GradScaler
()
...
@@ -72,7 +217,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
...
@@ -72,7 +217,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
with
amp_ctx
:
with
amp_ctx
:
output
=
model
(
data
)
if
args
.
sparse
:
data_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
data
.
reshape
(
-
1
,
28
,
28
,
1
))
# output = model(data_sp)
output
=
model
(
data_sp
.
features
,
data_sp
.
indices
,
data_sp
.
batch_size
)
else
:
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
=
F
.
nll_loss
(
output
,
target
)
scale
=
1.0
scale
=
1.0
if
args
.
fp16
:
if
args
.
fp16
:
...
@@ -114,8 +265,12 @@ def test(args, model, device, test_loader):
...
@@ -114,8 +265,12 @@ def test(args, model, device, test_loader):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
with
amp_ctx
:
with
amp_ctx
:
if
args
.
sparse
:
output
=
model
(
data
)
data_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
data
.
reshape
(
-
1
,
28
,
28
,
1
))
# output = model(data_sp)
output
=
model
(
data_sp
.
features
,
data_sp
.
indices
,
data_sp
.
batch_size
)
else
:
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
# sum up batch loss
output
,
target
,
reduction
=
'sum'
).
item
()
# sum up batch loss
pred
=
output
.
argmax
(
pred
=
output
.
argmax
(
...
@@ -131,6 +286,19 @@ def test(args, model, device, test_loader):
...
@@ -131,6 +286,19 @@ def test(args, model, device, test_loader):
100.
*
correct
/
len
(
test_loader
.
dataset
)))
100.
*
correct
/
len
(
test_loader
.
dataset
)))
def
calibrate
(
args
,
model
:
torch
.
nn
.
Module
,
data_loader
,
device
):
model
.
eval
()
with
torch
.
no_grad
():
for
image
,
target
in
data_loader
:
image
=
image
.
to
(
device
)
if
args
.
sparse
:
data_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
image
.
reshape
(
-
1
,
28
,
28
,
1
))
output
=
model
(
data_sp
.
features
,
data_sp
.
indices
,
data_sp
.
batch_size
)
# output = model(data_sp)
else
:
output
=
model
(
image
)
def
main
():
def
main
():
# Training settings
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
...
@@ -146,7 +314,7 @@ def main():
...
@@ -146,7 +314,7 @@ def main():
help
=
'input batch size for testing (default: 1000)'
)
help
=
'input batch size for testing (default: 1000)'
)
parser
.
add_argument
(
'--epochs'
,
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
type
=
int
,
default
=
1
4
,
default
=
1
,
metavar
=
'N'
,
metavar
=
'N'
,
help
=
'number of epochs to train (default: 14)'
)
help
=
'number of epochs to train (default: 14)'
)
parser
.
add_argument
(
'--lr'
,
parser
.
add_argument
(
'--lr'
,
...
@@ -168,6 +336,10 @@ def main():
...
@@ -168,6 +336,10 @@ def main():
default
=
1
,
default
=
1
,
metavar
=
'S'
,
metavar
=
'S'
,
help
=
'random seed (default: 1)'
)
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--sparse'
,
action
=
'store_true'
,
default
=
True
,
help
=
'use sparse conv network instead of dense'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--log-interval'
,
'--log-interval'
,
type
=
int
,
type
=
int
,
...
@@ -190,8 +362,14 @@ def main():
...
@@ -190,8 +362,14 @@ def main():
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
qdevice
=
torch
.
device
(
"cuda"
if
use_cuda
and
args
.
sparse
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
if
args
.
sparse
:
model
=
NetPTQ
().
to
(
device
)
else
:
model
=
NetDense
().
to
(
device
)
optimizer
=
optim
.
Adadelta
(
model
.
parameters
(),
lr
=
args
.
lr
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
datasets
.
MNIST
(
'../data'
,
'../data'
,
...
@@ -218,17 +396,46 @@ def main():
...
@@ -218,17 +396,46 @@ def main():
shuffle
=
True
,
shuffle
=
True
,
**
kwargs
)
**
kwargs
)
model
=
Net
().
to
(
device
)
optimizer
=
optim
.
Adadelta
(
model
.
parameters
(),
lr
=
args
.
lr
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
1
,
gamma
=
args
.
gamma
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
1
,
gamma
=
args
.
gamma
)
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
)
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
)
test
(
args
,
model
,
device
,
test_loader
)
test
(
args
,
model
,
device
,
test_loader
)
scheduler
.
step
()
scheduler
.
step
()
# if args.save_model:
# torch.save(model.state_dict(), "mnist_cnn.pt")
model
.
eval
()
STATIC_LOWER_FUSED_MODULE_MAP
.
update
(
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
)
if
not
args
.
sparse
:
model
=
model
.
cpu
()
# qconfig_mapping_default = get_default_qconfig_mapping("x86")
qconfig_mapping
=
get_default_spconv_qconfig_mapping
(
False
)
prepare_cfg
=
spconvq
.
get_spconv_prepare_custom_config
()
backend_cfg
=
spconvq
.
get_spconv_backend_config
()
convert_cfg
=
spconvq
.
get_spconv_convert_custom_config
()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model
=
qfx
.
prepare_fx
(
model
,
qconfig_mapping
,
(),
backend_config
=
backend_cfg
,
prepare_custom_config
=
prepare_cfg
)
# prepared_model.print_readable()
print
([
type
(
m
)
for
m
in
prepared_model
.
modules
()])
print
(
prepared_model
)
# calibrate: run model with some inputs
calibrate
(
args
,
prepared_model
,
test_loader
,
qdevice
)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model
=
qfx
.
convert_to_reference_fx
(
prepared_model
,
convert_cfg
,
qconfig_mapping
=
qconfig_mapping
,
backend_config
=
backend_cfg
)
print
([
type
(
m
)
for
m
in
converted_model
.
modules
()])
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
# model.qconfig = get_default_spconv_trt_ptq_qconfig()
# prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# convert_custom_config_dict = spconvq.get_convert_custom_config()
# torch.ao.quantization.prepare(model, inplace=True)
# print('Post Training Quantization Prepare: Inserting Observers')
# print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# test(args, model, device, test_loader)
print
(
converted_model
)
if
args
.
save_model
:
test
(
args
,
converted_model
,
qdevice
,
test_loader
)
torch
.
save
(
model
.
state_dict
(),
"mnist_cnn.pt"
)
breakpoint
(
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
spconv/build.py
View file @
b1c57a31
...
@@ -21,7 +21,7 @@ from ccimport.compat import InWindows
...
@@ -21,7 +21,7 @@ from ccimport.compat import InWindows
from
.constants
import
PACKAGE_NAME
,
PACKAGE_ROOT
,
DISABLE_JIT
from
.constants
import
PACKAGE_NAME
,
PACKAGE_ROOT
,
DISABLE_JIT
if
project_is_installed
(
PACKAGE_NAME
)
and
project_is_editable
(
if
project_is_installed
(
PACKAGE_NAME
)
and
project_is_editable
(
PACKAGE_NAME
)
and
not
DISABLE_JIT
:
PACKAGE_NAME
)
and
not
DISABLE_JIT
and
False
:
from
spconv.core
import
SHUFFLE_SIMT_PARAMS
,
SHUFFLE_VOLTA_PARAMS
,
SHUFFLE_TURING_PARAMS
,
SHUFFLE_AMPERE_PARAMS
from
spconv.core
import
SHUFFLE_SIMT_PARAMS
,
SHUFFLE_VOLTA_PARAMS
,
SHUFFLE_TURING_PARAMS
,
SHUFFLE_AMPERE_PARAMS
from
spconv.core
import
IMPLGEMM_SIMT_PARAMS
,
IMPLGEMM_VOLTA_PARAMS
,
IMPLGEMM_TURING_PARAMS
,
IMPLGEMM_AMPERE_PARAMS
from
spconv.core
import
IMPLGEMM_SIMT_PARAMS
,
IMPLGEMM_VOLTA_PARAMS
,
IMPLGEMM_TURING_PARAMS
,
IMPLGEMM_AMPERE_PARAMS
...
...
spconv/core.py
View file @
b1c57a31
...
@@ -698,7 +698,23 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -698,7 +698,23 @@ IMPLGEMM_AMPERE_PARAMS = [
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
is_nvrtc
=
True
,
int8_inference
=
True
),
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
...
@@ -797,7 +813,21 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -797,7 +813,21 @@ IMPLGEMM_TURING_PARAMS = [
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
is_nvrtc
=
True
,
int8_inference
=
True
),
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
...
...
spconv/core_cc/csrc/sparse/alloc.pyi
View file @
b1c57a31
...
@@ -2,7 +2,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
...
@@ -2,7 +2,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview import Tensor
class ExternalAllocator:
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False
, scale: float = 1.0
) -> Tensor:
"""
"""
Args:
Args:
name:
name:
...
@@ -11,9 +11,10 @@ class ExternalAllocator:
...
@@ -11,9 +11,10 @@ class ExternalAllocator:
device:
device:
stream:
stream:
is_temp_memory:
is_temp_memory:
scale:
"""
"""
...
...
def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False
, scale: float = 1.0
) -> Tensor:
"""
"""
Args:
Args:
name:
name:
...
@@ -22,6 +23,7 @@ class ExternalAllocator:
...
@@ -22,6 +23,7 @@ class ExternalAllocator:
device:
device:
stream:
stream:
is_temp_memory:
is_temp_memory:
scale:
"""
"""
...
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
...
...
spconv/core_cc/csrc/sparse/convops/spops.pyi
View file @
b1c57a31
...
@@ -63,7 +63,7 @@ class ConvGemmOps:
...
@@ -63,7 +63,7 @@ class ConvGemmOps:
"""
"""
...
...
@staticmethod
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0) -> Tuple[int, Any]:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0
, output_dtype: int = -1
) -> Tuple[int, Any]:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -91,6 +91,7 @@ class ConvGemmOps:
...
@@ -91,6 +91,7 @@ class ConvGemmOps:
scale:
scale:
output_add:
output_add:
output_add_scale:
output_add_scale:
output_dtype:
"""
"""
...
...
@staticmethod
@staticmethod
...
...
spconv/csrc/sparse/alloc.py
View file @
b1c57a31
...
@@ -56,7 +56,7 @@ class ExternalAllocator(pccm.Class):
...
@@ -56,7 +56,7 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"scale"
,
"float"
,
"1.0"
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
...
@@ -69,7 +69,7 @@ class ExternalAllocator(pccm.Class):
...
@@ -69,7 +69,7 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"scale"
,
"float"
,
"1.0"
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
...
...
spconv/csrc/sparse/convops.py
View file @
b1c57a31
...
@@ -2127,10 +2127,10 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2127,10 +2127,10 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
if (is_subm){{
if (is_subm){{
out_features = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
out_features = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int
, false /*is_temp*/, output_scale
);
}}else{{
}}else{{
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int
, false /*is_temp*/, output_scale
);
}}
}}
// auto start_ev = tv::CUDAEvent();
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
// start_ev.record(stream_int);
...
...
spconv/csrc/sparse/maxpool.py
View file @
b1c57a31
...
@@ -311,7 +311,7 @@ class IndiceMaxPool(pccm.Class):
...
@@ -311,7 +311,7 @@ class IndiceMaxPool(pccm.Class):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto nhot = out_inds.dim(0);
auto nhot = out_inds.dim(0);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t
, int8_t
>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
int num_blocks_X = std::get<0>(launchdims);
...
@@ -350,7 +350,7 @@ class IndiceMaxPool(pccm.Class):
...
@@ -350,7 +350,7 @@ class IndiceMaxPool(pccm.Class):
tv::check_shape(in, {{-1, out.dim(1)}});
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t
, int8_t
>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
int num_blocks_X = std::get<0>(launchdims);
...
@@ -478,7 +478,7 @@ class IndiceMaxPool(pccm.Class):
...
@@ -478,7 +478,7 @@ class IndiceMaxPool(pccm.Class):
tv::check_shape(in, {{-1, out.dim(1)}});
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t
, int8_t
>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
int num_blocks_X = std::get<0>(launchdims);
...
...
spconv/pytorch/conv.py
View file @
b1c57a31
...
@@ -36,7 +36,7 @@ from spconv.constants import SAVED_WEIGHT_LAYOUT, ALL_WEIGHT_IS_KRSC, SPCONV_DEB
...
@@ -36,7 +36,7 @@ from spconv.constants import SAVED_WEIGHT_LAYOUT, ALL_WEIGHT_IS_KRSC, SPCONV_DEB
from
spconv.utils
import
nullcontext
from
spconv.utils
import
nullcontext
from
torch.nn.init
import
calculate_gain
from
torch.nn.init
import
calculate_gain
from
cumm
import
tensorview
as
tv
from
cumm
import
tensorview
as
tv
from
collections
import
namedtuple
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
FILTER_HWIO
=
False
FILTER_HWIO
=
False
...
@@ -55,12 +55,7 @@ def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float,
...
@@ -55,12 +55,7 @@ def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float,
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
class
SparseConvolution
(
SparseModule
):
class
SparseConvolutionBase
:
__constants__
=
[
'stride'
,
'padding'
,
'dilation'
,
'groups'
,
'bias'
,
'subm'
,
'inverse'
,
'transposed'
,
'output_padding'
]
def
__init__
(
self
,
def
__init__
(
self
,
ndim
:
int
,
ndim
:
int
,
in_channels
:
int
,
in_channels
:
int
,
...
@@ -69,7 +64,7 @@ class SparseConvolution(SparseModule):
...
@@ -69,7 +64,7 @@ class SparseConvolution(SparseModule):
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
...
@@ -81,21 +76,17 @@ class SparseConvolution(SparseModule):
...
@@ -81,21 +76,17 @@ class SparseConvolution(SparseModule):
record_voxel_count
:
bool
=
False
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
act_beta
:
float
=
0
):
name
=
None
):
super
(
SparseConvolution
,
self
).
__init__
(
name
=
name
)
assert
groups
==
1
,
"don't support groups for now"
assert
groups
==
1
,
"don't support groups for now"
self
.
ndim
=
ndim
self
.
ndim
=
ndim
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
expand_nd
(
ndim
,
kernel_size
)
self
.
kernel_size
=
expand_nd
(
ndim
,
kernel_size
)
self
.
stride
=
expand_nd
(
ndim
,
stride
)
self
.
stride
=
expand_nd
(
ndim
,
stride
)
kv
=
int
(
np
.
prod
(
self
.
kernel_size
))
kv
=
int
(
np
.
prod
(
self
.
kernel_size
))
kv_stride
=
int
(
np
.
prod
(
self
.
stride
))
kv_stride
=
int
(
np
.
prod
(
self
.
stride
))
self
.
dilation
=
expand_nd
(
ndim
,
dilation
)
self
.
dilation
=
expand_nd
(
ndim
,
dilation
)
self
.
padding
=
expand_nd
(
ndim
,
padding
)
self
.
padding
=
expand_nd
(
ndim
,
padding
)
self
.
conv1x1
=
kv
==
1
self
.
conv1x1
=
kv
==
1
# TODO we should deprecate support for ksize == 1 but stride != 1.
# TODO we should deprecate support for ksize == 1 but stride != 1.
if
not
subm
:
if
not
subm
:
...
@@ -110,11 +101,6 @@ class SparseConvolution(SparseModule):
...
@@ -110,11 +101,6 @@ class SparseConvolution(SparseModule):
self
.
groups
=
groups
self
.
groups
=
groups
self
.
subm
=
subm
self
.
subm
=
subm
self
.
indice_key
=
indice_key
self
.
indice_key
=
indice_key
if
record_voxel_count
and
not
self
.
subm
and
not
self
.
inverse
:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self
.
register_buffer
(
_MAX_NUM_VOXELS_DURING_TRAINING
,
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
))
self
.
record_voxel_count
=
record_voxel_count
self
.
record_voxel_count
=
record_voxel_count
if
algo
is
None
:
if
algo
is
None
:
if
kv
<=
128
and
not
CPU_ONLY_BUILD
:
if
kv
<=
128
and
not
CPU_ONLY_BUILD
:
...
@@ -131,169 +117,37 @@ class SparseConvolution(SparseModule):
...
@@ -131,169 +117,37 @@ class SparseConvolution(SparseModule):
self
.
algo
=
algo
self
.
algo
=
algo
self
.
fp32_accum
=
fp32_accum
self
.
fp32_accum
=
fp32_accum
# self.algo = ConvAlgo.Native
# self.algo = ConvAlgo.Native
if
self
.
algo
==
ConvAlgo
.
Native
and
not
ALL_WEIGHT_IS_KRSC
:
if
self
.
algo
==
ConvAlgo
.
Native
and
not
ALL_WEIGHT_IS_KRSC
:
if
FILTER_HWIO
:
if
FILTER_HWIO
:
# RSCK
# RSCK
self
.
weight
=
Parameter
(
weight_shape
=
[
*
self
.
kernel_size
,
in_channels
,
out_channels
]
torch
.
Tensor
(
*
self
.
kernel_size
,
in_channels
,
out_channels
))
else
:
else
:
# RSKC
# RSKC
self
.
weight
=
Parameter
(
weight_shape
=
[
*
self
.
kernel_size
,
out_channels
,
in_channels
]
torch
.
Tensor
(
*
self
.
kernel_size
,
out_channels
,
in_channels
))
else
:
else
:
# KRSC
# KRSC
self
.
weight
=
Parameter
(
weight_shape
=
[
out_channels
,
*
self
.
kernel_size
,
in_channels
]
torch
.
Tensor
(
out_channels
,
*
self
.
kernel_size
,
in_channels
))
self
.
weight_shape
=
weight_shape
if
bias
:
self
.
bias
=
Parameter
(
torch
.
Tensor
(
out_channels
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
act_type
=
act_type
self
.
act_type
=
act_type
self
.
act_alpha
=
act_alpha
self
.
act_alpha
=
act_alpha
self
.
act_beta
=
act_beta
self
.
act_beta
=
act_beta
self
.
scale
=
1.0
self
.
enable_int8_test_mode
:
bool
=
False
self
.
zero_point
=
0
self
.
_int8_weight
=
torch
.
Tensor
()
# calculated by max(abs(weight)) for each channel
self
.
_int8_weight_scale
=
torch
.
Tensor
()
# calculated by scale self.bias with _int8_input_scale
self
.
_int8_bias
=
torch
.
Tensor
()
# int8 inference must set _int8_input_scale
self
.
_int8_input_scale
:
Optional
[
float
]
=
None
# if _int8_output_scale unset, will execute s8 @ s8 => f16/f32 (weight dtype), i.e. dequantization
self
.
_int8_output_scale
:
Optional
[
float
]
=
None
if
self
.
conv1x1
:
if
self
.
conv1x1
:
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
self
.
reset_parameters
()
if
hasattr
(
self
,
"_register_load_state_dict_pre_hook"
):
self
.
_register_load_state_dict_pre_hook
(
self
.
_load_weight_different_layout
)
def
get_max_num_voxels
(
self
)
->
Optional
[
torch
.
Tensor
]:
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
return
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
)
return
None
def
set_int8_test
(
self
,
enable
:
bool
,
input_scale
:
float
,
output_scale
:
Optional
[
float
]
=
None
,
weight_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
self
.
_int8_input_scale
=
input_scale
self
.
_int8_output_scale
=
output_scale
if
weight_scale
is
not
None
:
self
.
_int8_weight_scale
=
weight_scale
self
.
enable_int8_test_mode
=
enable
def
_load_weight_different_layout
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
name
=
prefix
+
_MAX_NUM_VOXELS_DURING_TRAINING
if
self
.
record_voxel_count
and
not
self
.
subm
and
not
self
.
inverse
and
name
not
in
state_dict
:
state_dict
[
name
]
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
)
if
not
SAVED_WEIGHT_LAYOUT
:
return
key
=
prefix
+
"weight"
assert
key
in
state_dict
ndim
=
self
.
ndim
if
SAVED_WEIGHT_LAYOUT
==
"RSKC"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
ndim
,
*
range
(
ndim
),
ndim
+
1
).
contiguous
()
elif
SAVED_WEIGHT_LAYOUT
==
"RSCK"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
ndim
+
1
,
*
range
(
ndim
),
ndim
).
contiguous
()
if
ALL_WEIGHT_IS_KRSC
or
self
.
algo
!=
ConvAlgo
.
Native
:
# in spconv 2.2, we only support KRSC layout.
if
SAVED_WEIGHT_LAYOUT
==
"RSKC"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
ndim
,
*
range
(
ndim
),
ndim
+
1
).
contiguous
()
elif
SAVED_WEIGHT_LAYOUT
==
"RSCK"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
ndim
+
1
,
*
range
(
ndim
),
ndim
).
contiguous
()
else
:
if
self
.
algo
==
ConvAlgo
.
Native
:
# to RSCK
if
SAVED_WEIGHT_LAYOUT
==
"RSKC"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
*
range
(
ndim
),
ndim
+
1
,
ndim
).
contiguous
()
elif
SAVED_WEIGHT_LAYOUT
==
"KRSC"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
*
range
(
1
,
ndim
+
1
),
0
,
ndim
+
1
).
contiguous
()
def
extra_repr
(
self
):
s
=
(
'{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}'
)
if
self
.
padding
!=
(
0
,
)
*
len
(
self
.
padding
):
s
+=
', padding={padding}'
if
self
.
dilation
!=
(
1
,
)
*
len
(
self
.
dilation
):
s
+=
', dilation={dilation}'
if
self
.
output_padding
!=
(
0
,
)
*
len
(
self
.
output_padding
):
s
+=
', output_padding={output_padding}'
if
self
.
groups
!=
1
:
s
+=
', groups={groups}'
if
self
.
bias
is
None
:
s
+=
', bias=False'
if
self
.
algo
is
not
None
:
s
+=
f
', algo=
{
self
.
algo
}
'
return
s
.
format
(
**
self
.
__dict__
)
def
_calculate_fan_in_and_fan_out
(
self
):
receptive_field_size
=
1
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for
s
in
self
.
kernel_size
:
receptive_field_size
*=
s
fan_in
=
self
.
in_channels
*
receptive_field_size
fan_out
=
self
.
out_channels
*
receptive_field_size
return
fan_in
,
fan_out
def
_calculate_correct_fan
(
self
,
mode
):
mode
=
mode
.
lower
()
valid_modes
=
[
'fan_in'
,
'fan_out'
]
if
mode
not
in
valid_modes
:
raise
ValueError
(
"Mode {} not supported, please use one of {}"
.
format
(
mode
,
valid_modes
))
fan_in
,
fan_out
=
self
.
_calculate_fan_in_and_fan_out
()
return
fan_in
if
mode
==
'fan_in'
else
fan_out
def
_custom_kaiming_uniform_
(
self
,
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
r
"""same as torch.init.kaiming_uniform_, with KRSC layout support
"""
fan
=
self
.
_calculate_correct_fan
(
mode
)
gain
=
calculate_gain
(
nonlinearity
,
a
)
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
with
torch
.
no_grad
():
return
tensor
.
uniform_
(
-
bound
,
bound
)
def
reset_parameters
(
self
):
def
_conv_forward
(
self
,
training
:
bool
,
input
:
SparseConvTensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
add_input
:
Optional
[
SparseConvTensor
]
=
None
,
if
SPCONV_DEBUG_WEIGHT
:
channel_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
float
]
=
None
,
name
:
Optional
[
str
]
=
None
,
self
.
_custom_kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
0.005
))
sparse_unique_name
:
str
=
""
,
else
:
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
self
.
_custom_kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
5
))
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
):
if
self
.
bias
is
not
None
:
# assert isinstance(input, SparseConvTensor)
fan_in
,
_
=
self
.
_calculate_fan_in_and_fan_out
()
is_int8
=
input
.
is_quantized
and
weight
.
is_quantized
bound
=
1
/
math
.
sqrt
(
fan_in
)
if
is_int8
:
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
assert
output_scale
is
not
None
and
channel_scale
is
not
None
,
"int8 must be called in static quantized module"
assert
bias
is
not
None
,
"currently you must specify a bias"
def
is_inverseable
(
self
):
return
self
.
indice_key
is
not
None
and
not
self
.
subm
def
forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
return
self
.
_conv_forward
(
input
,
self
.
weight
,
self
.
bias
,
add_input
)
def
_conv_forward
(
self
,
input
:
SparseConvTensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
assert
isinstance
(
input
,
SparseConvTensor
)
assert
input
.
features
.
shape
[
assert
input
.
features
.
shape
[
1
]
==
self
.
in_channels
,
"channel size mismatch"
1
]
==
self
.
in_channels
,
"channel size mismatch"
features
=
input
.
features
features
=
input
.
features
...
@@ -301,35 +155,38 @@ class SparseConvolution(SparseModule):
...
@@ -301,35 +155,38 @@ class SparseConvolution(SparseModule):
indices
=
input
.
indices
indices
=
input
.
indices
spatial_shape
=
input
.
spatial_shape
spatial_shape
=
input
.
spatial_shape
batch_size
=
input
.
batch_size
batch_size
=
input
.
batch_size
bias_for_training
=
bias
if
self
.
training
else
None
bias_for_training
=
bias
if
training
else
None
bias_for_infer
=
bias
if
not
self
.
training
else
None
bias_for_infer
=
bias
if
not
training
else
None
output_scale
=
None
output_add_scale
=
1.0
output_add_scale
=
1.0
if
self
.
enable_int8_test_mode
:
if
is_int8
:
assert
not
self
.
training
,
"must in eval mode"
assert
self
.
algo
==
ConvAlgo
.
MaskImplicitGemm
,
"int8 inference only support MaskImplicitGemm"
assert
bias_for_infer
is
not
None
,
"conv-bn-relu must be fused"
assert
self
.
_int8_input_scale
is
not
None
if
features
.
dtype
!=
torch
.
int8
:
# quantize
features
=
torch
.
clamp
(
torch
.
round
(
features
/
self
.
_int8_input_scale
),
-
128
,
127
).
to
(
torch
.
int8
)
output_scale
=
self
.
_int8_output_scale
int8_out_scale
=
output_scale
if
int8_out_scale
is
None
:
int8_out_scale
=
1
if
add_input
is
not
None
:
if
add_input
is
not
None
:
assert
add_input
.
int8_scale
is
not
None
,
"only support int8 add"
output_add_scale
=
add_input
.
q_scale
()
output_add_scale
=
add_input
.
int8_scale
# if self.enable_int8_test_mode:
if
self
.
_int8_weight
.
numel
()
==
0
:
# assert not self.training, "must in eval mode"
with
torch
.
no_grad
():
# assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
assert
ALL_WEIGHT_IS_KRSC
# assert bias_for_infer is not None, "conv-bn-relu must be fused"
weight_scales
=
torch
.
abs
(
weight
).
view
(
self
.
out_channels
,
-
1
).
max
(
1
)[
0
]
# assert self._int8_input_scale is not None
num_1s
=
[
1
]
*
(
self
.
ndim
+
1
)
# if features.dtype != torch.int8:
self
.
_int8_weight
=
(
weight
/
weight_scales
.
view
(
self
.
out_channels
,
*
num_1s
)
*
127
).
to
(
torch
.
int8
)
# # quantize
if
self
.
_int8_weight_scale
.
numel
()
==
0
:
# features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
self
.
_int8_weight_scale
=
int8_out_scale
/
(
self
.
_int8_input_scale
*
weight_scales
)
# output_scale = self._int8_output_scale
self
.
_int8_bias
=
bias_for_infer
*
int8_out_scale
# int8_out_scale = output_scale
if
self
.
training
:
# if int8_out_scale is None:
# int8_out_scale = 1
# if add_input is not None:
# assert add_input.int8_scale is not None, "only support int8 add"
# output_add_scale = add_input.int8_scale
# if self._int8_weight.numel() == 0:
# with torch.no_grad():
# assert ALL_WEIGHT_IS_KRSC
# weight_scale = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
# num_1s = [1] * (self.ndim + 1)
# self._int8_weight = (weight / weight_scale.view(self.out_channels, *num_1s) * 127).to(torch.int8)
# if self._int8_weight_scale.numel() == 0:
# self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scale)
# self._int8_bias = bias_for_infer * int8_out_scale
if
training
:
msg
=
"act don't support backward, only used in inference"
msg
=
"act don't support backward, only used in inference"
assert
self
.
act_type
==
tv
.
gemm
.
Activation
.
None_
,
msg
assert
self
.
act_type
==
tv
.
gemm
.
Activation
.
None_
,
msg
...
@@ -349,12 +206,12 @@ class SparseConvolution(SparseModule):
...
@@ -349,12 +206,12 @@ class SparseConvolution(SparseModule):
# t = time.time()
# t = time.time()
out_tensor
=
input
.
shadow_copy
()
out_tensor
=
input
.
shadow_copy
()
if
input
.
benchmark
:
if
input
.
benchmark
:
if
self
.
name
is
None
:
if
name
is
None
:
raise
ValueError
(
raise
ValueError
(
"you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
"you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
)
)
if
self
.
name
not
in
input
.
benchmark_record
:
if
name
not
in
input
.
benchmark_record
:
input
.
benchmark_record
[
self
.
name
]
=
{
input
.
benchmark_record
[
name
]
=
{
"type"
:
"SparseConvolution"
,
"type"
:
"SparseConvolution"
,
"indice_gen_time"
:
[],
"indice_gen_time"
:
[],
"time"
:
[],
"time"
:
[],
...
@@ -372,7 +229,7 @@ class SparseConvolution(SparseModule):
...
@@ -372,7 +229,7 @@ class SparseConvolution(SparseModule):
"out_channels"
:
self
.
out_channels
,
"out_channels"
:
self
.
out_channels
,
}
}
}
}
if
self
.
conv1x1
and
not
self
.
enable_int8_test_mode
:
if
self
.
conv1x1
and
not
is_int8
:
# in int8 test mode, we don't implement conv1x1 via mm.
# in int8 test mode, we don't implement conv1x1 via mm.
if
FILTER_HWIO
:
if
FILTER_HWIO
:
features
=
torch
.
mm
(
features
=
torch
.
mm
(
...
@@ -401,8 +258,8 @@ class SparseConvolution(SparseModule):
...
@@ -401,8 +258,8 @@ class SparseConvolution(SparseModule):
assert
algo
==
datas
.
algo
,
msg
assert
algo
==
datas
.
algo
,
msg
# algo = datas.algo
# algo = datas.algo
profile_ctx
=
nullcontext
()
profile_ctx
=
nullcontext
()
if
input
.
_timer
is
not
None
and
self
.
_
sparse_unique_name
:
if
input
.
_timer
is
not
None
and
sparse_unique_name
:
profile_ctx
=
input
.
_timer
.
namespace
(
self
.
_
sparse_unique_name
)
profile_ctx
=
input
.
_timer
.
namespace
(
sparse_unique_name
)
with
profile_ctx
:
with
profile_ctx
:
if
algo
==
ConvAlgo
.
Native
:
if
algo
==
ConvAlgo
.
Native
:
datas
=
input
.
find_indice_pair
(
self
.
indice_key
)
datas
=
input
.
find_indice_pair
(
self
.
indice_key
)
...
@@ -449,7 +306,7 @@ class SparseConvolution(SparseModule):
...
@@ -449,7 +306,7 @@ class SparseConvolution(SparseModule):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
interval
=
time
.
time
()
-
t
interval
=
time
.
time
()
-
t
out_tensor
.
benchmark_record
[
out_tensor
.
benchmark_record
[
self
.
name
][
"indice_gen_time"
].
append
(
interval
)
name
][
"indice_gen_time"
].
append
(
interval
)
indice_data
=
IndiceData
(
outids
,
indice_data
=
IndiceData
(
outids
,
indices
,
indices
,
...
@@ -567,7 +424,7 @@ class SparseConvolution(SparseModule):
...
@@ -567,7 +424,7 @@ class SparseConvolution(SparseModule):
out_padding
=
self
.
output_padding
,
out_padding
=
self
.
output_padding
,
subm
=
self
.
subm
,
subm
=
self
.
subm
,
transpose
=
self
.
transposed
,
transpose
=
self
.
transposed
,
is_train
=
(
not
self
.
subm
)
or
self
.
training
,
is_train
=
(
not
self
.
subm
)
or
training
,
alloc
=
input
.
thrust_allocator
,
alloc
=
input
.
thrust_allocator
,
timer
=
input
.
_timer
)
timer
=
input
.
_timer
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -583,7 +440,7 @@ class SparseConvolution(SparseModule):
...
@@ -583,7 +440,7 @@ class SparseConvolution(SparseModule):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
interval
=
time
.
time
()
-
t
interval
=
time
.
time
()
-
t
out_tensor
.
benchmark_record
[
out_tensor
.
benchmark_record
[
self
.
name
][
"indice_gen_time"
].
append
(
interval
)
name
][
"indice_gen_time"
].
append
(
interval
)
outids
=
res
[
0
]
outids
=
res
[
0
]
num_inds_per_loc
=
res
[
1
]
num_inds_per_loc
=
res
[
1
]
pair_fwd
=
res
[
2
]
pair_fwd
=
res
[
2
]
...
@@ -621,16 +478,16 @@ class SparseConvolution(SparseModule):
...
@@ -621,16 +478,16 @@ class SparseConvolution(SparseModule):
num_activate_out
=
outids
.
shape
[
0
]
num_activate_out
=
outids
.
shape
[
0
]
weight_cur
=
weight
weight_cur
=
weight
bias_cur
=
bias_for_infer
bias_cur
=
bias_for_infer
if
self
.
enable_int8_test_mode
:
#
if self.enable_int8_test_mode:
assert
features
.
dtype
==
torch
.
int8
,
"in int8 test mode, feature must be int8"
#
assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
weight_cur
=
self
.
_int8_weight
#
weight_cur = self._int8_weight
bias_cur
=
self
.
_int8_bias
#
bias_cur = self._int8_bias
if
self
.
training
:
if
training
:
out_features
=
Fsp
.
implicit_gemm
(
out_features
=
Fsp
.
implicit_gemm
(
features
,
weight_cur
,
pair_fwd
,
pair_bwd
,
features
,
weight_cur
,
pair_fwd
,
pair_bwd
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
pair_mask_fwd_splits
,
pair_mask_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_bwd_splits
,
num_activate_out
,
masks
,
self
.
training
,
self
.
subm
,
num_activate_out
,
masks
,
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
,
input
.
_timer
,
self
.
fp32_accum
,
bias_cur
,
bias_cur
,
self
.
act_alpha
,
self
.
act_alpha
,
...
@@ -638,20 +495,20 @@ class SparseConvolution(SparseModule):
...
@@ -638,20 +495,20 @@ class SparseConvolution(SparseModule):
self
.
act_type
)
self
.
act_type
)
else
:
else
:
output_dtype
=
None
output_dtype
=
None
if
self
.
_int8_
output_scale
is
None
:
if
output_scale
is
None
:
output_dtype
=
weight
.
dtype
output_dtype
=
weight
.
dtype
out_features
,
_
,
_
=
ops
.
implicit_gemm
(
out_features
,
_
,
_
=
ops
.
implicit_gemm
(
features
,
weight_cur
,
pair_fwd
,
pair_mask_fwd_splits
,
features
,
weight_cur
,
pair_fwd
,
pair_mask_fwd_splits
,
mask_argsort_fwd_splits
,
mask_argsort_fwd_splits
,
num_activate_out
,
masks
,
self
.
training
,
self
.
subm
,
num_activate_out
,
masks
,
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
,
input
.
_timer
,
self
.
fp32_accum
,
bias_cur
,
bias_cur
,
self
.
act_alpha
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_beta
,
self
.
act_type
,
self
.
act_type
,
# TODO do we really need output scale to scale bias in kernel?
# TODO do we really need output scale to scale bias in kernel?
1.0
,
# output_scale
1.0
if
output_scale
is
None
else
output_scale
,
# output_scale
self
.
_int8_weight
_scale
,
# scale
channel
_scale
,
# scale
output_add
=
add_input
.
features
if
add_input
is
not
None
else
None
,
output_add
=
add_input
.
features
if
add_input
is
not
None
else
None
,
output_add_scale
=
output_add_scale
,
output_add_scale
=
output_add_scale
,
output_dtype
=
output_dtype
)
output_dtype
=
output_dtype
)
...
@@ -661,10 +518,10 @@ class SparseConvolution(SparseModule):
...
@@ -661,10 +518,10 @@ class SparseConvolution(SparseModule):
if
input
.
benchmark
:
if
input
.
benchmark
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
interval
=
time
.
time
()
-
t
interval
=
time
.
time
()
-
t
out_tensor
.
benchmark_record
[
self
.
name
][
"time"
].
append
(
interval
)
out_tensor
.
benchmark_record
[
name
][
"time"
].
append
(
interval
)
out_tensor
.
benchmark_record
[
self
.
name
][
"num_points"
].
append
(
out_tensor
.
benchmark_record
[
name
][
"num_points"
].
append
(
features
.
shape
[
0
])
features
.
shape
[
0
])
out_tensor
.
benchmark_record
[
self
.
name
][
"num_out_points"
].
append
(
out_tensor
.
benchmark_record
[
name
][
"num_out_points"
].
append
(
out_features
.
shape
[
0
])
out_features
.
shape
[
0
])
if
not
self
.
subm
and
not
self
.
inverse
and
self
.
record_voxel_count
:
if
not
self
.
subm
and
not
self
.
inverse
and
self
.
record_voxel_count
:
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
...
@@ -675,9 +532,9 @@ class SparseConvolution(SparseModule):
...
@@ -675,9 +532,9 @@ class SparseConvolution(SparseModule):
out_tensor
.
indices
=
outids
out_tensor
.
indices
=
outids
out_tensor
.
indice_dict
=
indice_dict
out_tensor
.
indice_dict
=
indice_dict
out_tensor
.
spatial_shape
=
out_spatial_shape
out_tensor
.
spatial_shape
=
out_spatial_shape
if
add_input
is
not
None
and
not
self
.
enable_int8_test_mode
:
if
add_input
is
not
None
and
not
is_int8
:
# in int8, we apply add + act in kernel.
out_tensor
=
out_tensor
.
replace_feature
(
_apply_act
(
out_tensor
.
features
+
add_input
.
features
,
self
.
act_type
,
self
.
act_alpha
,
self
.
act_beta
))
out_tensor
=
out_tensor
.
replace_feature
(
_apply_act
(
out_tensor
.
features
+
add_input
.
features
,
self
.
act_type
,
self
.
act_alpha
,
self
.
act_beta
))
out_tensor
.
int8_scale
=
output_scale
return
out_tensor
return
out_tensor
...
@@ -725,6 +582,629 @@ class SparseConvolution(SparseModule):
...
@@ -725,6 +582,629 @@ class SparseConvolution(SparseModule):
"please check Inverse Convolution in ."
"please check Inverse Convolution in ."
)
)
class
SparseConvolution
(
SparseConvolutionBase
,
SparseModule
):
__constants__
=
[
'stride'
,
'padding'
,
'dilation'
,
'groups'
,
'bias'
,
'subm'
,
'inverse'
,
'transposed'
,
'output_padding'
]
def
__init__
(
self
,
ndim
:
int
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
3
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
transposed
:
bool
=
False
,
inverse
:
bool
=
False
,
indice_key
:
Optional
[
str
]
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
name
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
SparseConvolutionBase
.
__init__
(
self
,
ndim
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
=
False
,
subm
=
subm
,
output_padding
=
output_padding
,
transposed
=
transposed
,
inverse
=
inverse
,
indice_key
=
indice_key
,
algo
=
algo
,
fp32_accum
=
fp32_accum
,
record_voxel_count
=
record_voxel_count
,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
)
SparseModule
.
__init__
(
self
,
name
=
name
)
if
record_voxel_count
and
not
self
.
subm
and
not
self
.
inverse
:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self
.
register_buffer
(
_MAX_NUM_VOXELS_DURING_TRAINING
,
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
))
self
.
weight
=
Parameter
(
torch
.
zeros
(
*
self
.
weight_shape
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
out_channels
,
**
factory_kwargs
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
if
hasattr
(
self
,
"_register_load_state_dict_pre_hook"
):
self
.
_register_load_state_dict_pre_hook
(
self
.
_load_weight_different_layout
)
def
get_max_num_voxels
(
self
)
->
Optional
[
torch
.
Tensor
]:
if
hasattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
):
return
getattr
(
self
,
_MAX_NUM_VOXELS_DURING_TRAINING
)
return
None
# def set_int8_test(self, enable: bool, input_scale: float, output_scale: Optional[float] = None, weight_scale: Optional[torch.Tensor] = None):
# self._int8_input_scale = input_scale
# self._int8_output_scale = output_scale
# if weight_scale is not None:
# self._int8_weight_scale = weight_scale
# self.enable_int8_test_mode = enable
def
_load_weight_different_layout
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
name
=
prefix
+
_MAX_NUM_VOXELS_DURING_TRAINING
if
self
.
record_voxel_count
and
not
self
.
subm
and
not
self
.
inverse
and
name
not
in
state_dict
:
state_dict
[
name
]
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
)
if
not
SAVED_WEIGHT_LAYOUT
:
return
key
=
prefix
+
"weight"
assert
key
in
state_dict
ndim
=
self
.
ndim
if
SAVED_WEIGHT_LAYOUT
==
"RSKC"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
ndim
,
*
range
(
ndim
),
ndim
+
1
).
contiguous
()
elif
SAVED_WEIGHT_LAYOUT
==
"RSCK"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
ndim
+
1
,
*
range
(
ndim
),
ndim
).
contiguous
()
if
ALL_WEIGHT_IS_KRSC
or
self
.
algo
!=
ConvAlgo
.
Native
:
# in spconv 2.2, we only support KRSC layout.
if
SAVED_WEIGHT_LAYOUT
==
"RSKC"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
ndim
,
*
range
(
ndim
),
ndim
+
1
).
contiguous
()
elif
SAVED_WEIGHT_LAYOUT
==
"RSCK"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
ndim
+
1
,
*
range
(
ndim
),
ndim
).
contiguous
()
else
:
if
self
.
algo
==
ConvAlgo
.
Native
:
# to RSCK
if
SAVED_WEIGHT_LAYOUT
==
"RSKC"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
*
range
(
ndim
),
ndim
+
1
,
ndim
).
contiguous
()
elif
SAVED_WEIGHT_LAYOUT
==
"KRSC"
:
state_dict
[
key
]
=
state_dict
[
key
].
permute
(
*
range
(
1
,
ndim
+
1
),
0
,
ndim
+
1
).
contiguous
()
def
extra_repr
(
self
):
s
=
(
'{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}'
)
if
self
.
padding
!=
(
0
,
)
*
len
(
self
.
padding
):
s
+=
', padding={padding}'
if
self
.
dilation
!=
(
1
,
)
*
len
(
self
.
dilation
):
s
+=
', dilation={dilation}'
if
self
.
output_padding
!=
(
0
,
)
*
len
(
self
.
output_padding
):
s
+=
', output_padding={output_padding}'
if
self
.
groups
!=
1
:
s
+=
', groups={groups}'
if
self
.
bias
is
None
:
s
+=
', bias=False'
if
self
.
algo
is
not
None
:
s
+=
f
', algo=
{
self
.
algo
}
'
return
s
.
format
(
**
self
.
__dict__
)
def
_calculate_fan_in_and_fan_out
(
self
):
receptive_field_size
=
1
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for
s
in
self
.
kernel_size
:
receptive_field_size
*=
s
fan_in
=
self
.
in_channels
*
receptive_field_size
fan_out
=
self
.
out_channels
*
receptive_field_size
return
fan_in
,
fan_out
def
_calculate_correct_fan
(
self
,
mode
):
mode
=
mode
.
lower
()
valid_modes
=
[
'fan_in'
,
'fan_out'
]
if
mode
not
in
valid_modes
:
raise
ValueError
(
"Mode {} not supported, please use one of {}"
.
format
(
mode
,
valid_modes
))
fan_in
,
fan_out
=
self
.
_calculate_fan_in_and_fan_out
()
return
fan_in
if
mode
==
'fan_in'
else
fan_out
def
_custom_kaiming_uniform_
(
self
,
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
r
"""same as torch.init.kaiming_uniform_, with KRSC layout support
"""
fan
=
self
.
_calculate_correct_fan
(
mode
)
gain
=
calculate_gain
(
nonlinearity
,
a
)
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
with
torch
.
no_grad
():
return
tensor
.
uniform_
(
-
bound
,
bound
)
def
reset_parameters
(
self
):
if
SPCONV_DEBUG_WEIGHT
:
self
.
_custom_kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
0.005
))
else
:
self
.
_custom_kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
5
))
if
self
.
bias
is
not
None
:
fan_in
,
_
=
self
.
_calculate_fan_in_and_fan_out
()
bound
=
1
/
math
.
sqrt
(
fan_in
)
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
is_inverseable
(
self
):
return
self
.
indice_key
is
not
None
and
not
self
.
subm
def
forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
return
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight
,
self
.
bias
,
add_input
,
name
=
self
.
name
,
sparse_unique_name
=
self
.
_sparse_unique_name
,
act_type
=
self
.
act_type
,
act_alpha
=
self
.
act_alpha
,
act_beta
=
self
.
act_beta
)
# def _conv_forward(self, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None,
# channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None):
# assert isinstance(input, SparseConvTensor)
# is_int8 = input.is_quantized and weight.is_quantized
# if is_int8:
# assert output_scale is not None and channel_scale is not None, "int8 must be called in static quantized module"
# assert bias is not None, "currently you must specify a bias"
# assert input.features.shape[
# 1] == self.in_channels, "channel size mismatch"
# features = input.features
# device = features.device
# indices = input.indices
# spatial_shape = input.spatial_shape
# batch_size = input.batch_size
# bias_for_training = bias if self.training else None
# bias_for_infer = bias if not self.training else None
# output_add_scale = 1.0
# if is_int8:
# if add_input is not None:
# output_add_scale = add_input.q_scale()
# # if self.enable_int8_test_mode:
# # assert not self.training, "must in eval mode"
# # assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
# # assert bias_for_infer is not None, "conv-bn-relu must be fused"
# # assert self._int8_input_scale is not None
# # if features.dtype != torch.int8:
# # # quantize
# # features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
# # output_scale = self._int8_output_scale
# # int8_out_scale = output_scale
# # if int8_out_scale is None:
# # int8_out_scale = 1
# # if add_input is not None:
# # assert add_input.int8_scale is not None, "only support int8 add"
# # output_add_scale = add_input.int8_scale
# # if self._int8_weight.numel() == 0:
# # with torch.no_grad():
# # assert ALL_WEIGHT_IS_KRSC
# # weight_scale = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
# # num_1s = [1] * (self.ndim + 1)
# # self._int8_weight = (weight / weight_scale.view(self.out_channels, *num_1s) * 127).to(torch.int8)
# # if self._int8_weight_scale.numel() == 0:
# # self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scale)
# # self._int8_bias = bias_for_infer * int8_out_scale
# if self.training:
# msg = "act don't support backward, only used in inference"
# assert self.act_type == tv.gemm.Activation.None_, msg
# if not self.subm:
# if self.transposed:
# out_spatial_shape = ops.get_deconv_output_size(
# spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding)
# else:
# out_spatial_shape = ops.get_conv_output_size(
# spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation)
# else:
# out_spatial_shape = spatial_shape
# # print(self._sparse_unique_name, spatial_shape, out_spatial_shape)
# # input.update_grid(out_spatial_shape)
# # t = time.time()
# out_tensor = input.shadow_copy()
# if input.benchmark:
# if self.name is None:
# raise ValueError(
# "you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
# )
# if self.name not in input.benchmark_record:
# input.benchmark_record[self.name] = {
# "type": "SparseConvolution",
# "indice_gen_time": [],
# "time": [],
# "num_points": [],
# "num_out_points": [],
# "params": {
# "kernel_size": self.kernel_size,
# "stride": self.stride,
# "padding": self.padding,
# "dilation": self.dilation,
# "output_padding": self.output_padding,
# "subm": self.subm,
# "transposed": self.transposed,
# "input_channels": self.in_channels,
# "out_channels": self.out_channels,
# }
# }
# if self.conv1x1 and not is_int8:
# # in int8 test mode, we don't implement conv1x1 via mm.
# if FILTER_HWIO:
# features = torch.mm(
# input.features,
# weight.view(self.out_channels, self.in_channels).T)
# else:
# features = torch.mm(
# input.features,
# weight.view(self.in_channels, self.out_channels))
# if bias is not None:
# features += bias
# out_tensor = out_tensor.replace_feature(features)
# # padding may change spatial shape of conv 1x1.
# out_tensor.spatial_shape = out_spatial_shape
# return out_tensor
# indice_dict = input.indice_dict.copy()
# # only support contiguous tensor for now
# if not features.is_contiguous():
# features = features.contiguous()
# algo = self.algo
# if self.indice_key is not None:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
# assert algo == datas.algo, msg
# # algo = datas.algo
# profile_ctx = nullcontext()
# if input._timer is not None and self._sparse_unique_name:
# profile_ctx = input._timer.namespace(self._sparse_unique_name)
# with profile_ctx:
# if algo == ConvAlgo.Native:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# assert isinstance(datas, IndiceData)
# if self.inverse:
# assert datas is not None and self.indice_key is not None
# assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
# outids = datas.indices
# indice_pairs = datas.indice_pairs
# indice_pair_num = datas.indice_pair_num
# out_spatial_shape = datas.spatial_shape
# self._check_inverse_reuse_valid(input, spatial_shape,
# datas)
# else:
# if self.indice_key is not None and datas is not None:
# outids = datas.out_indices
# indice_pairs = datas.indice_pairs
# indice_pair_num = datas.indice_pair_num
# assert self.subm, "only support reuse subm indices"
# self._check_subm_reuse_valid(input, spatial_shape,
# datas)
# else:
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# try:
# outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
# indices, batch_size, spatial_shape, algo,
# self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding, self.subm,
# self.transposed)
# except Exception as e:
# msg = "[Exception|native_pair]"
# msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
# msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
# msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
# msg += f"transpose={self.transposed}"
# print(msg, file=sys.stderr)
# spconv_save_debug_data(indices)
# raise e
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[
# self.name]["indice_gen_time"].append(interval)
# indice_data = IndiceData(outids,
# indices,
# indice_pairs,
# indice_pair_num,
# spatial_shape,
# out_spatial_shape,
# is_subm=self.subm,
# algo=algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation)
# if self.indice_key is not None:
# msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
# assert self.indice_key not in indice_dict, msg
# indice_dict[self.indice_key] = indice_data
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# indice_pairs_calc = indice_pairs
# if indice_pairs.device != features.device:
# indice_pairs_calc = indice_pairs.to(features.device)
# if self.subm:
# out_features = Fsp.indice_subm_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# if self.inverse:
# out_features = Fsp.indice_inverse_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# out_features = Fsp.indice_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# assert isinstance(datas, ImplicitGemmIndiceData)
# if self.inverse:
# assert datas is not None and self.indice_key is not None
# assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
# outids = datas.indices
# pair_fwd = datas.pair_bwd
# pair_bwd = datas.pair_fwd
# pair_mask_fwd_splits = datas.pair_mask_bwd_splits
# pair_mask_bwd_splits = datas.pair_mask_fwd_splits
# mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
# mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
# masks = datas.masks
# out_spatial_shape = datas.spatial_shape
# # assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
# self._check_inverse_reuse_valid(input, spatial_shape,
# datas)
# else:
# if self.indice_key is not None and datas is not None:
# outids = datas.out_indices
# pair_fwd = datas.pair_fwd
# pair_bwd = datas.pair_bwd
# pair_mask_fwd_splits = datas.pair_mask_fwd_splits
# pair_mask_bwd_splits = datas.pair_mask_bwd_splits
# mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
# mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
# masks = datas.masks
# assert self.subm, "only support reuse subm indices"
# self._check_subm_reuse_valid(input, spatial_shape,
# datas)
# else:
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# with input._timer.namespace("gen_pairs"):
# # we need to gen bwd indices for regular conv
# # because it may be inversed.
# try:
# res = ops.get_indice_pairs_implicit_gemm(
# indices,
# batch_size,
# spatial_shape,
# algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation,
# out_padding=self.output_padding,
# subm=self.subm,
# transpose=self.transposed,
# is_train=(not self.subm) or self.training,
# alloc=input.thrust_allocator,
# timer=input._timer)
# except Exception as e:
# msg = "[Exception|implicit_gemm_pair]"
# msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
# msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
# msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
# msg += f"transpose={self.transposed}"
# print(msg, file=sys.stderr)
# spconv_save_debug_data(indices)
# raise e
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[
# self.name]["indice_gen_time"].append(interval)
# outids = res[0]
# num_inds_per_loc = res[1]
# pair_fwd = res[2]
# pair_bwd = res[3]
# pair_mask_fwd_splits = res[4]
# pair_mask_bwd_splits = res[5]
# mask_argsort_fwd_splits = res[6]
# mask_argsort_bwd_splits = res[7]
# masks = res[8]
# if self.indice_key is not None:
# indice_data = ImplicitGemmIndiceData(
# outids,
# indices,
# pair_fwd,
# pair_bwd,
# pair_mask_fwd_splits=pair_mask_fwd_splits,
# pair_mask_bwd_splits=pair_mask_bwd_splits,
# mask_argsort_fwd_splits=mask_argsort_fwd_splits,
# mask_argsort_bwd_splits=mask_argsort_bwd_splits,
# masks=masks,
# is_subm=self.subm,
# spatial_shape=spatial_shape,
# out_spatial_shape=out_spatial_shape,
# algo=algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation)
# msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
# assert self.indice_key not in indice_dict, msg
# indice_dict[self.indice_key] = indice_data
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# num_activate_out = outids.shape[0]
# weight_cur = weight
# bias_cur = bias_for_infer
# # if self.enable_int8_test_mode:
# # assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
# # weight_cur = self._int8_weight
# # bias_cur = self._int8_bias
# if self.training:
# out_features = Fsp.implicit_gemm(
# features, weight_cur, pair_fwd, pair_bwd,
# pair_mask_fwd_splits, pair_mask_bwd_splits,
# mask_argsort_fwd_splits, mask_argsort_bwd_splits,
# num_activate_out, masks, self.training, self.subm,
# input._timer, self.fp32_accum,
# bias_cur,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# output_dtype = None
# if output_scale is None:
# output_dtype = weight.dtype
# out_features, _, _ = ops.implicit_gemm(
# features, weight_cur, pair_fwd, pair_mask_fwd_splits,
# mask_argsort_fwd_splits,
# num_activate_out, masks, self.training, self.subm,
# input._timer, self.fp32_accum,
# bias_cur,
# self.act_alpha,
# self.act_beta,
# self.act_type,
# # TODO do we really need output scale to scale bias in kernel?
# 1.0 if output_scale is None else output_scale, # output_scale
# channel_scale, # scale
# output_add=add_input.features if add_input is not None else None,
# output_add_scale=output_add_scale,
# output_dtype=output_dtype)
# if bias_for_training is not None:
# out_features += bias_for_training
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[self.name]["time"].append(interval)
# out_tensor.benchmark_record[self.name]["num_points"].append(
# features.shape[0])
# out_tensor.benchmark_record[self.name]["num_out_points"].append(
# out_features.shape[0])
# if not self.subm and not self.inverse and self.record_voxel_count:
# if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
# ops.maximum_value_int_(
# getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING),
# outids.shape[0])
# out_tensor = out_tensor.replace_feature(out_features)
# out_tensor.indices = outids
# out_tensor.indice_dict = indice_dict
# out_tensor.spatial_shape = out_spatial_shape
# if add_input is not None and not is_int8:
# # in int8, we apply add + act in kernel.
# out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
# return out_tensor
# def _check_subm_reuse_valid(self, inp: SparseConvTensor,
# spatial_shape: List[int],
# datas: Union[ImplicitGemmIndiceData,
# IndiceData]):
# assert datas.is_subm, "only support reuse subm indices"
# if self.kernel_size != datas.ksize:
# raise ValueError(
# f"subm with same indice_key must have same kernel"
# f" size, expect {datas.ksize}, this layer {self.kernel_size}")
# if self.dilation != datas.dilation:
# raise ValueError(
# f"subm with same indice_key must have same dilation"
# f", expect {datas.dilation}, this layer {self.dilation}")
# if inp.spatial_shape != datas.spatial_shape:
# raise ValueError(
# f"subm with same indice_key must have same spatial structure"
# f", expect {datas.spatial_shape}, input {spatial_shape}")
# if inp.indices.shape[0] != datas.indices.shape[0]:
# raise ValueError(
# f"subm with same indice_key must have same num of indices"
# f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
# )
# def _check_inverse_reuse_valid(self, inp: SparseConvTensor,
# spatial_shape: List[int],
# datas: Union[ImplicitGemmIndiceData,
# IndiceData]):
# if self.kernel_size != datas.ksize:
# raise ValueError(
# f"Inverse with same indice_key must have same kernel"
# f" size, expect {datas.ksize}, this layer {self.kernel_size}, "
# "please check Inverse Convolution in docs/USAGE.md.")
# if inp.spatial_shape != datas.out_spatial_shape:
# raise ValueError(
# f"Inverse with same indice_key must have same spatial structure (spatial shape)"
# f", expect {datas.spatial_shape}, input {spatial_shape}, "
# "please check Inverse Convolution in docs/USAGE.md.")
# if inp.indices.shape[0] != datas.out_indices.shape[0]:
# raise ValueError(
# f"Inverse with same indice_key must have same num of indices"
# f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, "
# "please check Inverse Convolution in ."
# )
class
SparseConv1d
(
SparseConvolution
):
class
SparseConv1d
(
SparseConvolution
):
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
...
@@ -1191,3 +1671,24 @@ class SubMConv4d(SparseConvolution):
...
@@ -1191,3 +1671,24 @@ class SubMConv4d(SparseConvolution):
algo
=
algo
,
algo
=
algo
,
fp32_accum
=
fp32_accum
,
fp32_accum
=
fp32_accum
,
name
=
name
)
name
=
name
)
DEFAULT_SPARSE_CONV_TYPES
=
{
SubMConv1d
,
SubMConv2d
,
SubMConv3d
,
SubMConv4d
,
SparseConv1d
,
SparseConv2d
,
SparseConv3d
,
SparseConv4d
,
SparseInverseConv1d
,
SparseInverseConv2d
,
SparseInverseConv3d
,
SparseInverseConv4d
,
SparseConvTranspose1d
,
SparseConvTranspose2d
,
SparseConvTranspose3d
,
SparseConvTranspose4d
,
}
spconv/pytorch/core.py
View file @
b1c57a31
...
@@ -128,7 +128,7 @@ def scatter_nd(indices, updates, shape):
...
@@ -128,7 +128,7 @@ def scatter_nd(indices, updates, shape):
return
ret
return
ret
# ProxyableClassMeta is used for
TensorRT conversion in future.
# ProxyableClassMeta is used for
torch.fx
class
SparseConvTensor
(
metaclass
=
SpConvTensorMeta
):
class
SparseConvTensor
(
metaclass
=
SpConvTensorMeta
):
def
__init__
(
self
,
def
__init__
(
self
,
features
:
torch
.
Tensor
,
features
:
torch
.
Tensor
,
...
@@ -181,8 +181,15 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
...
@@ -181,8 +181,15 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
force_algo
=
force_algo
self
.
force_algo
=
force_algo
# for simple int8 torch inference
self
.
int8_scale
:
Optional
[
float
]
=
None
@
property
def
is_quantized
(
self
):
return
self
.
features
.
dtype
==
torch
.
qint8
def
q_scale
(
self
):
if
self
.
is_quantized
:
return
self
.
features
.
q_scale
()
raise
ValueError
(
"sparse tensor must be quantized"
)
def
replace_feature
(
self
,
feature
:
torch
.
Tensor
):
def
replace_feature
(
self
,
feature
:
torch
.
Tensor
):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
...
@@ -220,7 +227,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
...
@@ -220,7 +227,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
x must be NHWC tensor, channel last
x must be NHWC tensor, channel last
"""
"""
x_sp
=
x
.
to_sparse
(
x
.
ndim
-
1
)
x_sp
=
x
.
to_sparse
(
x
.
ndim
-
1
)
spatial_shape
=
list
(
x_sp
.
shape
[
1
:
-
1
]
)
spatial_shape
=
x_sp
.
shape
[
1
:
-
1
]
batch_size
=
x_sp
.
shape
[
0
]
batch_size
=
x_sp
.
shape
[
0
]
indices_th
=
x_sp
.
indices
().
permute
(
1
,
0
).
contiguous
().
int
()
indices_th
=
x_sp
.
indices
().
permute
(
1
,
0
).
contiguous
().
int
()
features_th
=
x_sp
.
values
()
features_th
=
x_sp
.
values
()
...
...
spconv/pytorch/cppcore.py
View file @
b1c57a31
...
@@ -34,6 +34,7 @@ _TORCH_DTYPE_TO_TV = {
...
@@ -34,6 +34,7 @@ _TORCH_DTYPE_TO_TV = {
torch
.
int8
:
tv
.
int8
,
torch
.
int8
:
tv
.
int8
,
torch
.
int16
:
tv
.
int16
,
torch
.
int16
:
tv
.
int16
,
torch
.
uint8
:
tv
.
uint8
,
torch
.
uint8
:
tv
.
uint8
,
torch
.
qint8
:
tv
.
int8
,
}
}
_TORCH_UINT_WORKAROUNDS
=
{
_TORCH_UINT_WORKAROUNDS
=
{
...
@@ -42,6 +43,8 @@ _TORCH_UINT_WORKAROUNDS = {
...
@@ -42,6 +43,8 @@ _TORCH_UINT_WORKAROUNDS = {
tv
.
uint64
:
tv
.
int64
tv
.
uint64
:
tv
.
int64
}
}
_TH_QTYPES
=
{
torch
.
qint8
}
_TV_DTYPE_TO_TORCH
=
{
v
:
k
for
k
,
v
in
_TORCH_DTYPE_TO_TV
.
items
()}
_TV_DTYPE_TO_TORCH
=
{
v
:
k
for
k
,
v
in
_TORCH_DTYPE_TO_TV
.
items
()}
_TV_DTYPE_TO_TORCH
.
update
({
_TV_DTYPE_TO_TORCH
.
update
({
tv
.
uint32
:
torch
.
int32
,
tv
.
uint32
:
torch
.
int32
,
...
@@ -50,6 +53,9 @@ _TV_DTYPE_TO_TORCH.update({
...
@@ -50,6 +53,9 @@ _TV_DTYPE_TO_TORCH.update({
})
})
_TV_DTYPE_TO_TORCHQ
=
_TV_DTYPE_TO_TORCH
.
copy
()
_TV_DTYPE_TO_TORCHQ
[
tv
.
int8
]
=
torch
.
qint8
_ALL_INTS
=
{
_ALL_INTS
=
{
tv
.
int32
,
tv
.
int16
,
tv
.
int8
,
tv
.
int64
,
tv
.
uint64
,
tv
.
uint8
,
tv
.
uint32
,
tv
.
int32
,
tv
.
int16
,
tv
.
int8
,
tv
.
int64
,
tv
.
uint64
,
tv
.
uint8
,
tv
.
uint32
,
tv
.
uint16
tv
.
uint16
...
@@ -105,23 +111,31 @@ def get_arch():
...
@@ -105,23 +111,31 @@ def get_arch():
class
TorchAllocator
(
ExternalAllocator
):
class
TorchAllocator
(
ExternalAllocator
):
def
__init__
(
self
,
gpudevice
:
torch
.
device
)
->
None
:
def
__init__
(
self
,
gpudevice
:
torch
.
device
,
is_quantized
:
bool
=
False
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gpudevice
=
gpudevice
self
.
gpudevice
=
gpudevice
self
.
cpudevice
=
torch
.
device
(
"cpu"
)
self
.
cpudevice
=
torch
.
device
(
"cpu"
)
self
.
allocated
:
Dict
[
Union
[
str
,
int
],
torch
.
Tensor
]
=
{}
self
.
allocated
:
Dict
[
Union
[
str
,
int
],
torch
.
Tensor
]
=
{}
self
.
is_quantized
=
is_quantized
self
.
_tv_dtype_to_torch
=
_TV_DTYPE_TO_TORCH
if
is_quantized
:
self
.
_tv_dtype_to_torch
=
_TV_DTYPE_TO_TORCHQ
def
zeros
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
def
zeros
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
,
scale
:
float
=
1.0
)
->
tv
.
Tensor
:
# TODO free memory by name if its already free by pointer.
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit.
# provide a name if you want to access it after c++ function exit.
dtype_bkp
=
dtype
dtype_bkp
=
dtype
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
self
.
_tv_dtype_to_torch
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
ten
=
torch
.
zeros
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
if
self
.
is_quantized
:
ten
=
torch
.
_empty_affine_quantized
(
shape
,
scale
=
scale
,
zero_point
=
0
,
dtype
=
th_dtype
,
device
=
dev
)
else
:
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
).
zero_
()
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
if
name
and
not
is_temp_memory
:
if
name
and
not
is_temp_memory
:
...
@@ -129,14 +143,17 @@ class TorchAllocator(ExternalAllocator):
...
@@ -129,14 +143,17 @@ class TorchAllocator(ExternalAllocator):
return
ten_tv
return
ten_tv
def
empty
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
def
empty
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
,
scale
:
float
=
1.0
)
->
tv
.
Tensor
:
dtype_bkp
=
dtype
dtype_bkp
=
dtype
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
self
.
_tv_dtype_to_torch
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
if
self
.
is_quantized
:
ten
=
torch
.
_empty_affine_quantized
(
shape
,
scale
=
scale
,
zero_point
=
0
,
dtype
=
th_dtype
,
device
=
dev
)
else
:
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
if
name
and
not
is_temp_memory
:
if
name
and
not
is_temp_memory
:
...
@@ -148,11 +165,13 @@ class TorchAllocator(ExternalAllocator):
...
@@ -148,11 +165,13 @@ class TorchAllocator(ExternalAllocator):
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
dtype_bkp
=
dtype
dtype_bkp
=
dtype
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
self
.
_tv_dtype_to_torch
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
if
self
.
is_quantized
:
assert
th_dtype
not
in
_TH_QTYPES
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
...
@@ -165,11 +184,13 @@ class TorchAllocator(ExternalAllocator):
...
@@ -165,11 +184,13 @@ class TorchAllocator(ExternalAllocator):
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
dtype_bkp
=
dtype
dtype_bkp
=
dtype
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
self
.
_tv_dtype_to_torch
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
if
self
.
is_quantized
:
assert
th_dtype
not
in
_TH_QTYPES
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
...
...
spconv/pytorch/quantization/__init__.py
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.backend_cfg
import
(
get_spconv_backend_config
,
get_spconv_prepare_custom_config
,
get_spconv_convert_custom_config
)
from
.fake_q
import
(
get_default_spconv_trt_ptq_qconfig
,
get_default_spconv_trt_qat_qconfig
)
from
.qmapping
import
(
get_spconv_fmod_to_qat_mapping
,
get_spconv_qat_to_static_mapping
)
spconv/pytorch/quantization/backend_cfg.py
0 → 100644
View file @
b1c57a31
from
collections
import
namedtuple
from
typing
import
List
,
Dict
,
Union
,
Type
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.ao.quantization.backend_config
import
(
BackendConfig
,
BackendPatternConfig
,
DTypeConfig
,
ObservationType
,
get_tensorrt_backend_config
)
from
torch.ao.quantization.fx.custom_config
import
(
ConvertCustomConfig
,
FuseCustomConfig
,
PrepareCustomConfig
)
from
torch.ao.nn.quantized.modules.utils
import
WeightedQuantizedModule
import
spconv.pytorch.conv
as
sconvmod
import
spconv.pytorch.quantization.intrinsic
as
snni
import
spconv.pytorch.quantization.intrinsic.qat
as
snniqat
import
spconv.pytorch.quantization.intrinsic.quantized
as
snniq
import
spconv.pytorch.quantization.quantized
as
snnq
import
spconv.pytorch.quantization.quantized.reference
as
snnqr
from
spconv.pytorch.constants
import
PYTORCH_VERSION
from
spconv.pytorch.quantization.fuse_mapping
import
(
fuse_conv_bn
,
fuse_conv_bn_relu
)
from
spconv.pytorch
import
ToDense
_SpConvMetadataDef
=
namedtuple
(
"_ConvMetadata"
,
[
"root"
,
"bn"
,
"reference"
,
"fused_conv_relu"
,
"fused_conv_bn"
,
"fused_conv_bn_relu"
,
"qat"
,
"relu_qat"
,
"bn_qat"
,
"bn_relu_qat"
])
_SpConvMetadatas
:
List
[
_SpConvMetadataDef
]
=
[]
for
t
in
sconvmod
.
DEFAULT_SPARSE_CONV_TYPES
:
_SpConvMetadatas
.
append
(
_SpConvMetadataDef
(
t
,
nn
.
BatchNorm1d
,
snnqr
.
SpConv
,
snni
.
SpconvReLUNd
,
snni
.
SpconvBnNd
,
snni
.
SpconvBnReLUNd
,
snniqat
.
SparseConv
,
snniqat
.
SparseConvReLU
,
snniqat
.
SparseConvBn
,
snniqat
.
SparseConvBnReLU
))
_SpConvMetadatas
.
append
(
_SpConvMetadataDef
(
sconvmod
.
SparseConvolution
,
nn
.
BatchNorm1d
,
snnqr
.
SpConv
,
snni
.
SpconvReLUNd
,
snni
.
SpconvBnNd
,
snni
.
SpconvBnReLUNd
,
snniqat
.
SparseConv
,
snniqat
.
SparseConvReLU
,
snniqat
.
SparseConvBn
,
snniqat
.
SparseConvBnReLU
))
def
_sequential_wrapper2
(
sequential
):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules
"""
def
fuser_method
(
is_qat
,
m1
,
m2
):
return
sequential
(
m1
,
m2
)
return
fuser_method
# new cfg remove reverse pattern.
def
_get_spconv_configs
(
dtype_configs
):
"""
Return all configs related to conv modules and ops.
"""
conv_configs
=
[]
observation_type
=
ObservationType
.
OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
if
PYTORCH_VERSION
<=
[
1
,
13
,
1
]:
from
torch.ao.quantization.fuser_method_mappings
import
(
reverse2
,
reverse3
,
reverse_sequential_wrapper2
)
for
convs
in
_SpConvMetadatas
:
# (1) Single conv modules/functions
# -----------------------------------
# conv module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
root
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
qat
))
# conv qat module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs
.
append
(
BackendPatternConfig
((
torch
.
nn
.
ReLU
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# conv relu fusion, conv module + functional relu
conv_configs
.
append
(
BackendPatternConfig
((
F
.
relu
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
relu_qat
))
# conv relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# 2.3 functional conv + relu configs
# fused conv relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
bn
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse2
(
fuse_conv_bn
))
.
set_fused_module
(
convs
.
fused_conv_bn
))
# conv + bn + relu module fusion
conv_configs
.
append
(
BackendPatternConfig
((
nn
.
ReLU
,
(
convs
.
bn
,
convs
.
root
)))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse3
(
fuse_conv_bn_relu
))
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# conv + bn + relu functional fusion
conv_configs
.
append
(
BackendPatternConfig
((
F
.
relu
,
(
convs
.
bn
,
convs
.
root
)))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_fuser_method
(
reverse3
(
fuse_conv_bn_relu
))
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_qat
))
# fused conv bn relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_relu_qat
))
# conv bn, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# conv bn relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (4) conv transpose and its fusion
# 4.1 conv transpose config
# conv_configs.append(
# BackendPatternConfig(convs.transpose)
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# # 4.2 conv transpose + bn fusion
# conv_configs.append(
# BackendPatternConfig((convs.bn, convs.transpose))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(reverse2(fuse_conv_bn))
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
return
conv_configs
else
:
for
convs
in
_SpConvMetadatas
:
# (1) Single conv modules/functions
# -----------------------------------
# conv module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
root
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
qat
))
# conv qat module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
torch
.
nn
.
ReLU
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# conv relu fusion, conv module + functional relu
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
F
.
relu
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
relu_qat
))
# conv relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# fused conv relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
fuse_conv_bn
)
.
set_fused_module
(
convs
.
fused_conv_bn
))
# conv + bn + relu module fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
,
nn
.
ReLU
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
fuse_conv_bn_relu
)
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# conv + bn + relu functional fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
,
F
.
relu
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_fuser_method
(
fuse_conv_bn_relu
)
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_qat
))
# fused conv bn relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_relu_qat
))
# conv bn, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# conv bn relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# # (4) conv transpose and its fusion
# # 4.1 conv transpose config
# conv_configs.append(
# BackendPatternConfig(convs.transpose)
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# # 4.2 conv transpose + bn fusion
# conv_configs.append(
# BackendPatternConfig((convs.transpose, convs.bn))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(fuse_conv_bn)
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
return
conv_configs
weighted_op_qint8_dtype_config
=
DTypeConfig
(
input_dtype
=
torch
.
qint8
,
output_dtype
=
torch
.
qint8
,
weight_dtype
=
torch
.
qint8
,
bias_dtype
=
torch
.
float
,
)
conv_dtype_configs
=
[
weighted_op_qint8_dtype_config
,
]
_to_dense_cfg
=
(
BackendPatternConfig
(
ToDense
)
.
set_observation_type
(
ObservationType
.
OUTPUT_SHARE_OBSERVER_WITH_INPUT
))
backend_config
=
get_tensorrt_backend_config
()
\
.
set_backend_pattern_configs
(
_get_spconv_configs
(
conv_dtype_configs
)
+
[
_to_dense_cfg
])
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Tuple
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]]
=
{
snni
.
SpconvReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvReLU
),
}
def
get_spconv_backend_config
():
return
backend_config
def
get_spconv_prepare_custom_config
():
cfg
=
PrepareCustomConfig
()
cfg
.
non_traceable_module_classes
=
[
*
sconvmod
.
DEFAULT_SPARSE_CONV_TYPES
]
return
cfg
def
get_spconv_convert_custom_config
():
cfg
=
ConvertCustomConfig
()
cfg
.
set_observed_to_quantized_mapping
(
snni
.
SpconvReLUNd
,
snniq
.
SparseConvReLU
)
# cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU)
return
cfg
\ No newline at end of file
spconv/pytorch/quantization/fake_q.py
View file @
b1c57a31
from
torch.ao.quantization.fake_quantize
import
FusedMovingAvgObsFakeQuantize
,
fused_wt_fake_quant_range_neg_127_to_127
import
torch
from
torch.ao.quantization.fake_quantize
import
(
FixedQParamsFakeQuantize
,
FusedMovingAvgObsFakeQuantize
,
FakeQuantize
,
default_fused_per_channel_wt_fake_quant
,
default_weight_fake_quant
,
default_per_channel_weight_fake_quant
)
from
torch.ao.quantization.observer
import
(
HistogramObserver
,
MovingAverageMinMaxObserver
,
default_weight_observer
,
default_placeholder_observer
,
default_per_channel_weight_observer
)
from
torch.ao.quantization.qconfig
import
QConfig
,
QConfigAny
,
default_reuse_input_qconfig
from
torch.ao.quantization.qconfig_mapping
import
QConfigMapping
,
_FIXED_QPARAMS_OP_TO_OBSERVER
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Union
,
List
from
torch.ao.quantization
import
get_default_qconfig
from
spconv.pytorch.core
import
SparseConvTensor
from
spconv.pytorch.core
import
SparseConvTensor
import
torch
from
torch.ao.quantization.qconfig
import
QConfig
__all__
=
[
"get_default_spconv_trt_ptq_qconfig"
,
"get_default_spconv_trt_qat_qconfig"
]
from
torch.ao.quantization.observer
import
MovingAverageMinMaxObserver
class
SparseFusedMovingAvgObsFakeQuantize
(
FusedMovingAvgObsFakeQuantize
):
class
SparseFusedMovingAvgObsFakeQuantize
(
FusedMovingAvgObsFakeQuantize
):
def
forward
(
self
,
input
:
SparseConvTensor
):
def
forward
(
self
,
input
:
Union
[
SparseConvTensor
,
torch
.
Tensor
]):
# add lines to support spconv
if
isinstance
(
input
,
SparseConvTensor
):
x
=
input
.
features
# add lines to support spconv
res_features
=
super
().
forward
(
x
)
x
=
input
.
features
return
input
.
replace_feature
(
res_features
)
res_features
=
super
().
forward
(
x
)
return
input
.
replace_feature
(
res_features
)
else
:
return
super
().
forward
(
input
)
# class SparseMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
# def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
# if isinstance(input, SparseConvTensor):
# # add lines to support spconv
# x = input.features
# res_features = super().forward(x)
# return input.replace_feature(res_features)
# else:
# return super().forward(input)
class
SparseHistogramObserver
(
HistogramObserver
):
def
forward
(
self
,
input
:
Union
[
SparseConvTensor
,
torch
.
Tensor
]):
if
isinstance
(
input
,
SparseConvTensor
):
# add lines to support spconv
x
=
input
.
features
res_features
=
super
().
forward
(
x
)
return
input
.
replace_feature
(
res_features
)
else
:
return
super
().
forward
(
input
)
default_symmetric_spconv_ptq_qconfig
=
QConfig
(
activation
=
SparseHistogramObserver
.
with_args
(
quant_min
=-
128
,
quant_max
=
127
,
dtype
=
torch
.
qint8
,
reduce_range
=
False
,
qscheme
=
torch
.
per_tensor_symmetric
,
eps
=
2
**
-
12
),
weight
=
default_per_channel_weight_observer
)
# default_symmetric_ptq_qconfig = QConfig(
# activation=HistogramObserver.with_args(quant_min=-128,
# quant_max=127,
# dtype=torch.qint8,
# reduce_range=False,
# eps=2 ** -12),
# weight=default_per_channel_weight_observer)
default_symmetric_spconv_qat_qconfig
=
QConfig
(
default_symmetric_spconv_qat_qconfig
=
QConfig
(
activation
=
SparseFusedMovingAvgObsFakeQuantize
.
with_args
(
observer
=
MovingAverageMinMaxObserver
,
activation
=
SparseFusedMovingAvgObsFakeQuantize
.
with_args
(
observer
=
MovingAverageMinMaxObserver
,
...
@@ -19,5 +69,78 @@ default_symmetric_spconv_qat_qconfig = QConfig(
...
@@ -19,5 +69,78 @@ default_symmetric_spconv_qat_qconfig = QConfig(
quant_max
=
127
,
quant_max
=
127
,
dtype
=
torch
.
qint8
,
dtype
=
torch
.
qint8
,
reduce_range
=
False
,
reduce_range
=
False
,
qscheme
=
torch
.
per_tensor_symmetric
,
eps
=
2
**
-
12
),
eps
=
2
**
-
12
),
weight
=
fused_wt_fake_quant_range_neg_127_to_127
)
weight
=
default_fused_per_channel_wt_fake_quant
)
def
get_default_spconv_trt_ptq_qconfig
(
backend
,
version
):
return
default_symmetric_spconv_ptq_qconfig
def
get_default_spconv_trt_qat_qconfig
(
backend
,
version
):
return
default_symmetric_spconv_qat_qconfig
def
get_default_spconv_qconfig_mapping
(
is_qat
:
bool
,
backend
:
str
=
"x86"
,
version
:
int
=
0
)
->
QConfigMapping
:
"""
From torch.ao.quantization.qconfig_mapping
Return the default QConfigMapping for the given quantization type and backend.
"""
# get_default_qconfig(backend, version)
if
is_qat
:
qconfig
=
get_default_spconv_trt_qat_qconfig
(
backend
,
version
)
else
:
# qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8),
# weight=default_per_channel_weight_observer)
qconfig
=
get_default_spconv_trt_ptq_qconfig
(
backend
,
version
)
default_weight
=
default_weight_fake_quant
if
is_qat
else
default_weight_observer
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
# so we have to modify the weight observer to default_weight_observer or another
# per tensor supported observer.
# see https://github.com/pytorch/pytorch/issues/47535
if
backend
in
(
"fbgemm"
,
"x86"
):
qconfig_transpose
=
QConfig
(
activation
=
qconfig
.
activation
,
weight
=
default_weight
)
else
:
qconfig_transpose
=
qconfig
# currently layernorm only supports float weights
# we have to add this because otherwise there will be a extra quantize-dequantize pair
qconfig_layernorm
=
QConfig
(
activation
=
qconfig
.
activation
,
weight
=
default_placeholder_observer
)
qconfig_mapping
=
QConfigMapping
()
\
.
set_global
(
qconfig
)
\
.
set_object_type
(
"reshape"
,
default_reuse_input_qconfig
)
\
.
set_object_type
(
torch
.
nn
.
ConvTranspose1d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
ConvTranspose2d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
ConvTranspose3d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
conv_transpose1d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
conv_transpose2d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
conv_transpose3d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
layer_norm
,
qconfig_layernorm
)
\
.
set_object_type
(
torch
.
nn
.
LayerNorm
,
qconfig_layernorm
)
\
# Use special observers for ops with fixed qparams
fixed_qparams_observer_to_qconfig
:
Dict
[
Any
,
QConfigAny
]
=
{}
for
fixed_qparams_op
,
observer
in
_FIXED_QPARAMS_OP_TO_OBSERVER
.
items
():
if
observer
in
fixed_qparams_observer_to_qconfig
:
fixed_qparams_qconfig
=
fixed_qparams_observer_to_qconfig
[
observer
]
else
:
if
is_qat
:
activation
=
FixedQParamsFakeQuantize
.
with_args
(
observer
=
observer
)
else
:
activation
=
observer
fixed_qparams_qconfig
=
QConfig
(
activation
=
activation
,
weight
=
default_weight
)
fixed_qparams_observer_to_qconfig
[
observer
]
=
fixed_qparams_qconfig
qconfig_mapping
.
set_object_type
(
fixed_qparams_op
,
fixed_qparams_qconfig
)
# QConfig for fused ops for onednn backend
# Separate ops are required to have the same qconfig as fused ops
# TODO: we should be able to configure qconfig for patterns
if
backend
==
'onednn'
:
qconfig_mapping
.
set_object_type
(
torch
.
nn
.
Linear
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
LeakyReLU
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
leaky_relu
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
Tanh
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
tanh
,
qconfig
)
return
qconfig_mapping
spconv/pytorch/quantization/fuse_mapping.py
View file @
b1c57a31
...
@@ -3,9 +3,9 @@ import torch.nn as nn
...
@@ -3,9 +3,9 @@ import torch.nn as nn
import
spconv.pytorch
as
spconv
import
spconv.pytorch
as
spconv
from
.utils
import
fuse_spconv_bn_eval
from
.utils
import
fuse_spconv_bn_eval
from
.
import
intrinsic
as
snni
from
.
import
intrinsic
as
snni
from
.
conv_fused
import
SparseConvBn
,
SparseConvBnReLU
from
.
intrinsic.qat.modules
import
SparseConvBn
,
SparseConvBnReLU
,
SparseConvBnAddReLU
from
spconv.pytorch.conv
import
DEFAULT_SPARSE_CONV_TYPES
def
fuse_conv_bn
(
conv
,
bn
):
def
fuse_conv_bn
(
is_qat
,
conv
,
bn
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
Args:
...
@@ -20,20 +20,12 @@ def fuse_conv_bn(conv, bn):
...
@@ -20,20 +20,12 @@ def fuse_conv_bn(conv, bn):
"""
"""
assert
(
conv
.
training
==
bn
.
training
),
\
assert
(
conv
.
training
==
bn
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
"Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map
=
{
fused_module_class_map
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvBnNd
,
k
:
snni
.
SpconvBnNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
spconv
.
SparseConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvBnNd
,
}
}
if
conv
.
training
:
if
is_qat
:
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
assert
bn
.
affine
,
'Only support fusing BatchNorm2d with affine set to True'
assert
bn
.
affine
,
'Only support fusing BatchNorm2d with affine set to True'
assert
bn
.
track_running_stats
,
'Only support fusing BatchNorm2d with tracking_running_stats set to True'
assert
bn
.
track_running_stats
,
'Only support fusing BatchNorm2d with tracking_running_stats set to True'
...
@@ -45,7 +37,7 @@ def fuse_conv_bn(conv, bn):
...
@@ -45,7 +37,7 @@ def fuse_conv_bn(conv, bn):
else
:
else
:
return
fuse_spconv_bn_eval
(
conv
,
bn
)
return
fuse_spconv_bn_eval
(
conv
,
bn
)
def
fuse_conv_bn_relu
(
conv
,
bn
,
relu
):
def
fuse_conv_bn_relu
(
is_qat
,
conv
,
bn
,
relu
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
Args:
...
@@ -61,17 +53,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
...
@@ -61,17 +53,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
assert
(
conv
.
training
==
bn
.
training
==
relu
.
training
),
\
assert
(
conv
.
training
==
bn
.
training
==
relu
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
"Conv and BN both must be in the same mode (train or eval)."
fused_module
:
Optional
[
Type
[
spconv
.
SparseSequential
]]
=
None
fused_module
:
Optional
[
Type
[
spconv
.
SparseSequential
]]
=
None
if
conv
.
training
:
if
is_qat
:
map_to_fused_module_train
=
{
map_to_fused_module_train
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvBnReLUNd
,
k
:
snni
.
SpconvBnReLUNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
spconv
.
SparseConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvBnReLUNd
,
}
}
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv must match num_features of BatchNorm'
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv must match num_features of BatchNorm'
assert
bn
.
affine
,
'Only support fusing BatchNorm with affine set to True'
assert
bn
.
affine
,
'Only support fusing BatchNorm with affine set to True'
...
@@ -83,15 +67,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
...
@@ -83,15 +67,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
,
relu
)))
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
,
relu
)))
else
:
else
:
map_to_fused_module_eval
=
{
map_to_fused_module_eval
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvReLUNd
,
k
:
snni
.
SpconvReLUNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
spconv
.
SparseConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvReLUNd
,
}
}
fused_module
=
map_to_fused_module_eval
.
get
(
type
(
conv
),
None
)
fused_module
=
map_to_fused_module_eval
.
get
(
type
(
conv
),
None
)
if
fused_module
is
not
None
:
if
fused_module
is
not
None
:
...
@@ -100,31 +76,28 @@ def fuse_conv_bn_relu(conv, bn, relu):
...
@@ -100,31 +76,28 @@ def fuse_conv_bn_relu(conv, bn, relu):
else
:
else
:
raise
NotImplementedError
(
"Cannot fuse eval modules: {}"
.
format
((
conv
,
bn
,
relu
)))
raise
NotImplementedError
(
"Cannot fuse eval modules: {}"
.
format
((
conv
,
bn
,
relu
)))
DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
:
Dict
[
Tuple
,
Union
[
nn
.
Sequential
,
Callable
]]
=
{
# DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(
spconv
.
SubMConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SubMConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseInverseConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseInverseConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SubMConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SubMConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseInverseConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseInverseConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SubMConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SubMConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseInverseConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseInverseConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
}
# }
# def get_spconv_fuse_method_mapping():
# return DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
# Default map for swapping float module to qat modules
# Default map for swapping float module to qat modules
DEFAULT_SPCONV_QAT_MODULE_MAPPINGS
:
Dict
[
Callable
,
Any
]
=
{
# nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
snni
.
SpconvBnNd
:
SparseConvBn
,
snni
.
SpconvBnReLUNd
:
SparseConvBnReLU
,
}
spconv/pytorch/quantization/intrinsic/__init__.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.modules
import
SpconvBnNd
,
SpconvBnReLUNd
,
SpconvBnAddReLUNd
,
SpconvReLUNd
spconv/pytorch/quantization/intrinsic.py
→
spconv/pytorch/quantization/intrinsic
/modules
.py
View file @
b1c57a31
...
@@ -4,9 +4,28 @@ from torch.nn.utils.parametrize import type_before_parametrizations
...
@@ -4,9 +4,28 @@ from torch.nn.utils.parametrize import type_before_parametrizations
import
torch.ao.nn.intrinsic
as
nni
import
torch.ao.nn.intrinsic
as
nni
from
spconv.pytorch.conv
import
SparseConvolution
from
spconv.pytorch.conv
import
SparseConvolution
from
spconv.pytorch.modules
import
is_spconv_module
from
spconv.pytorch.core
import
SparseConvTensor
class
_FusedSparseModule
(
nni
.
_FusedModule
):
def
forward
(
self
,
input
):
for
k
,
module
in
self
.
_modules
.
items
():
if
is_spconv_module
(
module
):
# use SpConvTensor as input
if
isinstance
(
input
,
list
):
input
=
module
(
input
)
else
:
# assert isinstance(input, spconv.SparseConvTensor)
# self._sparity_dict[k] = input.sparity
input
=
module
(
input
)
else
:
if
isinstance
(
input
,
SparseConvTensor
):
if
input
.
indices
.
shape
[
0
]
!=
0
:
input
=
input
.
replace_feature
(
module
(
input
.
features
))
else
:
input
=
module
(
input
)
return
input
class
SpconvReLUNd
(
nni
.
_FusedModule
):
class
SpconvReLUNd
(
_Fused
Sparse
Module
):
r
"""This is a sequential container which calls the Conv3d and ReLU modules.
r
"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
relu
):
def
__init__
(
self
,
conv
,
relu
):
...
@@ -15,7 +34,7 @@ class SpconvReLUNd(nni._FusedModule):
...
@@ -15,7 +34,7 @@ class SpconvReLUNd(nni._FusedModule):
type
(
conv
),
type
(
relu
))
type
(
conv
),
type
(
relu
))
super
().
__init__
(
conv
,
relu
)
super
().
__init__
(
conv
,
relu
)
class
SpconvBnNd
(
nni
.
_FusedModule
):
class
SpconvBnNd
(
_Fused
Sparse
Module
):
r
"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
r
"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
):
def
__init__
(
self
,
conv
,
bn
):
...
@@ -24,8 +43,16 @@ class SpconvBnNd(nni._FusedModule):
...
@@ -24,8 +43,16 @@ class SpconvBnNd(nni._FusedModule):
type
(
conv
),
type
(
bn
))
type
(
conv
),
type
(
bn
))
super
().
__init__
(
conv
,
bn
)
super
().
__init__
(
conv
,
bn
)
class
SpconvBnReLUNd
(
_FusedSparseModule
):
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
,
relu
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
bn
,
BatchNorm1d
)
and
\
isinstance
(
relu
,
ReLU
),
'Incorrect types for input modules{}{}{}'
\
.
format
(
type
(
conv
),
type
(
bn
),
type
(
relu
))
super
().
__init__
(
conv
,
bn
,
relu
)
class
SpconvBnReLUNd
(
nni
.
_FusedModule
):
class
SpconvBn
Add
ReLUNd
(
_Fused
Sparse
Module
):
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
,
relu
):
def
__init__
(
self
,
conv
,
bn
,
relu
):
...
...
spconv/pytorch/quantization/intrinsic/qat/__init__.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.modules
import
SparseConvBn
,
SparseConvBnAddReLU
,
SparseConvBnReLU
,
SparseConv
,
SparseConvReLU
\ No newline at end of file
spconv/pytorch/quantization/
conv_fused
.py
→
spconv/pytorch/quantization/
intrinsic/qat/modules
.py
View file @
b1c57a31
...
@@ -6,8 +6,6 @@ import torch.ao.nn.intrinsic as nni
...
@@ -6,8 +6,6 @@ import torch.ao.nn.intrinsic as nni
import
torch.ao.nn.qat
as
nnqat
import
torch.ao.nn.qat
as
nnqat
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn.utils
import
fuse_conv_bn_weights
from
torch.nn.modules.utils
import
_single
,
_pair
,
_triple
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
typing
import
TypeVar
from
typing
import
TypeVar
from
spconv.pytorch.conv
import
SparseConvolution
from
spconv.pytorch.conv
import
SparseConvolution
...
@@ -16,9 +14,189 @@ from spconv.core import ConvAlgo
...
@@ -16,9 +14,189 @@ from spconv.core import ConvAlgo
from
cumm
import
tensorview
as
tv
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.core
import
SparseConvTensor
from
spconv.pytorch.core
import
SparseConvTensor
import
spconv.pytorch.quantization.intrinsic
as
snni
import
spconv.pytorch.quantization.intrinsic
as
snni
from
spconv.pytorch.quantization.utils
import
fuse_spconv_bn_weights
MOD
=
TypeVar
(
'MOD'
,
bound
=
SparseConvolution
)
MOD
=
TypeVar
(
'MOD'
,
bound
=
SparseConvolution
)
class
_SparseConv
(
SparseConvolution
,
nni
.
_FusedModule
):
_FLOAT_MODULE
=
MOD
_FLOAT_CONV_MODULE
=
SparseConvolution
def
__init__
(
self
,
ndim
:
int
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
3
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
transposed
:
bool
=
False
,
inverse
:
bool
=
False
,
indice_key
:
Optional
[
str
]
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
name
=
None
,
qconfig
=
None
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
SparseConvolution
.
__init__
(
self
,
ndim
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
=
False
,
subm
=
subm
,
output_padding
=
output_padding
,
transposed
=
transposed
,
inverse
=
inverse
,
indice_key
=
indice_key
,
algo
=
algo
,
fp32_accum
=
fp32_accum
,
record_voxel_count
=
record_voxel_count
,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
name
=
name
,
**
factory_kwargs
)
assert
qconfig
,
'qconfig must be provided for QAT module'
self
.
qconfig
=
qconfig
self
.
weight_fake_quant
=
qconfig
.
weight
(
factory_kwargs
=
factory_kwargs
)
def
forward
(
self
,
input
):
return
self
.
_conv_forward
(
False
,
input
,
self
.
weight_fake_quant
(
self
.
weight
),
self
.
bias
)
@
staticmethod
def
from_float
(
cls
,
mod
):
r
"""Create a qat module from a float module
Args:
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert
type
(
mod
)
==
cls
.
_FLOAT_MODULE
,
(
"qat."
+
cls
.
__name__
+
".from_float only works for "
+
cls
.
_FLOAT_MODULE
.
__name__
# type: ignore[attr-defined]
)
assert
hasattr
(
mod
,
'qconfig'
),
'Input float module must have qconfig defined'
assert
mod
.
qconfig
,
'Input float module must have a valid qconfig'
if
issubclass
(
type
(
mod
),
nni
.
_FusedModule
):
mod
=
mod
[
0
]
# type: ignore[index]
conv
:
SparseConvolution
=
mod
qconfig
=
mod
.
qconfig
qat_conv
=
cls
(
conv
.
ndim
,
conv
.
in_channels
,
conv
.
out_channels
,
conv
.
kernel_size
,
conv
.
stride
,
conv
.
padding
,
conv
.
dilation
,
conv
.
groups
,
conv
.
bias
is
not
None
,
subm
=
conv
.
subm
,
output_padding
=
conv
.
output_padding
,
transposed
=
conv
.
transposed
,
inverse
=
conv
.
inverse
,
indice_key
=
conv
.
indice_key
,
algo
=
conv
.
algo
,
fp32_accum
=
conv
.
fp32_accum
,
record_voxel_count
=
conv
.
record_voxel_count
,
act_type
=
conv
.
act_type
,
act_alpha
=
conv
.
act_alpha
,
act_beta
=
conv
.
act_beta
,
name
=
conv
.
name
,
qconfig
=
qconfig
)
qat_conv
.
weight
=
mod
.
weight
qat_conv
.
bias
=
mod
.
bias
return
qat_conv
def
to_float
(
self
):
""" This works for both single qat conv, and the qat conv - relu modules
to convert the qat module to a floating point module
"""
cls
=
type
(
self
)
conv
=
cls
.
_FLOAT_CONV_MODULE
(
# type: ignore[attr-defined]
self
.
ndim
,
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
bias
is
not
None
,
subm
=
self
.
subm
,
output_padding
=
self
.
output_padding
,
transposed
=
self
.
transposed
,
inverse
=
self
.
inverse
,
indice_key
=
self
.
indice_key
,
algo
=
self
.
algo
,
fp32_accum
=
self
.
fp32_accum
,
record_voxel_count
=
self
.
record_voxel_count
,
act_type
=
self
.
act_type
,
act_alpha
=
self
.
act_alpha
,
act_beta
=
self
.
act_beta
,
name
=
self
.
name
)
conv
.
weight
=
torch
.
nn
.
Parameter
(
self
.
weight
.
detach
())
if
self
.
bias
is
not
None
:
conv
.
bias
=
torch
.
nn
.
Parameter
(
self
.
bias
.
detach
())
# conv relu
if
issubclass
(
cls
,
nni
.
_FusedModule
):
modules
=
[
conv
]
assert
hasattr
(
cls
,
"_FLOAT_RELU_MODULE"
)
relu
=
cls
.
_FLOAT_RELU_MODULE
()
# type: ignore[attr-defined]
modules
.
append
(
relu
)
fused
=
cls
.
_FLOAT_MODULE
(
*
modules
)
# type: ignore[arg-type, attr-defined, operator]
fused
.
train
(
self
.
training
)
return
fused
else
:
return
conv
class
SparseConv
(
_SparseConv
,
SparseConvolution
):
r
"""
A Conv1d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as :class:`~torch.nn.Conv1d`
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE
=
SparseConvolution
_FLOAT_CONV_MODULE
=
SparseConvolution
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
().
from_float
(
cls
,
mod
)
class
SparseConvReLU
(
SparseConv
,
nni
.
_FusedModule
):
r
"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE
=
snni
.
SpconvReLUNd
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
None
_FLOAT_RELU_MODULE
=
nn
.
ReLU
def
forward
(
self
,
input
):
x
=
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight_fake_quant
(
self
.
weight
),
self
.
bias
)
return
x
.
replace_feature
(
F
.
relu
(
x
.
features
))
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvReLU
,
cls
).
from_float
(
mod
)
class
_SparseConvBn
(
SparseConvolution
,
nni
.
_FusedModule
):
class
_SparseConvBn
(
SparseConvolution
,
nni
.
_FusedModule
):
_version
=
2
_version
=
2
...
@@ -34,7 +212,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
...
@@ -34,7 +212,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
...
@@ -143,7 +321,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
...
@@ -143,7 +321,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
zero_bias
=
torch
.
zeros_like
(
self
.
bias
,
dtype
=
input
.
features
.
dtype
)
zero_bias
=
torch
.
zeros_like
(
self
.
bias
,
dtype
=
input
.
features
.
dtype
)
else
:
else
:
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
scaled_weight
.
device
,
dtype
=
input
.
features
.
dtype
)
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
scaled_weight
.
device
,
dtype
=
input
.
features
.
dtype
)
conv_spt
=
self
.
_conv_forward
(
input
,
scaled_weight
,
zero_bias
)
conv_spt
=
self
.
_conv_forward
(
self
.
training
,
input
,
scaled_weight
,
zero_bias
)
conv
=
conv_spt
.
features
conv
=
conv_spt
.
features
conv_orig
=
conv
/
scale_factor
.
reshape
(
bias_shape
)
conv_orig
=
conv
/
scale_factor
.
reshape
(
bias_shape
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
@@ -396,7 +574,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
...
@@ -396,7 +574,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
if
cls
.
_FLOAT_BN_MODULE
:
# type: ignore[attr-defined]
if
cls
.
_FLOAT_BN_MODULE
:
# type: ignore[attr-defined]
# fuse bn into conv
# fuse bn into conv
conv
.
weight
,
conv
.
bias
=
fuse_conv_bn_weights
(
conv
.
weight
,
conv
.
bias
=
fuse_
sp
conv_bn_weights
(
conv
.
weight
,
conv
.
weight
,
conv
.
bias
,
conv
.
bias
,
self
.
bn
.
running_mean
,
self
.
bn
.
running_mean
,
...
@@ -473,3 +651,35 @@ class SparseConvBnReLU(_SparseConvBn):
...
@@ -473,3 +651,35 @@ class SparseConvBnReLU(_SparseConvBn):
@
classmethod
@
classmethod
def
from_float
(
cls
,
mod
):
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvBnReLU
,
cls
).
from_float
(
mod
)
return
super
(
SparseConvBnReLU
,
cls
).
from_float
(
mod
)
class
SparseConvBnAddReLU
(
_SparseConvBn
):
r
"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE
=
snni
.
SpconvBnReLUNd
# type: ignore[assignment]
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
nn
.
BatchNorm1d
_FLOAT_RELU_MODULE
=
nn
.
ReLU
# type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE
=
snni
.
SpconvReLUNd
def
forward
(
self
,
input
,
add_input
):
x
=
_SparseConvBn
.
_forward
(
self
,
input
,
add_input
)
return
x
.
replace_feature
(
F
.
relu
(
x
.
features
))
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvBnAddReLU
,
cls
).
from_float
(
mod
)
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.conv_relu
import
*
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment