Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
b1c57a31
Commit
b1c57a31
authored
Jan 03, 2023
by
yan.yan
Browse files
still working on int8
parent
aa26c99e
Changes
25
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1871 additions
and
343 deletions
+1871
-343
example/mnist/mnist_qat.py
example/mnist/mnist_qat.py
+226
-19
spconv/build.py
spconv/build.py
+1
-1
spconv/core.py
spconv/core.py
+32
-2
spconv/core_cc/csrc/sparse/alloc.pyi
spconv/core_cc/csrc/sparse/alloc.pyi
+4
-2
spconv/core_cc/csrc/sparse/convops/spops.pyi
spconv/core_cc/csrc/sparse/convops/spops.pyi
+2
-1
spconv/csrc/sparse/alloc.py
spconv/csrc/sparse/alloc.py
+2
-2
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+2
-2
spconv/csrc/sparse/maxpool.py
spconv/csrc/sparse/maxpool.py
+3
-3
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+720
-219
spconv/pytorch/core.py
spconv/pytorch/core.py
+11
-4
spconv/pytorch/cppcore.py
spconv/pytorch/cppcore.py
+30
-9
spconv/pytorch/quantization/__init__.py
spconv/pytorch/quantization/__init__.py
+21
-0
spconv/pytorch/quantization/backend_cfg.py
spconv/pytorch/quantization/backend_cfg.py
+360
-0
spconv/pytorch/quantization/fake_q.py
spconv/pytorch/quantization/fake_q.py
+133
-10
spconv/pytorch/quantization/fuse_mapping.py
spconv/pytorch/quantization/fuse_mapping.py
+33
-60
spconv/pytorch/quantization/intrinsic/__init__.py
spconv/pytorch/quantization/intrinsic/__init__.py
+15
-0
spconv/pytorch/quantization/intrinsic/modules.py
spconv/pytorch/quantization/intrinsic/modules.py
+30
-3
spconv/pytorch/quantization/intrinsic/qat/__init__.py
spconv/pytorch/quantization/intrinsic/qat/__init__.py
+15
-0
spconv/pytorch/quantization/intrinsic/qat/modules.py
spconv/pytorch/quantization/intrinsic/qat/modules.py
+216
-6
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
+15
-0
No files found.
example/mnist/mnist_qat.py
View file @
b1c57a31
...
@@ -23,33 +23,95 @@ from torchvision import datasets, transforms
...
@@ -23,33 +23,95 @@ from torchvision import datasets, transforms
from
torch.optim.lr_scheduler
import
StepLR
from
torch.optim.lr_scheduler
import
StepLR
import
contextlib
import
contextlib
import
torch.cuda.amp
import
torch.cuda.amp
import
torch.ao.quantization
from
torch.ao.quantization
import
QuantStub
,
DeQuantStub
import
torch.ao.quantization.quantize_fx
as
qfx
from
spconv.pytorch.quantization.fake_q
import
get_default_spconv_qconfig_mapping
import
spconv.pytorch.quantization
as
spconvq
from
spconv.pytorch.quantization
import
get_default_spconv_trt_ptq_qconfig
from
torch.ao.quantization
import
get_default_qconfig_mapping
from
spconv.pytorch.quantization.backend_cfg
import
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
from
torch.ao.quantization.fx._lower_to_native_backend
import
STATIC_LOWER_FUSED_MODULE_MAP
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
identity_ctx
():
def
identity_ctx
():
yield
yield
class
SubMConvBNReLU
(
spconv
.
SparseSequential
):
def
__init__
(
self
,
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
groups
=
1
):
padding
=
(
kernel_size
-
1
)
//
2
super
(
SubMConvBNReLU
,
self
).
__init__
(
spconv
.
SubMConv2d
(
in_planes
,
out_planes
,
kernel_size
,
stride
,
padding
,
groups
=
groups
,
bias
=
False
),
nn
.
BatchNorm1d
(
out_planes
,
momentum
=
0.1
),
# Replace with ReLU
nn
.
ReLU
(
inplace
=
False
)
)
class
SparseConvBNReLU
(
spconv
.
SparseSequential
):
def
__init__
(
self
,
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
groups
=
1
):
padding
=
(
kernel_size
-
1
)
//
2
super
(
SparseConvBNReLU
,
self
).
__init__
(
spconv
.
SparseConv2d
(
in_planes
,
out_planes
,
kernel_size
,
stride
,
padding
,
groups
=
groups
,
bias
=
False
),
nn
.
BatchNorm1d
(
out_planes
,
momentum
=
0.1
),
# Replace with ReLU
nn
.
ReLU
(
inplace
=
False
)
)
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
super
(
Net
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
self
.
net
=
spconv
.
SparseSequential
(
nn
.
BatchNorm1d
(
1
),
SubMConvBNReLU
(
1
,
32
,
3
),
spconv
.
SubMConv2d
(
1
,
32
,
3
,
1
),
SubMConvBNReLU
(
32
,
64
,
3
),
nn
.
ReLU
(),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
spconv
.
SubMConv2d
(
32
,
64
,
3
,
1
),
spconv
.
ToDense
(),
nn
.
ReLU
(),
)
spconv
.
SparseConv2d
(
64
,
64
,
2
,
2
),
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
x_sp
:
spconv
.
SparseConvTensor
):
# def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
# x = self.quant(x)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
# x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x
=
self
.
net
(
x_sp
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
dropout1
(
x
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
# x = self.dequant(x)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
class
NetV2
(
nn
.
Module
):
def
__init__
(
self
):
super
(
NetV2
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
SubMConvBNReLU
(
1
,
32
,
3
),
SubMConvBNReLU
(
32
,
64
,
3
),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
spconv
.
ToDense
(),
spconv
.
ToDense
(),
)
)
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
features
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
):
# x: [N, 28, 28, 1], must be NHWC tensor
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
x
.
reshape
(
-
1
,
28
,
28
,
1
))
x
=
self
.
quant
(
features
)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp
=
spconv
.
SparseConvTensor
(
features
,
indices
,
[
28
,
28
],
batch_size
)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x
=
self
.
net
(
x_sp
)
x
=
self
.
net
(
x_sp
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
torch
.
flatten
(
x
,
1
)
...
@@ -58,10 +120,93 @@ class Net(nn.Module):
...
@@ -58,10 +120,93 @@ class Net(nn.Module):
x
=
F
.
relu
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
class
NetPTQ
(
nn
.
Module
):
"""pytorch currently don't support cuda int8 inference, so
we only use sparse ops here.
"""
def
__init__
(
self
):
super
(
NetPTQ
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
SubMConvBNReLU
(
1
,
32
,
3
),
SubMConvBNReLU
(
32
,
64
,
3
),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 14x14
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 7x7
SparseConvBNReLU
(
64
,
64
,
3
,
2
,
1
),
# 4x4
spconv
.
SparseConv2d
(
64
,
10
,
4
,
4
),
spconv
.
ToDense
(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
features
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
):
# x: [N, 28, 28, 1], must be NHWC tensor
features
=
self
.
quant
(
features
)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp
=
spconv
.
SparseConvTensor
(
features
,
indices
,
[
28
,
28
],
batch_size
)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp
=
self
.
net
(
x_sp
)
# print(x_sp.shape)
x
=
x_sp
x
=
torch
.
flatten
(
x
,
1
)
# x_res = torch.zeros_like(x)
# x_res[x_sp.indices[:, 0].long()] = x
# x = x_res
# x = torch.flatten(x, 1)
# x = self.dropout1(x)
# x = self.fc1(x)
# x = F.relu(x)
# x = self.dropout2(x)
# x = self.fc2(x)
# print(x_sp.features.shape, x_sp.spatial_shape)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
return
output
class
NetDense
(
nn
.
Module
):
def
__init__
(
self
):
super
(
NetDense
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
3
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
32
,
64
,
3
,
1
)
self
.
dropout1
=
nn
.
Dropout
(
0.25
)
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
self
.
fc1
=
nn
.
Linear
(
9216
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
x
):
x
=
self
.
quant
(
x
)
x
=
self
.
conv1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
x
,
2
)
x
=
self
.
dropout1
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
def
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
):
def
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
):
model
.
train
()
model
.
train
()
scaler
=
torch
.
cuda
.
amp
.
grad_scaler
.
GradScaler
()
scaler
=
torch
.
cuda
.
amp
.
grad_scaler
.
GradScaler
()
...
@@ -72,7 +217,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
...
@@ -72,7 +217,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
with
amp_ctx
:
with
amp_ctx
:
if
args
.
sparse
:
data_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
data
.
reshape
(
-
1
,
28
,
28
,
1
))
# output = model(data_sp)
output
=
model
(
data_sp
.
features
,
data_sp
.
indices
,
data_sp
.
batch_size
)
else
:
output
=
model
(
data
)
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
=
F
.
nll_loss
(
output
,
target
)
scale
=
1.0
scale
=
1.0
if
args
.
fp16
:
if
args
.
fp16
:
...
@@ -114,7 +265,11 @@ def test(args, model, device, test_loader):
...
@@ -114,7 +265,11 @@ def test(args, model, device, test_loader):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
with
amp_ctx
:
with
amp_ctx
:
if
args
.
sparse
:
data_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
data
.
reshape
(
-
1
,
28
,
28
,
1
))
# output = model(data_sp)
output
=
model
(
data_sp
.
features
,
data_sp
.
indices
,
data_sp
.
batch_size
)
else
:
output
=
model
(
data
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
# sum up batch loss
output
,
target
,
reduction
=
'sum'
).
item
()
# sum up batch loss
...
@@ -131,6 +286,19 @@ def test(args, model, device, test_loader):
...
@@ -131,6 +286,19 @@ def test(args, model, device, test_loader):
100.
*
correct
/
len
(
test_loader
.
dataset
)))
100.
*
correct
/
len
(
test_loader
.
dataset
)))
def
calibrate
(
args
,
model
:
torch
.
nn
.
Module
,
data_loader
,
device
):
model
.
eval
()
with
torch
.
no_grad
():
for
image
,
target
in
data_loader
:
image
=
image
.
to
(
device
)
if
args
.
sparse
:
data_sp
=
spconv
.
SparseConvTensor
.
from_dense
(
image
.
reshape
(
-
1
,
28
,
28
,
1
))
output
=
model
(
data_sp
.
features
,
data_sp
.
indices
,
data_sp
.
batch_size
)
# output = model(data_sp)
else
:
output
=
model
(
image
)
def
main
():
def
main
():
# Training settings
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
...
@@ -146,7 +314,7 @@ def main():
...
@@ -146,7 +314,7 @@ def main():
help
=
'input batch size for testing (default: 1000)'
)
help
=
'input batch size for testing (default: 1000)'
)
parser
.
add_argument
(
'--epochs'
,
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
type
=
int
,
default
=
1
4
,
default
=
1
,
metavar
=
'N'
,
metavar
=
'N'
,
help
=
'number of epochs to train (default: 14)'
)
help
=
'number of epochs to train (default: 14)'
)
parser
.
add_argument
(
'--lr'
,
parser
.
add_argument
(
'--lr'
,
...
@@ -168,6 +336,10 @@ def main():
...
@@ -168,6 +336,10 @@ def main():
default
=
1
,
default
=
1
,
metavar
=
'S'
,
metavar
=
'S'
,
help
=
'random seed (default: 1)'
)
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--sparse'
,
action
=
'store_true'
,
default
=
True
,
help
=
'use sparse conv network instead of dense'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--log-interval'
,
'--log-interval'
,
type
=
int
,
type
=
int
,
...
@@ -190,8 +362,14 @@ def main():
...
@@ -190,8 +362,14 @@ def main():
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
qdevice
=
torch
.
device
(
"cuda"
if
use_cuda
and
args
.
sparse
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
if
args
.
sparse
:
model
=
NetPTQ
().
to
(
device
)
else
:
model
=
NetDense
().
to
(
device
)
optimizer
=
optim
.
Adadelta
(
model
.
parameters
(),
lr
=
args
.
lr
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
datasets
.
MNIST
(
'../data'
,
'../data'
,
...
@@ -218,17 +396,46 @@ def main():
...
@@ -218,17 +396,46 @@ def main():
shuffle
=
True
,
shuffle
=
True
,
**
kwargs
)
**
kwargs
)
model
=
Net
().
to
(
device
)
optimizer
=
optim
.
Adadelta
(
model
.
parameters
(),
lr
=
args
.
lr
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
1
,
gamma
=
args
.
gamma
)
scheduler
=
StepLR
(
optimizer
,
step_size
=
1
,
gamma
=
args
.
gamma
)
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
)
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
)
test
(
args
,
model
,
device
,
test_loader
)
test
(
args
,
model
,
device
,
test_loader
)
scheduler
.
step
()
scheduler
.
step
()
# if args.save_model:
# torch.save(model.state_dict(), "mnist_cnn.pt")
model
.
eval
()
STATIC_LOWER_FUSED_MODULE_MAP
.
update
(
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
)
if
not
args
.
sparse
:
model
=
model
.
cpu
()
# qconfig_mapping_default = get_default_qconfig_mapping("x86")
qconfig_mapping
=
get_default_spconv_qconfig_mapping
(
False
)
prepare_cfg
=
spconvq
.
get_spconv_prepare_custom_config
()
backend_cfg
=
spconvq
.
get_spconv_backend_config
()
convert_cfg
=
spconvq
.
get_spconv_convert_custom_config
()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model
=
qfx
.
prepare_fx
(
model
,
qconfig_mapping
,
(),
backend_config
=
backend_cfg
,
prepare_custom_config
=
prepare_cfg
)
# prepared_model.print_readable()
print
([
type
(
m
)
for
m
in
prepared_model
.
modules
()])
print
(
prepared_model
)
# calibrate: run model with some inputs
calibrate
(
args
,
prepared_model
,
test_loader
,
qdevice
)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model
=
qfx
.
convert_to_reference_fx
(
prepared_model
,
convert_cfg
,
qconfig_mapping
=
qconfig_mapping
,
backend_config
=
backend_cfg
)
print
([
type
(
m
)
for
m
in
converted_model
.
modules
()])
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
# model.qconfig = get_default_spconv_trt_ptq_qconfig()
# prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# convert_custom_config_dict = spconvq.get_convert_custom_config()
# torch.ao.quantization.prepare(model, inplace=True)
# print('Post Training Quantization Prepare: Inserting Observers')
# print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# test(args, model, device, test_loader)
print
(
converted_model
)
if
args
.
save_model
:
test
(
args
,
converted_model
,
qdevice
,
test_loader
)
torch
.
save
(
model
.
state_dict
(),
"mnist_cnn.pt"
)
breakpoint
(
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
spconv/build.py
View file @
b1c57a31
...
@@ -21,7 +21,7 @@ from ccimport.compat import InWindows
...
@@ -21,7 +21,7 @@ from ccimport.compat import InWindows
from
.constants
import
PACKAGE_NAME
,
PACKAGE_ROOT
,
DISABLE_JIT
from
.constants
import
PACKAGE_NAME
,
PACKAGE_ROOT
,
DISABLE_JIT
if
project_is_installed
(
PACKAGE_NAME
)
and
project_is_editable
(
if
project_is_installed
(
PACKAGE_NAME
)
and
project_is_editable
(
PACKAGE_NAME
)
and
not
DISABLE_JIT
:
PACKAGE_NAME
)
and
not
DISABLE_JIT
and
False
:
from
spconv.core
import
SHUFFLE_SIMT_PARAMS
,
SHUFFLE_VOLTA_PARAMS
,
SHUFFLE_TURING_PARAMS
,
SHUFFLE_AMPERE_PARAMS
from
spconv.core
import
SHUFFLE_SIMT_PARAMS
,
SHUFFLE_VOLTA_PARAMS
,
SHUFFLE_TURING_PARAMS
,
SHUFFLE_AMPERE_PARAMS
from
spconv.core
import
IMPLGEMM_SIMT_PARAMS
,
IMPLGEMM_VOLTA_PARAMS
,
IMPLGEMM_TURING_PARAMS
,
IMPLGEMM_AMPERE_PARAMS
from
spconv.core
import
IMPLGEMM_SIMT_PARAMS
,
IMPLGEMM_VOLTA_PARAMS
,
IMPLGEMM_TURING_PARAMS
,
IMPLGEMM_AMPERE_PARAMS
...
...
spconv/core.py
View file @
b1c57a31
...
@@ -699,6 +699,22 @@ IMPLGEMM_AMPERE_PARAMS = [
...
@@ -699,6 +699,22 @@ IMPLGEMM_AMPERE_PARAMS = [
is_nvrtc
=
True
,
is_nvrtc
=
True
,
int8_inference
=
True
),
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
...
@@ -797,7 +813,21 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -797,7 +813,21 @@ IMPLGEMM_TURING_PARAMS = [
access_per_vector
=
1
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
is_nvrtc
=
True
,
int8_inference
=
True
),
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
ConvIterAlgo
.
Optimized
,
...
...
spconv/core_cc/csrc/sparse/alloc.pyi
View file @
b1c57a31
...
@@ -2,7 +2,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
...
@@ -2,7 +2,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview import Tensor
class ExternalAllocator:
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False
, scale: float = 1.0
) -> Tensor:
"""
"""
Args:
Args:
name:
name:
...
@@ -11,9 +11,10 @@ class ExternalAllocator:
...
@@ -11,9 +11,10 @@ class ExternalAllocator:
device:
device:
stream:
stream:
is_temp_memory:
is_temp_memory:
scale:
"""
"""
...
...
def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False
, scale: float = 1.0
) -> Tensor:
"""
"""
Args:
Args:
name:
name:
...
@@ -22,6 +23,7 @@ class ExternalAllocator:
...
@@ -22,6 +23,7 @@ class ExternalAllocator:
device:
device:
stream:
stream:
is_temp_memory:
is_temp_memory:
scale:
"""
"""
...
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
...
...
spconv/core_cc/csrc/sparse/convops/spops.pyi
View file @
b1c57a31
...
@@ -63,7 +63,7 @@ class ConvGemmOps:
...
@@ -63,7 +63,7 @@ class ConvGemmOps:
"""
"""
...
...
@staticmethod
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0) -> Tuple[int, Any]:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0
, output_dtype: int = -1
) -> Tuple[int, Any]:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -91,6 +91,7 @@ class ConvGemmOps:
...
@@ -91,6 +91,7 @@ class ConvGemmOps:
scale:
scale:
output_add:
output_add:
output_add_scale:
output_add_scale:
output_dtype:
"""
"""
...
...
@staticmethod
@staticmethod
...
...
spconv/csrc/sparse/alloc.py
View file @
b1c57a31
...
@@ -56,7 +56,7 @@ class ExternalAllocator(pccm.Class):
...
@@ -56,7 +56,7 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"scale"
,
"float"
,
"1.0"
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
...
@@ -69,7 +69,7 @@ class ExternalAllocator(pccm.Class):
...
@@ -69,7 +69,7 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"scale"
,
"float"
,
"1.0"
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
...
...
spconv/csrc/sparse/convops.py
View file @
b1c57a31
...
@@ -2127,10 +2127,10 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2127,10 +2127,10 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
if (is_subm){{
if (is_subm){{
out_features = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
out_features = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int
, false /*is_temp*/, output_scale
);
}}else{{
}}else{{
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int
, false /*is_temp*/, output_scale
);
}}
}}
// auto start_ev = tv::CUDAEvent();
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
// start_ev.record(stream_int);
...
...
spconv/csrc/sparse/maxpool.py
View file @
b1c57a31
...
@@ -311,7 +311,7 @@ class IndiceMaxPool(pccm.Class):
...
@@ -311,7 +311,7 @@ class IndiceMaxPool(pccm.Class):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto nhot = out_inds.dim(0);
auto nhot = out_inds.dim(0);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t
, int8_t
>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
int num_blocks_X = std::get<0>(launchdims);
...
@@ -350,7 +350,7 @@ class IndiceMaxPool(pccm.Class):
...
@@ -350,7 +350,7 @@ class IndiceMaxPool(pccm.Class):
tv::check_shape(in, {{-1, out.dim(1)}});
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t
, int8_t
>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
int num_blocks_X = std::get<0>(launchdims);
...
@@ -478,7 +478,7 @@ class IndiceMaxPool(pccm.Class):
...
@@ -478,7 +478,7 @@ class IndiceMaxPool(pccm.Class):
tv::check_shape(in, {{-1, out.dim(1)}});
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t
, int8_t
>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
int num_blocks_X = std::get<0>(launchdims);
...
...
spconv/pytorch/conv.py
View file @
b1c57a31
This diff is collapsed.
Click to expand it.
spconv/pytorch/core.py
View file @
b1c57a31
...
@@ -128,7 +128,7 @@ def scatter_nd(indices, updates, shape):
...
@@ -128,7 +128,7 @@ def scatter_nd(indices, updates, shape):
return
ret
return
ret
# ProxyableClassMeta is used for
TensorRT conversion in future.
# ProxyableClassMeta is used for
torch.fx
class
SparseConvTensor
(
metaclass
=
SpConvTensorMeta
):
class
SparseConvTensor
(
metaclass
=
SpConvTensorMeta
):
def
__init__
(
self
,
def
__init__
(
self
,
features
:
torch
.
Tensor
,
features
:
torch
.
Tensor
,
...
@@ -181,8 +181,15 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
...
@@ -181,8 +181,15 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
force_algo
=
force_algo
self
.
force_algo
=
force_algo
# for simple int8 torch inference
self
.
int8_scale
:
Optional
[
float
]
=
None
@
property
def
is_quantized
(
self
):
return
self
.
features
.
dtype
==
torch
.
qint8
def
q_scale
(
self
):
if
self
.
is_quantized
:
return
self
.
features
.
q_scale
()
raise
ValueError
(
"sparse tensor must be quantized"
)
def
replace_feature
(
self
,
feature
:
torch
.
Tensor
):
def
replace_feature
(
self
,
feature
:
torch
.
Tensor
):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
...
@@ -220,7 +227,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
...
@@ -220,7 +227,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
x must be NHWC tensor, channel last
x must be NHWC tensor, channel last
"""
"""
x_sp
=
x
.
to_sparse
(
x
.
ndim
-
1
)
x_sp
=
x
.
to_sparse
(
x
.
ndim
-
1
)
spatial_shape
=
list
(
x_sp
.
shape
[
1
:
-
1
]
)
spatial_shape
=
x_sp
.
shape
[
1
:
-
1
]
batch_size
=
x_sp
.
shape
[
0
]
batch_size
=
x_sp
.
shape
[
0
]
indices_th
=
x_sp
.
indices
().
permute
(
1
,
0
).
contiguous
().
int
()
indices_th
=
x_sp
.
indices
().
permute
(
1
,
0
).
contiguous
().
int
()
features_th
=
x_sp
.
values
()
features_th
=
x_sp
.
values
()
...
...
spconv/pytorch/cppcore.py
View file @
b1c57a31
...
@@ -34,6 +34,7 @@ _TORCH_DTYPE_TO_TV = {
...
@@ -34,6 +34,7 @@ _TORCH_DTYPE_TO_TV = {
torch
.
int8
:
tv
.
int8
,
torch
.
int8
:
tv
.
int8
,
torch
.
int16
:
tv
.
int16
,
torch
.
int16
:
tv
.
int16
,
torch
.
uint8
:
tv
.
uint8
,
torch
.
uint8
:
tv
.
uint8
,
torch
.
qint8
:
tv
.
int8
,
}
}
_TORCH_UINT_WORKAROUNDS
=
{
_TORCH_UINT_WORKAROUNDS
=
{
...
@@ -42,6 +43,8 @@ _TORCH_UINT_WORKAROUNDS = {
...
@@ -42,6 +43,8 @@ _TORCH_UINT_WORKAROUNDS = {
tv
.
uint64
:
tv
.
int64
tv
.
uint64
:
tv
.
int64
}
}
_TH_QTYPES
=
{
torch
.
qint8
}
_TV_DTYPE_TO_TORCH
=
{
v
:
k
for
k
,
v
in
_TORCH_DTYPE_TO_TV
.
items
()}
_TV_DTYPE_TO_TORCH
=
{
v
:
k
for
k
,
v
in
_TORCH_DTYPE_TO_TV
.
items
()}
_TV_DTYPE_TO_TORCH
.
update
({
_TV_DTYPE_TO_TORCH
.
update
({
tv
.
uint32
:
torch
.
int32
,
tv
.
uint32
:
torch
.
int32
,
...
@@ -50,6 +53,9 @@ _TV_DTYPE_TO_TORCH.update({
...
@@ -50,6 +53,9 @@ _TV_DTYPE_TO_TORCH.update({
})
})
_TV_DTYPE_TO_TORCHQ
=
_TV_DTYPE_TO_TORCH
.
copy
()
_TV_DTYPE_TO_TORCHQ
[
tv
.
int8
]
=
torch
.
qint8
_ALL_INTS
=
{
_ALL_INTS
=
{
tv
.
int32
,
tv
.
int16
,
tv
.
int8
,
tv
.
int64
,
tv
.
uint64
,
tv
.
uint8
,
tv
.
uint32
,
tv
.
int32
,
tv
.
int16
,
tv
.
int8
,
tv
.
int64
,
tv
.
uint64
,
tv
.
uint8
,
tv
.
uint32
,
tv
.
uint16
tv
.
uint16
...
@@ -105,23 +111,31 @@ def get_arch():
...
@@ -105,23 +111,31 @@ def get_arch():
class
TorchAllocator
(
ExternalAllocator
):
class
TorchAllocator
(
ExternalAllocator
):
def
__init__
(
self
,
gpudevice
:
torch
.
device
)
->
None
:
def
__init__
(
self
,
gpudevice
:
torch
.
device
,
is_quantized
:
bool
=
False
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gpudevice
=
gpudevice
self
.
gpudevice
=
gpudevice
self
.
cpudevice
=
torch
.
device
(
"cpu"
)
self
.
cpudevice
=
torch
.
device
(
"cpu"
)
self
.
allocated
:
Dict
[
Union
[
str
,
int
],
torch
.
Tensor
]
=
{}
self
.
allocated
:
Dict
[
Union
[
str
,
int
],
torch
.
Tensor
]
=
{}
self
.
is_quantized
=
is_quantized
self
.
_tv_dtype_to_torch
=
_TV_DTYPE_TO_TORCH
if
is_quantized
:
self
.
_tv_dtype_to_torch
=
_TV_DTYPE_TO_TORCHQ
def
zeros
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
def
zeros
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
,
scale
:
float
=
1.0
)
->
tv
.
Tensor
:
# TODO free memory by name if its already free by pointer.
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit.
# provide a name if you want to access it after c++ function exit.
dtype_bkp
=
dtype
dtype_bkp
=
dtype
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
self
.
_tv_dtype_to_torch
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
ten
=
torch
.
zeros
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
if
self
.
is_quantized
:
ten
=
torch
.
_empty_affine_quantized
(
shape
,
scale
=
scale
,
zero_point
=
0
,
dtype
=
th_dtype
,
device
=
dev
)
else
:
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
).
zero_
()
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
if
name
and
not
is_temp_memory
:
if
name
and
not
is_temp_memory
:
...
@@ -129,13 +143,16 @@ class TorchAllocator(ExternalAllocator):
...
@@ -129,13 +143,16 @@ class TorchAllocator(ExternalAllocator):
return
ten_tv
return
ten_tv
def
empty
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
def
empty
(
self
,
name
:
str
,
shape
:
List
[
int
],
dtype
:
int
,
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
)
->
tv
.
Tensor
:
device
:
int
,
stream
:
int
=
0
,
is_temp_memory
:
bool
=
False
,
scale
:
float
=
1.0
)
->
tv
.
Tensor
:
dtype_bkp
=
dtype
dtype_bkp
=
dtype
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
self
.
_tv_dtype_to_torch
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
if
self
.
is_quantized
:
ten
=
torch
.
_empty_affine_quantized
(
shape
,
scale
=
scale
,
zero_point
=
0
,
dtype
=
th_dtype
,
device
=
dev
)
else
:
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
...
@@ -148,11 +165,13 @@ class TorchAllocator(ExternalAllocator):
...
@@ -148,11 +165,13 @@ class TorchAllocator(ExternalAllocator):
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
dtype_bkp
=
dtype
dtype_bkp
=
dtype
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
self
.
_tv_dtype_to_torch
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
if
self
.
is_quantized
:
assert
th_dtype
not
in
_TH_QTYPES
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
...
@@ -165,11 +184,13 @@ class TorchAllocator(ExternalAllocator):
...
@@ -165,11 +184,13 @@ class TorchAllocator(ExternalAllocator):
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
if
dtype
in
_TORCH_UINT_WORKAROUNDS
and
value
<
0
:
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
raise
NotImplementedError
(
"you can't use full for unsigned dtypes"
)
dtype_bkp
=
dtype
dtype_bkp
=
dtype
th_dtype
=
_TV_DTYPE_TO_TORCH
[
dtype
]
th_dtype
=
self
.
_tv_dtype_to_torch
[
dtype
]
if
device
==
-
1
:
if
device
==
-
1
:
dev
=
self
.
cpudevice
dev
=
self
.
cpudevice
else
:
else
:
dev
=
self
.
gpudevice
dev
=
self
.
gpudevice
if
self
.
is_quantized
:
assert
th_dtype
not
in
_TH_QTYPES
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten
=
torch
.
full
(
shape
,
value
,
dtype
=
th_dtype
,
device
=
dev
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
...
...
spconv/pytorch/quantization/__init__.py
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.backend_cfg
import
(
get_spconv_backend_config
,
get_spconv_prepare_custom_config
,
get_spconv_convert_custom_config
)
from
.fake_q
import
(
get_default_spconv_trt_ptq_qconfig
,
get_default_spconv_trt_qat_qconfig
)
from
.qmapping
import
(
get_spconv_fmod_to_qat_mapping
,
get_spconv_qat_to_static_mapping
)
spconv/pytorch/quantization/backend_cfg.py
0 → 100644
View file @
b1c57a31
from
collections
import
namedtuple
from
typing
import
List
,
Dict
,
Union
,
Type
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.ao.quantization.backend_config
import
(
BackendConfig
,
BackendPatternConfig
,
DTypeConfig
,
ObservationType
,
get_tensorrt_backend_config
)
from
torch.ao.quantization.fx.custom_config
import
(
ConvertCustomConfig
,
FuseCustomConfig
,
PrepareCustomConfig
)
from
torch.ao.nn.quantized.modules.utils
import
WeightedQuantizedModule
import
spconv.pytorch.conv
as
sconvmod
import
spconv.pytorch.quantization.intrinsic
as
snni
import
spconv.pytorch.quantization.intrinsic.qat
as
snniqat
import
spconv.pytorch.quantization.intrinsic.quantized
as
snniq
import
spconv.pytorch.quantization.quantized
as
snnq
import
spconv.pytorch.quantization.quantized.reference
as
snnqr
from
spconv.pytorch.constants
import
PYTORCH_VERSION
from
spconv.pytorch.quantization.fuse_mapping
import
(
fuse_conv_bn
,
fuse_conv_bn_relu
)
from
spconv.pytorch
import
ToDense
_SpConvMetadataDef
=
namedtuple
(
"_ConvMetadata"
,
[
"root"
,
"bn"
,
"reference"
,
"fused_conv_relu"
,
"fused_conv_bn"
,
"fused_conv_bn_relu"
,
"qat"
,
"relu_qat"
,
"bn_qat"
,
"bn_relu_qat"
])
_SpConvMetadatas
:
List
[
_SpConvMetadataDef
]
=
[]
for
t
in
sconvmod
.
DEFAULT_SPARSE_CONV_TYPES
:
_SpConvMetadatas
.
append
(
_SpConvMetadataDef
(
t
,
nn
.
BatchNorm1d
,
snnqr
.
SpConv
,
snni
.
SpconvReLUNd
,
snni
.
SpconvBnNd
,
snni
.
SpconvBnReLUNd
,
snniqat
.
SparseConv
,
snniqat
.
SparseConvReLU
,
snniqat
.
SparseConvBn
,
snniqat
.
SparseConvBnReLU
))
_SpConvMetadatas
.
append
(
_SpConvMetadataDef
(
sconvmod
.
SparseConvolution
,
nn
.
BatchNorm1d
,
snnqr
.
SpConv
,
snni
.
SpconvReLUNd
,
snni
.
SpconvBnNd
,
snni
.
SpconvBnReLUNd
,
snniqat
.
SparseConv
,
snniqat
.
SparseConvReLU
,
snniqat
.
SparseConvBn
,
snniqat
.
SparseConvBnReLU
))
def
_sequential_wrapper2
(
sequential
):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules
"""
def
fuser_method
(
is_qat
,
m1
,
m2
):
return
sequential
(
m1
,
m2
)
return
fuser_method
# new cfg remove reverse pattern.
def
_get_spconv_configs
(
dtype_configs
):
"""
Return all configs related to conv modules and ops.
"""
conv_configs
=
[]
observation_type
=
ObservationType
.
OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
if
PYTORCH_VERSION
<=
[
1
,
13
,
1
]:
from
torch.ao.quantization.fuser_method_mappings
import
(
reverse2
,
reverse3
,
reverse_sequential_wrapper2
)
for
convs
in
_SpConvMetadatas
:
# (1) Single conv modules/functions
# -----------------------------------
# conv module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
root
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
qat
))
# conv qat module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs
.
append
(
BackendPatternConfig
((
torch
.
nn
.
ReLU
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# conv relu fusion, conv module + functional relu
conv_configs
.
append
(
BackendPatternConfig
((
F
.
relu
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
relu_qat
))
# conv relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# 2.3 functional conv + relu configs
# fused conv relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
bn
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse2
(
fuse_conv_bn
))
.
set_fused_module
(
convs
.
fused_conv_bn
))
# conv + bn + relu module fusion
conv_configs
.
append
(
BackendPatternConfig
((
nn
.
ReLU
,
(
convs
.
bn
,
convs
.
root
)))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse3
(
fuse_conv_bn_relu
))
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# conv + bn + relu functional fusion
conv_configs
.
append
(
BackendPatternConfig
((
F
.
relu
,
(
convs
.
bn
,
convs
.
root
)))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_fuser_method
(
reverse3
(
fuse_conv_bn_relu
))
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_qat
))
# fused conv bn relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_relu_qat
))
# conv bn, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# conv bn relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (4) conv transpose and its fusion
# 4.1 conv transpose config
# conv_configs.append(
# BackendPatternConfig(convs.transpose)
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# # 4.2 conv transpose + bn fusion
# conv_configs.append(
# BackendPatternConfig((convs.bn, convs.transpose))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(reverse2(fuse_conv_bn))
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
return
conv_configs
else
:
for
convs
in
_SpConvMetadatas
:
# (1) Single conv modules/functions
# -----------------------------------
# conv module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
root
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
qat
))
# conv qat module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
torch
.
nn
.
ReLU
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# conv relu fusion, conv module + functional relu
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
F
.
relu
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
relu_qat
))
# conv relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# fused conv relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
fuse_conv_bn
)
.
set_fused_module
(
convs
.
fused_conv_bn
))
# conv + bn + relu module fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
,
nn
.
ReLU
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
fuse_conv_bn_relu
)
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# conv + bn + relu functional fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
,
F
.
relu
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_fuser_method
(
fuse_conv_bn_relu
)
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_qat
))
# fused conv bn relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_relu_qat
))
# conv bn, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# conv bn relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# # (4) conv transpose and its fusion
# # 4.1 conv transpose config
# conv_configs.append(
# BackendPatternConfig(convs.transpose)
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# # 4.2 conv transpose + bn fusion
# conv_configs.append(
# BackendPatternConfig((convs.transpose, convs.bn))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(fuse_conv_bn)
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
return
conv_configs
weighted_op_qint8_dtype_config
=
DTypeConfig
(
input_dtype
=
torch
.
qint8
,
output_dtype
=
torch
.
qint8
,
weight_dtype
=
torch
.
qint8
,
bias_dtype
=
torch
.
float
,
)
conv_dtype_configs
=
[
weighted_op_qint8_dtype_config
,
]
_to_dense_cfg
=
(
BackendPatternConfig
(
ToDense
)
.
set_observation_type
(
ObservationType
.
OUTPUT_SHARE_OBSERVER_WITH_INPUT
))
backend_config
=
get_tensorrt_backend_config
()
\
.
set_backend_pattern_configs
(
_get_spconv_configs
(
conv_dtype_configs
)
+
[
_to_dense_cfg
])
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Tuple
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]]
=
{
snni
.
SpconvReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvReLU
),
}
def
get_spconv_backend_config
():
return
backend_config
def
get_spconv_prepare_custom_config
():
cfg
=
PrepareCustomConfig
()
cfg
.
non_traceable_module_classes
=
[
*
sconvmod
.
DEFAULT_SPARSE_CONV_TYPES
]
return
cfg
def
get_spconv_convert_custom_config
():
cfg
=
ConvertCustomConfig
()
cfg
.
set_observed_to_quantized_mapping
(
snni
.
SpconvReLUNd
,
snniq
.
SparseConvReLU
)
# cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU)
return
cfg
\ No newline at end of file
spconv/pytorch/quantization/fake_q.py
View file @
b1c57a31
from
torch.ao.quantization.fake_quantize
import
FusedMovingAvgObsFakeQuantize
,
fused_wt_fake_quant_range_neg_127_to_127
from
spconv.pytorch.core
import
SparseConvTensor
import
torch
import
torch
from
torch.ao.quantization.qconfig
import
QConfig
from
torch.ao.quantization.fake_quantize
import
(
from
torch.ao.quantization.observer
import
MovingAverageMinMaxObserver
FixedQParamsFakeQuantize
,
FusedMovingAvgObsFakeQuantize
,
FakeQuantize
,
default_fused_per_channel_wt_fake_quant
,
default_weight_fake_quant
,
default_per_channel_weight_fake_quant
)
from
torch.ao.quantization.observer
import
(
HistogramObserver
,
MovingAverageMinMaxObserver
,
default_weight_observer
,
default_placeholder_observer
,
default_per_channel_weight_observer
)
from
torch.ao.quantization.qconfig
import
QConfig
,
QConfigAny
,
default_reuse_input_qconfig
from
torch.ao.quantization.qconfig_mapping
import
QConfigMapping
,
_FIXED_QPARAMS_OP_TO_OBSERVER
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Union
,
List
from
torch.ao.quantization
import
get_default_qconfig
from
spconv.pytorch.core
import
SparseConvTensor
__all__
=
[
"get_default_spconv_trt_ptq_qconfig"
,
"get_default_spconv_trt_qat_qconfig"
]
class
SparseFusedMovingAvgObsFakeQuantize
(
FusedMovingAvgObsFakeQuantize
):
class
SparseFusedMovingAvgObsFakeQuantize
(
FusedMovingAvgObsFakeQuantize
):
def
forward
(
self
,
input
:
SparseConvTensor
):
def
forward
(
self
,
input
:
Union
[
SparseConvTensor
,
torch
.
Tensor
]):
if
isinstance
(
input
,
SparseConvTensor
):
# add lines to support spconv
x
=
input
.
features
res_features
=
super
().
forward
(
x
)
return
input
.
replace_feature
(
res_features
)
else
:
return
super
().
forward
(
input
)
# class SparseMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
# def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
# if isinstance(input, SparseConvTensor):
# # add lines to support spconv
# x = input.features
# res_features = super().forward(x)
# return input.replace_feature(res_features)
# else:
# return super().forward(input)
class
SparseHistogramObserver
(
HistogramObserver
):
def
forward
(
self
,
input
:
Union
[
SparseConvTensor
,
torch
.
Tensor
]):
if
isinstance
(
input
,
SparseConvTensor
):
# add lines to support spconv
# add lines to support spconv
x
=
input
.
features
x
=
input
.
features
res_features
=
super
().
forward
(
x
)
res_features
=
super
().
forward
(
x
)
return
input
.
replace_feature
(
res_features
)
return
input
.
replace_feature
(
res_features
)
else
:
return
super
().
forward
(
input
)
default_symmetric_spconv_ptq_qconfig
=
QConfig
(
activation
=
SparseHistogramObserver
.
with_args
(
quant_min
=-
128
,
quant_max
=
127
,
dtype
=
torch
.
qint8
,
reduce_range
=
False
,
qscheme
=
torch
.
per_tensor_symmetric
,
eps
=
2
**
-
12
),
weight
=
default_per_channel_weight_observer
)
# default_symmetric_ptq_qconfig = QConfig(
# activation=HistogramObserver.with_args(quant_min=-128,
# quant_max=127,
# dtype=torch.qint8,
# reduce_range=False,
# eps=2 ** -12),
# weight=default_per_channel_weight_observer)
default_symmetric_spconv_qat_qconfig
=
QConfig
(
default_symmetric_spconv_qat_qconfig
=
QConfig
(
activation
=
SparseFusedMovingAvgObsFakeQuantize
.
with_args
(
observer
=
MovingAverageMinMaxObserver
,
activation
=
SparseFusedMovingAvgObsFakeQuantize
.
with_args
(
observer
=
MovingAverageMinMaxObserver
,
...
@@ -19,5 +69,78 @@ default_symmetric_spconv_qat_qconfig = QConfig(
...
@@ -19,5 +69,78 @@ default_symmetric_spconv_qat_qconfig = QConfig(
quant_max
=
127
,
quant_max
=
127
,
dtype
=
torch
.
qint8
,
dtype
=
torch
.
qint8
,
reduce_range
=
False
,
reduce_range
=
False
,
qscheme
=
torch
.
per_tensor_symmetric
,
eps
=
2
**
-
12
),
eps
=
2
**
-
12
),
weight
=
fused_wt_fake_quant_range_neg_127_to_127
)
weight
=
default_fused_per_channel_wt_fake_quant
)
def
get_default_spconv_trt_ptq_qconfig
(
backend
,
version
):
return
default_symmetric_spconv_ptq_qconfig
def
get_default_spconv_trt_qat_qconfig
(
backend
,
version
):
return
default_symmetric_spconv_qat_qconfig
def
get_default_spconv_qconfig_mapping
(
is_qat
:
bool
,
backend
:
str
=
"x86"
,
version
:
int
=
0
)
->
QConfigMapping
:
"""
From torch.ao.quantization.qconfig_mapping
Return the default QConfigMapping for the given quantization type and backend.
"""
# get_default_qconfig(backend, version)
if
is_qat
:
qconfig
=
get_default_spconv_trt_qat_qconfig
(
backend
,
version
)
else
:
# qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8),
# weight=default_per_channel_weight_observer)
qconfig
=
get_default_spconv_trt_ptq_qconfig
(
backend
,
version
)
default_weight
=
default_weight_fake_quant
if
is_qat
else
default_weight_observer
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
# so we have to modify the weight observer to default_weight_observer or another
# per tensor supported observer.
# see https://github.com/pytorch/pytorch/issues/47535
if
backend
in
(
"fbgemm"
,
"x86"
):
qconfig_transpose
=
QConfig
(
activation
=
qconfig
.
activation
,
weight
=
default_weight
)
else
:
qconfig_transpose
=
qconfig
# currently layernorm only supports float weights
# we have to add this because otherwise there will be a extra quantize-dequantize pair
qconfig_layernorm
=
QConfig
(
activation
=
qconfig
.
activation
,
weight
=
default_placeholder_observer
)
qconfig_mapping
=
QConfigMapping
()
\
.
set_global
(
qconfig
)
\
.
set_object_type
(
"reshape"
,
default_reuse_input_qconfig
)
\
.
set_object_type
(
torch
.
nn
.
ConvTranspose1d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
ConvTranspose2d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
ConvTranspose3d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
conv_transpose1d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
conv_transpose2d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
conv_transpose3d
,
qconfig_transpose
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
layer_norm
,
qconfig_layernorm
)
\
.
set_object_type
(
torch
.
nn
.
LayerNorm
,
qconfig_layernorm
)
\
# Use special observers for ops with fixed qparams
fixed_qparams_observer_to_qconfig
:
Dict
[
Any
,
QConfigAny
]
=
{}
for
fixed_qparams_op
,
observer
in
_FIXED_QPARAMS_OP_TO_OBSERVER
.
items
():
if
observer
in
fixed_qparams_observer_to_qconfig
:
fixed_qparams_qconfig
=
fixed_qparams_observer_to_qconfig
[
observer
]
else
:
if
is_qat
:
activation
=
FixedQParamsFakeQuantize
.
with_args
(
observer
=
observer
)
else
:
activation
=
observer
fixed_qparams_qconfig
=
QConfig
(
activation
=
activation
,
weight
=
default_weight
)
fixed_qparams_observer_to_qconfig
[
observer
]
=
fixed_qparams_qconfig
qconfig_mapping
.
set_object_type
(
fixed_qparams_op
,
fixed_qparams_qconfig
)
# QConfig for fused ops for onednn backend
# Separate ops are required to have the same qconfig as fused ops
# TODO: we should be able to configure qconfig for patterns
if
backend
==
'onednn'
:
qconfig_mapping
.
set_object_type
(
torch
.
nn
.
Linear
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
LeakyReLU
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
leaky_relu
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
Tanh
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
tanh
,
qconfig
)
return
qconfig_mapping
spconv/pytorch/quantization/fuse_mapping.py
View file @
b1c57a31
...
@@ -3,9 +3,9 @@ import torch.nn as nn
...
@@ -3,9 +3,9 @@ import torch.nn as nn
import
spconv.pytorch
as
spconv
import
spconv.pytorch
as
spconv
from
.utils
import
fuse_spconv_bn_eval
from
.utils
import
fuse_spconv_bn_eval
from
.
import
intrinsic
as
snni
from
.
import
intrinsic
as
snni
from
.
conv_fused
import
SparseConvBn
,
SparseConvBnReLU
from
.
intrinsic.qat.modules
import
SparseConvBn
,
SparseConvBnReLU
,
SparseConvBnAddReLU
from
spconv.pytorch.conv
import
DEFAULT_SPARSE_CONV_TYPES
def
fuse_conv_bn
(
conv
,
bn
):
def
fuse_conv_bn
(
is_qat
,
conv
,
bn
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
Args:
...
@@ -22,18 +22,10 @@ def fuse_conv_bn(conv, bn):
...
@@ -22,18 +22,10 @@ def fuse_conv_bn(conv, bn):
"Conv and BN both must be in the same mode (train or eval)."
"Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map
=
{
fused_module_class_map
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvBnNd
,
k
:
snni
.
SpconvBnNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
spconv
.
SparseConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvBnNd
,
}
}
if
conv
.
training
:
if
is_qat
:
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
assert
bn
.
affine
,
'Only support fusing BatchNorm2d with affine set to True'
assert
bn
.
affine
,
'Only support fusing BatchNorm2d with affine set to True'
assert
bn
.
track_running_stats
,
'Only support fusing BatchNorm2d with tracking_running_stats set to True'
assert
bn
.
track_running_stats
,
'Only support fusing BatchNorm2d with tracking_running_stats set to True'
...
@@ -45,7 +37,7 @@ def fuse_conv_bn(conv, bn):
...
@@ -45,7 +37,7 @@ def fuse_conv_bn(conv, bn):
else
:
else
:
return
fuse_spconv_bn_eval
(
conv
,
bn
)
return
fuse_spconv_bn_eval
(
conv
,
bn
)
def
fuse_conv_bn_relu
(
conv
,
bn
,
relu
):
def
fuse_conv_bn_relu
(
is_qat
,
conv
,
bn
,
relu
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
Args:
...
@@ -61,17 +53,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
...
@@ -61,17 +53,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
assert
(
conv
.
training
==
bn
.
training
==
relu
.
training
),
\
assert
(
conv
.
training
==
bn
.
training
==
relu
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
"Conv and BN both must be in the same mode (train or eval)."
fused_module
:
Optional
[
Type
[
spconv
.
SparseSequential
]]
=
None
fused_module
:
Optional
[
Type
[
spconv
.
SparseSequential
]]
=
None
if
conv
.
training
:
if
is_qat
:
map_to_fused_module_train
=
{
map_to_fused_module_train
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvBnReLUNd
,
k
:
snni
.
SpconvBnReLUNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
spconv
.
SparseConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvBnReLUNd
,
}
}
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv must match num_features of BatchNorm'
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv must match num_features of BatchNorm'
assert
bn
.
affine
,
'Only support fusing BatchNorm with affine set to True'
assert
bn
.
affine
,
'Only support fusing BatchNorm with affine set to True'
...
@@ -83,15 +67,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
...
@@ -83,15 +67,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
,
relu
)))
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
,
relu
)))
else
:
else
:
map_to_fused_module_eval
=
{
map_to_fused_module_eval
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvReLUNd
,
k
:
snni
.
SpconvReLUNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
spconv
.
SparseConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvReLUNd
,
}
}
fused_module
=
map_to_fused_module_eval
.
get
(
type
(
conv
),
None
)
fused_module
=
map_to_fused_module_eval
.
get
(
type
(
conv
),
None
)
if
fused_module
is
not
None
:
if
fused_module
is
not
None
:
...
@@ -100,31 +76,28 @@ def fuse_conv_bn_relu(conv, bn, relu):
...
@@ -100,31 +76,28 @@ def fuse_conv_bn_relu(conv, bn, relu):
else
:
else
:
raise
NotImplementedError
(
"Cannot fuse eval modules: {}"
.
format
((
conv
,
bn
,
relu
)))
raise
NotImplementedError
(
"Cannot fuse eval modules: {}"
.
format
((
conv
,
bn
,
relu
)))
DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
:
Dict
[
Tuple
,
Union
[
nn
.
Sequential
,
Callable
]]
=
{
# DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(
spconv
.
SubMConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SubMConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseInverseConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseInverseConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SubMConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SubMConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseInverseConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseInverseConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SubMConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SubMConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(
spconv
.
SparseInverseConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(
spconv
.
SparseInverseConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
}
# }
# def get_spconv_fuse_method_mapping():
# return DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
# Default map for swapping float module to qat modules
# Default map for swapping float module to qat modules
DEFAULT_SPCONV_QAT_MODULE_MAPPINGS
:
Dict
[
Callable
,
Any
]
=
{
# nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
snni
.
SpconvBnNd
:
SparseConvBn
,
snni
.
SpconvBnReLUNd
:
SparseConvBnReLU
,
}
spconv/pytorch/quantization/intrinsic/__init__.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.modules
import
SpconvBnNd
,
SpconvBnReLUNd
,
SpconvBnAddReLUNd
,
SpconvReLUNd
spconv/pytorch/quantization/intrinsic.py
→
spconv/pytorch/quantization/intrinsic
/modules
.py
View file @
b1c57a31
...
@@ -4,9 +4,28 @@ from torch.nn.utils.parametrize import type_before_parametrizations
...
@@ -4,9 +4,28 @@ from torch.nn.utils.parametrize import type_before_parametrizations
import
torch.ao.nn.intrinsic
as
nni
import
torch.ao.nn.intrinsic
as
nni
from
spconv.pytorch.conv
import
SparseConvolution
from
spconv.pytorch.conv
import
SparseConvolution
from
spconv.pytorch.modules
import
is_spconv_module
from
spconv.pytorch.core
import
SparseConvTensor
class
_FusedSparseModule
(
nni
.
_FusedModule
):
def
forward
(
self
,
input
):
for
k
,
module
in
self
.
_modules
.
items
():
if
is_spconv_module
(
module
):
# use SpConvTensor as input
if
isinstance
(
input
,
list
):
input
=
module
(
input
)
else
:
# assert isinstance(input, spconv.SparseConvTensor)
# self._sparity_dict[k] = input.sparity
input
=
module
(
input
)
else
:
if
isinstance
(
input
,
SparseConvTensor
):
if
input
.
indices
.
shape
[
0
]
!=
0
:
input
=
input
.
replace_feature
(
module
(
input
.
features
))
else
:
input
=
module
(
input
)
return
input
class
SpconvReLUNd
(
nni
.
_FusedModule
):
class
SpconvReLUNd
(
_Fused
Sparse
Module
):
r
"""This is a sequential container which calls the Conv3d and ReLU modules.
r
"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
relu
):
def
__init__
(
self
,
conv
,
relu
):
...
@@ -15,7 +34,7 @@ class SpconvReLUNd(nni._FusedModule):
...
@@ -15,7 +34,7 @@ class SpconvReLUNd(nni._FusedModule):
type
(
conv
),
type
(
relu
))
type
(
conv
),
type
(
relu
))
super
().
__init__
(
conv
,
relu
)
super
().
__init__
(
conv
,
relu
)
class
SpconvBnNd
(
nni
.
_FusedModule
):
class
SpconvBnNd
(
_Fused
Sparse
Module
):
r
"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
r
"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
):
def
__init__
(
self
,
conv
,
bn
):
...
@@ -24,8 +43,16 @@ class SpconvBnNd(nni._FusedModule):
...
@@ -24,8 +43,16 @@ class SpconvBnNd(nni._FusedModule):
type
(
conv
),
type
(
bn
))
type
(
conv
),
type
(
bn
))
super
().
__init__
(
conv
,
bn
)
super
().
__init__
(
conv
,
bn
)
class
SpconvBnReLUNd
(
_FusedSparseModule
):
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
,
relu
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
bn
,
BatchNorm1d
)
and
\
isinstance
(
relu
,
ReLU
),
'Incorrect types for input modules{}{}{}'
\
.
format
(
type
(
conv
),
type
(
bn
),
type
(
relu
))
super
().
__init__
(
conv
,
bn
,
relu
)
class
SpconvBnReLUNd
(
nni
.
_FusedModule
):
class
SpconvBn
Add
ReLUNd
(
_Fused
Sparse
Module
):
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
,
relu
):
def
__init__
(
self
,
conv
,
bn
,
relu
):
...
...
spconv/pytorch/quantization/intrinsic/qat/__init__.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.modules
import
SparseConvBn
,
SparseConvBnAddReLU
,
SparseConvBnReLU
,
SparseConv
,
SparseConvReLU
\ No newline at end of file
spconv/pytorch/quantization/
conv_fused
.py
→
spconv/pytorch/quantization/
intrinsic/qat/modules
.py
View file @
b1c57a31
...
@@ -6,8 +6,6 @@ import torch.ao.nn.intrinsic as nni
...
@@ -6,8 +6,6 @@ import torch.ao.nn.intrinsic as nni
import
torch.ao.nn.qat
as
nnqat
import
torch.ao.nn.qat
as
nnqat
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn.utils
import
fuse_conv_bn_weights
from
torch.nn.modules.utils
import
_single
,
_pair
,
_triple
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
typing
import
TypeVar
from
typing
import
TypeVar
from
spconv.pytorch.conv
import
SparseConvolution
from
spconv.pytorch.conv
import
SparseConvolution
...
@@ -16,9 +14,189 @@ from spconv.core import ConvAlgo
...
@@ -16,9 +14,189 @@ from spconv.core import ConvAlgo
from
cumm
import
tensorview
as
tv
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.core
import
SparseConvTensor
from
spconv.pytorch.core
import
SparseConvTensor
import
spconv.pytorch.quantization.intrinsic
as
snni
import
spconv.pytorch.quantization.intrinsic
as
snni
from
spconv.pytorch.quantization.utils
import
fuse_spconv_bn_weights
MOD
=
TypeVar
(
'MOD'
,
bound
=
SparseConvolution
)
MOD
=
TypeVar
(
'MOD'
,
bound
=
SparseConvolution
)
class
_SparseConv
(
SparseConvolution
,
nni
.
_FusedModule
):
_FLOAT_MODULE
=
MOD
_FLOAT_CONV_MODULE
=
SparseConvolution
def
__init__
(
self
,
ndim
:
int
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
3
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
transposed
:
bool
=
False
,
inverse
:
bool
=
False
,
indice_key
:
Optional
[
str
]
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
name
=
None
,
qconfig
=
None
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
SparseConvolution
.
__init__
(
self
,
ndim
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
=
False
,
subm
=
subm
,
output_padding
=
output_padding
,
transposed
=
transposed
,
inverse
=
inverse
,
indice_key
=
indice_key
,
algo
=
algo
,
fp32_accum
=
fp32_accum
,
record_voxel_count
=
record_voxel_count
,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
name
=
name
,
**
factory_kwargs
)
assert
qconfig
,
'qconfig must be provided for QAT module'
self
.
qconfig
=
qconfig
self
.
weight_fake_quant
=
qconfig
.
weight
(
factory_kwargs
=
factory_kwargs
)
def
forward
(
self
,
input
):
return
self
.
_conv_forward
(
False
,
input
,
self
.
weight_fake_quant
(
self
.
weight
),
self
.
bias
)
@
staticmethod
def
from_float
(
cls
,
mod
):
r
"""Create a qat module from a float module
Args:
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert
type
(
mod
)
==
cls
.
_FLOAT_MODULE
,
(
"qat."
+
cls
.
__name__
+
".from_float only works for "
+
cls
.
_FLOAT_MODULE
.
__name__
# type: ignore[attr-defined]
)
assert
hasattr
(
mod
,
'qconfig'
),
'Input float module must have qconfig defined'
assert
mod
.
qconfig
,
'Input float module must have a valid qconfig'
if
issubclass
(
type
(
mod
),
nni
.
_FusedModule
):
mod
=
mod
[
0
]
# type: ignore[index]
conv
:
SparseConvolution
=
mod
qconfig
=
mod
.
qconfig
qat_conv
=
cls
(
conv
.
ndim
,
conv
.
in_channels
,
conv
.
out_channels
,
conv
.
kernel_size
,
conv
.
stride
,
conv
.
padding
,
conv
.
dilation
,
conv
.
groups
,
conv
.
bias
is
not
None
,
subm
=
conv
.
subm
,
output_padding
=
conv
.
output_padding
,
transposed
=
conv
.
transposed
,
inverse
=
conv
.
inverse
,
indice_key
=
conv
.
indice_key
,
algo
=
conv
.
algo
,
fp32_accum
=
conv
.
fp32_accum
,
record_voxel_count
=
conv
.
record_voxel_count
,
act_type
=
conv
.
act_type
,
act_alpha
=
conv
.
act_alpha
,
act_beta
=
conv
.
act_beta
,
name
=
conv
.
name
,
qconfig
=
qconfig
)
qat_conv
.
weight
=
mod
.
weight
qat_conv
.
bias
=
mod
.
bias
return
qat_conv
def
to_float
(
self
):
""" This works for both single qat conv, and the qat conv - relu modules
to convert the qat module to a floating point module
"""
cls
=
type
(
self
)
conv
=
cls
.
_FLOAT_CONV_MODULE
(
# type: ignore[attr-defined]
self
.
ndim
,
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
bias
is
not
None
,
subm
=
self
.
subm
,
output_padding
=
self
.
output_padding
,
transposed
=
self
.
transposed
,
inverse
=
self
.
inverse
,
indice_key
=
self
.
indice_key
,
algo
=
self
.
algo
,
fp32_accum
=
self
.
fp32_accum
,
record_voxel_count
=
self
.
record_voxel_count
,
act_type
=
self
.
act_type
,
act_alpha
=
self
.
act_alpha
,
act_beta
=
self
.
act_beta
,
name
=
self
.
name
)
conv
.
weight
=
torch
.
nn
.
Parameter
(
self
.
weight
.
detach
())
if
self
.
bias
is
not
None
:
conv
.
bias
=
torch
.
nn
.
Parameter
(
self
.
bias
.
detach
())
# conv relu
if
issubclass
(
cls
,
nni
.
_FusedModule
):
modules
=
[
conv
]
assert
hasattr
(
cls
,
"_FLOAT_RELU_MODULE"
)
relu
=
cls
.
_FLOAT_RELU_MODULE
()
# type: ignore[attr-defined]
modules
.
append
(
relu
)
fused
=
cls
.
_FLOAT_MODULE
(
*
modules
)
# type: ignore[arg-type, attr-defined, operator]
fused
.
train
(
self
.
training
)
return
fused
else
:
return
conv
class
SparseConv
(
_SparseConv
,
SparseConvolution
):
r
"""
A Conv1d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as :class:`~torch.nn.Conv1d`
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE
=
SparseConvolution
_FLOAT_CONV_MODULE
=
SparseConvolution
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
().
from_float
(
cls
,
mod
)
class
SparseConvReLU
(
SparseConv
,
nni
.
_FusedModule
):
r
"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE
=
snni
.
SpconvReLUNd
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
None
_FLOAT_RELU_MODULE
=
nn
.
ReLU
def
forward
(
self
,
input
):
x
=
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight_fake_quant
(
self
.
weight
),
self
.
bias
)
return
x
.
replace_feature
(
F
.
relu
(
x
.
features
))
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvReLU
,
cls
).
from_float
(
mod
)
class
_SparseConvBn
(
SparseConvolution
,
nni
.
_FusedModule
):
class
_SparseConvBn
(
SparseConvolution
,
nni
.
_FusedModule
):
_version
=
2
_version
=
2
...
@@ -34,7 +212,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
...
@@ -34,7 +212,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
...
@@ -143,7 +321,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
...
@@ -143,7 +321,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
zero_bias
=
torch
.
zeros_like
(
self
.
bias
,
dtype
=
input
.
features
.
dtype
)
zero_bias
=
torch
.
zeros_like
(
self
.
bias
,
dtype
=
input
.
features
.
dtype
)
else
:
else
:
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
scaled_weight
.
device
,
dtype
=
input
.
features
.
dtype
)
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
scaled_weight
.
device
,
dtype
=
input
.
features
.
dtype
)
conv_spt
=
self
.
_conv_forward
(
input
,
scaled_weight
,
zero_bias
)
conv_spt
=
self
.
_conv_forward
(
self
.
training
,
input
,
scaled_weight
,
zero_bias
)
conv
=
conv_spt
.
features
conv
=
conv_spt
.
features
conv_orig
=
conv
/
scale_factor
.
reshape
(
bias_shape
)
conv_orig
=
conv
/
scale_factor
.
reshape
(
bias_shape
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
@@ -396,7 +574,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
...
@@ -396,7 +574,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
if
cls
.
_FLOAT_BN_MODULE
:
# type: ignore[attr-defined]
if
cls
.
_FLOAT_BN_MODULE
:
# type: ignore[attr-defined]
# fuse bn into conv
# fuse bn into conv
conv
.
weight
,
conv
.
bias
=
fuse_conv_bn_weights
(
conv
.
weight
,
conv
.
bias
=
fuse_
sp
conv_bn_weights
(
conv
.
weight
,
conv
.
weight
,
conv
.
bias
,
conv
.
bias
,
self
.
bn
.
running_mean
,
self
.
bn
.
running_mean
,
...
@@ -473,3 +651,35 @@ class SparseConvBnReLU(_SparseConvBn):
...
@@ -473,3 +651,35 @@ class SparseConvBnReLU(_SparseConvBn):
@
classmethod
@
classmethod
def
from_float
(
cls
,
mod
):
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvBnReLU
,
cls
).
from_float
(
mod
)
return
super
(
SparseConvBnReLU
,
cls
).
from_float
(
mod
)
class
SparseConvBnAddReLU
(
_SparseConvBn
):
r
"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE
=
snni
.
SpconvBnReLUNd
# type: ignore[assignment]
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
nn
.
BatchNorm1d
_FLOAT_RELU_MODULE
=
nn
.
ReLU
# type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE
=
snni
.
SpconvReLUNd
def
forward
(
self
,
input
,
add_input
):
x
=
_SparseConvBn
.
_forward
(
self
,
input
,
add_input
)
return
x
.
replace_feature
(
F
.
relu
(
x
.
features
))
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvBnAddReLU
,
cls
).
from_float
(
mod
)
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.conv_relu
import
*
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment