Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
e387ee74
Commit
e387ee74
authored
Jan 04, 2023
by
yan.yan
Browse files
sync quantization code
parent
b1c57a31
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1280 additions
and
1126 deletions
+1280
-1126
example/mnist/mnist_qat.py
example/mnist/mnist_qat.py
+289
-47
spconv/algo.py
spconv/algo.py
+6
-0
spconv/build.py
spconv/build.py
+2
-2
spconv/constants.py
spconv/constants.py
+2
-0
spconv/core.py
spconv/core.py
+347
-357
spconv/pytorch/__init__.py
spconv/pytorch/__init__.py
+2
-1
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+16
-482
spconv/pytorch/core.py
spconv/pytorch/core.py
+16
-0
spconv/pytorch/modules.py
spconv/pytorch/modules.py
+35
-1
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+22
-9
spconv/pytorch/pool.py
spconv/pytorch/pool.py
+15
-3
spconv/pytorch/quantization/__init__.py
spconv/pytorch/quantization/__init__.py
+1
-0
spconv/pytorch/quantization/backend_cfg.py
spconv/pytorch/quantization/backend_cfg.py
+412
-181
spconv/pytorch/quantization/core.py
spconv/pytorch/quantization/core.py
+21
-0
spconv/pytorch/quantization/fake_q.py
spconv/pytorch/quantization/fake_q.py
+4
-2
spconv/pytorch/quantization/fuse_mapping.py
spconv/pytorch/quantization/fuse_mapping.py
+26
-30
spconv/pytorch/quantization/intrinsic/__init__.py
spconv/pytorch/quantization/intrinsic/__init__.py
+1
-1
spconv/pytorch/quantization/intrinsic/modules.py
spconv/pytorch/quantization/intrinsic/modules.py
+24
-0
spconv/pytorch/quantization/intrinsic/qat/__init__.py
spconv/pytorch/quantization/intrinsic/qat/__init__.py
+2
-1
spconv/pytorch/quantization/intrinsic/qat/modules.py
spconv/pytorch/quantization/intrinsic/qat/modules.py
+37
-9
No files found.
example/mnist/mnist_qat.py
View file @
e387ee74
...
...
@@ -13,25 +13,37 @@
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
contextlib
import
copy
from
typing
import
Dict
,
Optional
import
torch
import
spconv.pytorch
as
spconv
import
torch.ao.quantization
import
torch.ao.quantization.quantize_fx
as
qfx
import
torch.cuda.amp
import
torch.fx
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
torchvision
import
datasets
,
transforms
from
torch.ao.quantization
import
(
DeQuantStub
,
QuantStub
,
get_default_qconfig_mapping
)
from
torch.ao.quantization.fx._lower_to_native_backend
import
\
STATIC_LOWER_FUSED_MODULE_MAP
,
STATIC_LOWER_MODULE_MAP
from
torch.optim.lr_scheduler
import
StepLR
import
contextlib
import
torch.cuda.amp
import
torch.ao.quantization
from
torch.ao.quantization
import
QuantStub
,
DeQuantStub
import
torch.ao.quantization.quantize_fx
as
qfx
from
spconv.pytorch.quantization.fake_q
import
get_default_spconv_qconfig_mapping
from
torchvision
import
datasets
,
transforms
import
spconv.pytorch
as
spconv
import
spconv.pytorch.quantization
as
spconvq
from
spconv.pytorch.quantization
import
get_default_spconv_trt_ptq_qconfig
from
torch.ao.quantization
import
get_default_qconfig_mapping
from
spconv.pytorch.quantization.backend_cfg
import
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
from
torch.ao.quantization.fx._lower_to_native_backend
import
STATIC_LOWER_FUSED_MODULE_MAP
from
spconv.pytorch.quantization.backend_cfg
import
\
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
,
SPCONV_STATIC_LOWER_MODULE_MAP
from
spconv.pytorch.quantization.core
import
quantize_per_tensor
from
spconv.pytorch.quantization.fake_q
import
\
get_default_spconv_qconfig_mapping
from
spconv.pytorch.quantization.intrinsic.modules
import
SpconvBnAddReLUNd
,
SpconvAddReLUNd
import
spconv.pytorch.quantization.intrinsic.quantized
as
snniq
@
contextlib
.
contextmanager
def
identity_ctx
():
...
...
@@ -57,6 +69,142 @@ class SparseConvBNReLU(spconv.SparseSequential):
nn
.
ReLU
(
inplace
=
False
)
)
class
SparseBasicBlock
(
spconv
.
SparseModule
):
"""residual block that supported by spconv quantization.
"""
expansion
=
1
def
__init__
(
self
,
in_planes
,
out_planes
,
stride
=
1
,
downsample
=
None
):
spconv
.
SparseModule
.
__init__
(
self
)
conv1
=
spconv
.
SubMConv2d
(
in_planes
,
out_planes
,
3
,
stride
,
1
,
bias
=
False
)
conv2
=
spconv
.
SubMConv2d
(
out_planes
,
out_planes
,
3
,
stride
,
1
,
bias
=
False
)
norm1
=
nn
.
BatchNorm1d
(
out_planes
,
momentum
=
0.1
)
norm2
=
nn
.
BatchNorm1d
(
out_planes
,
momentum
=
0.1
)
self
.
conv1_bn_relu
=
spconv
.
SparseSequential
(
conv
=
conv1
,
bn
=
norm1
,
relu
=
nn
.
ReLU
(
inplace
=
True
))
self
.
conv2_bn
=
spconv
.
SparseSequential
(
conv
=
conv2
,
bn
=
norm2
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
iden_for_fx_match
=
nn
.
Identity
()
def
forward
(
self
,
x
:
spconv
.
SparseConvTensor
):
identity
=
self
.
iden_for_fx_match
(
x
.
features
)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out
=
self
.
conv1_bn_relu
(
x
)
out
=
self
.
conv2_bn
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
=
out
.
replace_feature
(
self
.
relu
(
out
.
features
+
identity
))
return
out
class
SparseBasicBlock1
(
spconv
.
SparseModule
):
"""residual block that supported by spconv quantization.
"""
expansion
=
1
def
__init__
(
self
,
in_planes
,
out_planes
,
stride
=
1
,
downsample
=
None
):
spconv
.
SparseModule
.
__init__
(
self
)
self
.
conv1
=
spconv
.
SubMConv2d
(
in_planes
,
out_planes
,
3
,
stride
,
1
,
bias
=
False
)
self
.
conv2
=
spconv
.
SubMConv2d
(
out_planes
,
out_planes
,
3
,
stride
,
1
,
bias
=
False
)
self
.
norm1
=
nn
.
BatchNorm1d
(
out_planes
,
momentum
=
0.1
)
self
.
norm2
=
nn
.
BatchNorm1d
(
out_planes
,
momentum
=
0.1
)
self
.
relu1
=
nn
.
ReLU
(
inplace
=
True
)
self
.
relu2
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
iden_for_fx_match
=
nn
.
Identity
()
def
forward
(
self
,
x
:
spconv
.
SparseConvTensor
):
identity
=
self
.
iden_for_fx_match
(
x
.
features
)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out
=
self
.
conv1
(
x
)
out
=
out
.
replace_feature
(
self
.
relu1
(
self
.
norm1
(
out
.
features
)))
out
=
self
.
conv2
(
out
)
out
=
out
.
replace_feature
(
self
.
norm2
(
out
.
features
))
# if self.downsample is not None:
# identity = self.downsample(x)
out
=
out
.
replace_feature
(
self
.
relu2
(
out
.
features
+
identity
))
return
out
class
SparseBasicBlock2
(
spconv
.
SparseModule
):
"""residual block that supported by spconv quantization.
"""
expansion
=
1
def
__init__
(
self
,
in_planes
,
out_planes
,
stride
=
1
,
downsample
=
None
):
spconv
.
SparseModule
.
__init__
(
self
)
self
.
conv1
=
spconv
.
SubMConv2d
(
in_planes
,
out_planes
,
3
,
stride
,
1
,
bias
=
False
)
self
.
conv2
=
spconv
.
SubMConv2d
(
out_planes
,
out_planes
,
3
,
stride
,
1
,
bias
=
False
)
self
.
norm1
=
spconv
.
SparseBatchNorm
(
out_planes
,
momentum
=
0.1
)
self
.
norm2
=
spconv
.
SparseBatchNorm
(
out_planes
,
momentum
=
0.1
)
self
.
relu1
=
spconv
.
SparseReLU
(
inplace
=
True
)
self
.
relu2
=
spconv
.
SparseReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
iden_for_fx_match
=
spconv
.
SparseIdentity
()
def
forward
(
self
,
x
:
spconv
.
SparseConvTensor
):
identity
=
self
.
iden_for_fx_match
(
x
)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out
=
self
.
conv1
(
x
)
out
=
self
.
relu1
(
self
.
norm1
(
out
))
out
=
self
.
conv2
(
out
)
out
=
self
.
norm2
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
=
self
.
relu2
(
out
+
identity
)
return
out
class
SparseBasicBlock3
(
spconv
.
SparseModule
):
"""residual block that supported by spconv quantization.
"""
expansion
=
1
def
__init__
(
self
,
in_planes
,
out_planes
,
stride
=
1
,
downsample
=
None
):
spconv
.
SparseModule
.
__init__
(
self
)
self
.
conv1
=
spconv
.
SubMConv2d
(
in_planes
,
out_planes
,
3
,
stride
,
1
,
bias
=
False
)
conv2
=
spconv
.
SubMConv2d
(
out_planes
,
out_planes
,
3
,
stride
,
1
,
bias
=
False
)
self
.
norm1
=
spconv
.
SparseBatchNorm
(
out_planes
,
momentum
=
0.1
)
norm2
=
spconv
.
SparseBatchNorm
(
out_planes
,
momentum
=
0.1
)
self
.
residual_conv
=
SpconvAddReLUNd
(
conv2
,
spconv
.
SparseReLU
(
inplace
=
True
))
self
.
relu1
=
spconv
.
SparseReLU
(
inplace
=
True
)
# self.relu2 = spconv.SparseReLU(inplace=True)
self
.
downsample
=
downsample
self
.
iden_for_fx_match
=
spconv
.
SparseIdentity
()
def
forward
(
self
,
x
:
spconv
.
SparseConvTensor
):
identity
=
self
.
iden_for_fx_match
(
x
)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out
=
self
.
conv1
(
x
)
out
=
self
.
relu1
(
self
.
norm1
(
out
))
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
=
self
.
residual_conv
(
out
,
identity
)
return
out
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
...
...
@@ -126,7 +274,7 @@ class NetV2(nn.Module):
class
NetPTQ
(
nn
.
Module
):
"""pytorch currently don't support cuda int8 inference, so
we
only us
e sparse
ops
here.
we
build a pur
e sparse
network
here.
"""
def
__init__
(
self
):
super
(
NetPTQ
,
self
).
__init__
()
...
...
@@ -138,7 +286,6 @@ class NetPTQ(nn.Module):
SparseConvBNReLU
(
64
,
64
,
3
,
2
,
1
),
# 4x4
spconv
.
SparseConv2d
(
64
,
10
,
4
,
4
),
spconv
.
ToDense
(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
...
...
@@ -158,22 +305,47 @@ class NetPTQ(nn.Module):
# print(x_sp.shape)
x
=
x_sp
x
=
torch
.
flatten
(
x
,
1
)
# x_res = torch.zeros_like(x)
# x_res[x_sp.indices[:, 0].long()] = x
# x = x_res
# x = torch.flatten(x, 1)
# x = self.dropout1(x)
# x = self.fc1(x)
# x = F.relu(x)
# x = self.dropout2(x)
# x = self.fc2(x)
# print(x_sp.features.shape, x_sp.spatial_shape)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
class
ResidualNetPTQ
(
nn
.
Module
):
"""pytorch currently don't support cuda int8 inference, so
we build a pure sparse network here.
"""
def
__init__
(
self
):
super
(
ResidualNetPTQ
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
SubMConvBNReLU
(
1
,
32
,
3
),
SparseBasicBlock2
(
32
,
32
),
SubMConvBNReLU
(
32
,
64
,
3
),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 14x14
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 7x7
SparseConvBNReLU
(
64
,
64
,
3
,
2
,
1
),
# 4x4
spconv
.
SparseConv2d
(
64
,
10
,
4
,
4
),
spconv
.
ToDense
(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
def
forward
(
self
,
features
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
):
# x: [N, 28, 28, 1], must be NHWC tensor
features
=
self
.
quant
(
features
)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp
=
spconv
.
SparseConvTensor
(
features
,
indices
,
[
28
,
28
],
batch_size
)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp
=
self
.
net
(
x_sp
)
# print(x_sp.shape)
x
=
x_sp
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
class
NetDense
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -184,6 +356,8 @@ class NetDense(nn.Module):
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
self
.
fc1
=
nn
.
Linear
(
9216
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
iden
=
spconv
.
SparseIdentity
()
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
...
...
@@ -195,6 +369,7 @@ class NetDense(nn.Module):
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
iden
(
x
)
x
=
F
.
max_pool2d
(
x
,
2
)
x
=
self
.
dropout1
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
...
...
@@ -299,6 +474,54 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else
:
output
=
model
(
image
)
def
transform_qdq
(
m
:
torch
.
fx
.
GraphModule
)
->
torch
.
fx
.
GraphModule
:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for
node
in
m
.
graph
.
nodes
:
# Checks if we're calling a function (i.e:
# torch.add)
if
node
.
op
==
'call_function'
:
# The target attribute is the function
# that call_function calls.
if
node
.
target
==
torch
.
quantize_per_tensor
:
node
.
target
=
quantize_per_tensor
m
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
m
.
recompile
()
return
m
def
is_dequantize_node
(
node
):
return
isinstance
(
node
,
torch
.
fx
.
Node
)
and
node
.
op
==
"call_method"
and
node
.
target
==
"dequantize"
def
_get_module
(
node
:
torch
.
fx
.
Node
,
modules
:
Dict
[
str
,
nn
.
Module
])
->
Optional
[
nn
.
Module
]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if
node
.
op
==
"call_module"
and
str
(
node
.
target
)
in
modules
:
return
modules
[
str
(
node
.
target
)]
else
:
return
None
def
remove_conv_add_dq
(
model
:
torch
.
fx
.
graph_module
.
GraphModule
):
modules
=
dict
(
model
.
named_modules
(
remove_duplicate
=
False
))
for
n
in
model
.
graph
.
nodes
:
if
(
n
.
op
==
"call_module"
and
type
(
_get_module
(
n
,
modules
))
==
snniq
.
SparseConvAddReLU
):
# check second input, if it's dequantized, remove that dequantize node
arg1
=
n
.
args
[
1
]
if
is_dequantize_node
(
arg1
):
dq_node
=
arg1
assert
(
isinstance
(
dq_node
,
torch
.
fx
.
Node
))
dn_input
=
dq_node
.
args
[
0
]
n
.
replace_input_with
(
dq_node
,
dn_input
)
model
.
graph
.
eliminate_dead_code
()
model
.
recompile
()
model
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
return
model
def
main
():
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
...
...
@@ -361,11 +584,11 @@ def main():
torch
.
manual_seed
(
args
.
seed
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
use_cuda
and
args
.
sparse
else
"cpu"
)
qdevice
=
torch
.
device
(
"cuda"
if
use_cuda
and
args
.
sparse
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
if
args
.
sparse
:
model
=
NetPTQ
().
to
(
device
)
model
=
Residual
NetPTQ
().
to
(
device
)
else
:
model
=
NetDense
().
to
(
device
)
...
...
@@ -401,42 +624,61 @@ def main():
train
(
args
,
model
,
device
,
train_loader
,
optimizer
,
epoch
)
test
(
args
,
model
,
device
,
test_loader
)
scheduler
.
step
()
#
if args.save_model:
#
torch.save(model.state_dict(), "mnist_cnn.pt")
if
args
.
save_model
:
torch
.
save
(
model
.
state_dict
(),
"mnist_cnn.pt"
)
model
.
eval
()
STATIC_LOWER_FUSED_MODULE_MAP
.
update
(
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
)
if
not
args
.
sparse
:
model
=
model
.
cpu
()
# qconfig_mapping_default = get_default_qconfig_mapping("x86")
model_qat
=
copy
.
deepcopy
(
model
)
STATIC_LOWER_FUSED_MODULE_MAP
.
update
(
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
)
STATIC_LOWER_MODULE_MAP
.
update
(
SPCONV_STATIC_LOWER_MODULE_MAP
)
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
qconfig_mapping
=
get_default_spconv_qconfig_mapping
(
False
)
prepare_cfg
=
spconvq
.
get_spconv_prepare_custom_config
()
backend_cfg
=
spconvq
.
get_spconv_backend_config
()
convert_cfg
=
spconvq
.
get_spconv_convert_custom_config
()
#
convert_cfg = spconvq.get_spconv_convert_custom_config()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model
=
qfx
.
prepare_fx
(
model
,
qconfig_mapping
,
(),
backend_config
=
backend_cfg
,
prepare_custom_config
=
prepare_cfg
)
# prepared_model.print_readable()
print
([
type
(
m
)
for
m
in
prepared_model
.
modules
()])
print
(
prepared_model
)
# print(prepared_model)
# breakpoint()
# print(prepared_model)
# calibrate: run model with some inputs
calibrate
(
args
,
prepared_model
,
test_loader
,
qdevice
)
#
calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model
=
qfx
.
convert_fx
(
prepared_model
,
qconfig_mapping
=
qconfig_mapping
,
backend_config
=
backend_cfg
)
converted_model
=
transform_qdq
(
converted_model
)
# test converted ptq model with int8 kernel
remove_conv_add_dq
(
converted_model
)
converted_model
=
qfx
.
convert_to_reference_fx
(
prepared_model
,
convert_cfg
,
qconfig_mapping
=
qconfig_mapping
,
backend_config
=
backend_cfg
)
print
([
type
(
m
)
for
m
in
converted_model
.
modules
()])
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
# model.qconfig = get_default_spconv_trt_ptq_qconfig()
# prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# convert_custom_config_dict = spconvq.get_convert_custom_config()
# torch.ao.quantization.prepare(model, inplace=True)
# print('Post Training Quantization Prepare: Inserting Observers')
# print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# test(args, model, device, test_loader)
print
(
converted_model
)
breakpoint
()
test
(
args
,
converted_model
,
qdevice
,
test_loader
)
# do qat
# qconfig_mapping_qat = get_default_spconv_qconfig_mapping(True)
# prepared_model_qat = qfx.prepare_qat_fx(model_qat, qconfig_mapping_qat, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# # converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
# # breakpoint()
# print(prepared_model_qat)
# train(args, prepared_model_qat, qdevice, train_loader, optimizer, 1)
# converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
# converted_model = transform_qdq(converted_model)
# test(args, converted_model, qdevice, test_loader)
# # [type(m) for m in prepared_model_qat.modules()]
# # model.qconfig = get_default_spconv_trt_ptq_qconfig()
# # prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# # convert_custom_config_dict = spconvq.get_convert_custom_config()
# # torch.ao.quantization.prepare(model, inplace=True)
# # print('Post Training Quantization Prepare: Inserting Observers')
# # print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# # test(args, model, device, test_loader)
# print(converted_model)
# you will see some nvrtc compile log here, which means int8 kernel is used.
breakpoint
()
if
__name__
==
'__main__'
:
main
()
spconv/algo.py
View file @
e387ee74
...
...
@@ -188,10 +188,16 @@ class ConvTunerSimple(ConvTunerSimpleBase):
cudadevrt_p
=
get_cudadevrt_path
()
assert
cudadevrt_p
is
not
None
,
"DynamicParallism must have cudadevrt"
cudadevrt
=
str
(
cudadevrt_p
)
# mod = CummNVRTCModule([kernel],
# cudadevrt_path=cudadevrt,
# verbose=True,
# custom_names=custom_names,
# verbose_path="/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8")
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
...
...
spconv/build.py
View file @
e387ee74
...
...
@@ -18,10 +18,10 @@ from typing import List
import
pccm
from
pccm.utils
import
project_is_editable
,
project_is_installed
from
ccimport.compat
import
InWindows
from
.constants
import
PACKAGE_NAME
,
PACKAGE_ROOT
,
DISABLE_JIT
from
.constants
import
PACKAGE_NAME
,
PACKAGE_ROOT
,
DISABLE_JIT
,
SPCONV_INT8_DEBUG
if
project_is_installed
(
PACKAGE_NAME
)
and
project_is_editable
(
PACKAGE_NAME
)
and
not
DISABLE_JIT
and
False
:
PACKAGE_NAME
)
and
not
DISABLE_JIT
and
not
SPCONV_INT8_DEBUG
:
from
spconv.core
import
SHUFFLE_SIMT_PARAMS
,
SHUFFLE_VOLTA_PARAMS
,
SHUFFLE_TURING_PARAMS
,
SHUFFLE_AMPERE_PARAMS
from
spconv.core
import
IMPLGEMM_SIMT_PARAMS
,
IMPLGEMM_VOLTA_PARAMS
,
IMPLGEMM_TURING_PARAMS
,
IMPLGEMM_AMPERE_PARAMS
...
...
spconv/constants.py
View file @
e387ee74
...
...
@@ -116,3 +116,5 @@ SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
SPCONV_ALLOW_TF32
=
False
SPCONV_INT8_DEBUG
=
False
\ No newline at end of file
spconv/core.py
View file @
e387ee74
...
...
@@ -19,7 +19,7 @@ from cumm.gemm.algospec.core import TensorOp
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
spconv.constants
import
NDIM_DONT_CARE
from
spconv.constants
import
NDIM_DONT_CARE
,
SPCONV_INT8_DEBUG
class
ConvAlgo
(
Enum
):
...
...
@@ -39,18 +39,18 @@ class AlgoHint(Enum):
# TODO two step build: build gemm kernels first, then bind for every python
SHUFFLE_SIMT_PARAMS
:
List
[
GemmAlgoParams
]
=
[
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
#
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], "",
#
2, kernel.GemmAlgo.SimtDP4A, None),
#
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s8,s32,s32"], "",
#
2, kernel.GemmAlgo.SimtDP4A, None),
#
*gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"],
#
"", 2, kernel.GemmAlgo.SimtDP4A, None),
#
*gen_shuffle_params(
#
(128, 128, 32),
#
(64, 32, 32), ["s8,s8,s8,s32,s32"], "", 2,
#
kernel.GemmAlgo.SimtDP4A, None),
#
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s8,s32,s32"], "",
#
2, kernel.GemmAlgo.SimtDP4A, None),
*
gen_shuffle_params
((
64
,
256
,
8
),
(
32
,
64
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
# *gen_shuffle_params(
...
...
@@ -164,39 +164,39 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
(
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
# *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s8,s32,s32"], "",
# 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
# *gen_shuffle_params(
# (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s8,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
# (32, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
# TensorOp((8, 8, 16))),
# # *gen_shuffle_params(
# # (128, 128, 32),
# # (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# # kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
# *gen_shuffle_params(
# (128, 256, 32),
# (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
# TensorOp((8, 8, 16))),
# *gen_shuffle_params(
# (256, 128, 32),
# (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
# TensorOp((8, 8, 16))),
# *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s8,s32,s32"], "",
# 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
# *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], "",
# 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
]
SHUFFLE_AMPERE_PARAMS
=
[
*
gen_shuffle_params
(
(
128
,
128
,
64
),
(
64
,
64
,
64
),
[
"s8,s8,s8,s32,s32"
],
""
,
3
,
kernel
.
GemmAlgo
.
Ampere
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
(
(
128
,
64
,
64
),
(
64
,
32
,
64
),
[
"s8,s8,s8,s32,s32"
],
""
,
3
,
kernel
.
GemmAlgo
.
Ampere
,
TensorOp
((
8
,
8
,
16
))),
#
*gen_shuffle_params(
#
(128, 128, 64),
#
(64, 64, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
#
TensorOp((8, 8, 16))),
#
*gen_shuffle_params(
#
(128, 64, 64),
#
(64, 32, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
#
TensorOp((8, 8, 16))),
]
# SHUFFLE_TURING_PARAMS = []
...
...
@@ -619,182 +619,170 @@ IMPLGEMM_AMPERE_PARAMS = [
increment_k_first
=
True
,
access_per_vector
=
1
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
]
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
if
not
SPCONV_INT8_DEBUG
:
IMPLGEMM_AMPERE_PARAMS
.
extend
([
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
)
,
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
)
,
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
]
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
])
IMPLGEMM_TURING_PARAMS
=
[
...
...
@@ -828,151 +816,6 @@ IMPLGEMM_TURING_PARAMS = [
access_per_vector
=
0
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
32
,
16
,
16
),
(
16
,
16
,
16
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
...
...
@@ -1228,6 +1071,153 @@ IMPLGEMM_TURING_PARAMS = [
# gen_conv_params(ConvFwdAndBwdInput, )
]
if
not
SPCONV_INT8_DEBUG
:
IMPLGEMM_TURING_PARAMS
.
extend
([
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
256
,
64
),
(
64
,
128
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
256
,
128
,
64
),
(
128
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
128
),
(
64
,
64
,
128
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
])
ALL_NATIVE_PARAMS
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS
=
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_AMPERE_PARAMS
spconv/pytorch/__init__.py
View file @
e387ee74
...
...
@@ -14,7 +14,8 @@ from spconv.pytorch.conv import (SparseConv1d, SparseConv2d, SparseConv3d,
SubMConv3d
,
SubMConv4d
)
from
spconv.pytorch.identity
import
Identity
from
spconv.pytorch.modules
import
(
SparseModule
,
SparseSequential
,
assign_name_for_sparse_modules
)
assign_name_for_sparse_modules
,
SparseBatchNorm
,
SparseReLU
,
SparseIdentity
)
from
spconv.pytorch.ops
import
ConvAlgo
from
spconv.pytorch.pool
import
(
SparseMaxPool1d
,
SparseMaxPool2d
,
SparseMaxPool3d
,
SparseMaxPool4d
,
...
...
spconv/pytorch/conv.py
View file @
e387ee74
...
...
@@ -157,35 +157,10 @@ class SparseConvolutionBase:
batch_size
=
input
.
batch_size
bias_for_training
=
bias
if
training
else
None
bias_for_infer
=
bias
if
not
training
else
None
output_add_scale
=
1
.0
output_add_scale
=
0
.0
if
is_int8
:
if
add_input
is
not
None
:
output_add_scale
=
add_input
.
q_scale
()
# if self.enable_int8_test_mode:
# assert not self.training, "must in eval mode"
# assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
# assert bias_for_infer is not None, "conv-bn-relu must be fused"
# assert self._int8_input_scale is not None
# if features.dtype != torch.int8:
# # quantize
# features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
# output_scale = self._int8_output_scale
# int8_out_scale = output_scale
# if int8_out_scale is None:
# int8_out_scale = 1
# if add_input is not None:
# assert add_input.int8_scale is not None, "only support int8 add"
# output_add_scale = add_input.int8_scale
# if self._int8_weight.numel() == 0:
# with torch.no_grad():
# assert ALL_WEIGHT_IS_KRSC
# weight_scale = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
# num_1s = [1] * (self.ndim + 1)
# self._int8_weight = (weight / weight_scale.view(self.out_channels, *num_1s) * 127).to(torch.int8)
# if self._int8_weight_scale.numel() == 0:
# self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scale)
# self._int8_bias = bias_for_infer * int8_out_scale
if
training
:
msg
=
"act don't support backward, only used in inference"
assert
self
.
act_type
==
tv
.
gemm
.
Activation
.
None_
,
msg
...
...
@@ -340,9 +315,9 @@ class SparseConvolutionBase:
algo
,
input
.
_timer
,
bias_for_infer
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
act_alpha
,
act_beta
,
act_type
)
else
:
if
self
.
inverse
:
out_features
=
Fsp
.
indice_inverse_conv
(
...
...
@@ -354,9 +329,9 @@ class SparseConvolutionBase:
algo
,
input
.
_timer
,
bias_for_infer
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
act_alpha
,
act_beta
,
act_type
)
else
:
out_features
=
Fsp
.
indice_conv
(
features
,
...
...
@@ -367,10 +342,9 @@ class SparseConvolutionBase:
algo
,
input
.
_timer
,
bias_for_infer
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
act_type
,
act_beta
,
act_type
)
else
:
datas
=
input
.
find_indice_pair
(
self
.
indice_key
)
if
datas
is
not
None
:
...
...
@@ -490,9 +464,9 @@ class SparseConvolutionBase:
num_activate_out
,
masks
,
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
,
bias_cur
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
)
act_alpha
,
act_beta
,
act_type
)
else
:
output_dtype
=
None
if
output_scale
is
None
:
...
...
@@ -503,9 +477,9 @@ class SparseConvolutionBase:
num_activate_out
,
masks
,
training
,
self
.
subm
,
input
.
_timer
,
self
.
fp32_accum
,
bias_cur
,
self
.
act_alpha
,
self
.
act_beta
,
self
.
act_type
,
act_alpha
,
act_beta
,
act_type
,
# TODO do we really need output scale to scale bias in kernel?
1.0
if
output_scale
is
None
else
output_scale
,
# output_scale
channel_scale
,
# scale
...
...
@@ -764,446 +738,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
name
=
self
.
name
,
sparse_unique_name
=
self
.
_sparse_unique_name
,
act_type
=
self
.
act_type
,
act_alpha
=
self
.
act_alpha
,
act_beta
=
self
.
act_beta
)
# def _conv_forward(self, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None,
# channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None):
# assert isinstance(input, SparseConvTensor)
# is_int8 = input.is_quantized and weight.is_quantized
# if is_int8:
# assert output_scale is not None and channel_scale is not None, "int8 must be called in static quantized module"
# assert bias is not None, "currently you must specify a bias"
# assert input.features.shape[
# 1] == self.in_channels, "channel size mismatch"
# features = input.features
# device = features.device
# indices = input.indices
# spatial_shape = input.spatial_shape
# batch_size = input.batch_size
# bias_for_training = bias if self.training else None
# bias_for_infer = bias if not self.training else None
# output_add_scale = 1.0
# if is_int8:
# if add_input is not None:
# output_add_scale = add_input.q_scale()
# # if self.enable_int8_test_mode:
# # assert not self.training, "must in eval mode"
# # assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
# # assert bias_for_infer is not None, "conv-bn-relu must be fused"
# # assert self._int8_input_scale is not None
# # if features.dtype != torch.int8:
# # # quantize
# # features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
# # output_scale = self._int8_output_scale
# # int8_out_scale = output_scale
# # if int8_out_scale is None:
# # int8_out_scale = 1
# # if add_input is not None:
# # assert add_input.int8_scale is not None, "only support int8 add"
# # output_add_scale = add_input.int8_scale
# # if self._int8_weight.numel() == 0:
# # with torch.no_grad():
# # assert ALL_WEIGHT_IS_KRSC
# # weight_scale = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
# # num_1s = [1] * (self.ndim + 1)
# # self._int8_weight = (weight / weight_scale.view(self.out_channels, *num_1s) * 127).to(torch.int8)
# # if self._int8_weight_scale.numel() == 0:
# # self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scale)
# # self._int8_bias = bias_for_infer * int8_out_scale
# if self.training:
# msg = "act don't support backward, only used in inference"
# assert self.act_type == tv.gemm.Activation.None_, msg
# if not self.subm:
# if self.transposed:
# out_spatial_shape = ops.get_deconv_output_size(
# spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding)
# else:
# out_spatial_shape = ops.get_conv_output_size(
# spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation)
# else:
# out_spatial_shape = spatial_shape
# # print(self._sparse_unique_name, spatial_shape, out_spatial_shape)
# # input.update_grid(out_spatial_shape)
# # t = time.time()
# out_tensor = input.shadow_copy()
# if input.benchmark:
# if self.name is None:
# raise ValueError(
# "you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
# )
# if self.name not in input.benchmark_record:
# input.benchmark_record[self.name] = {
# "type": "SparseConvolution",
# "indice_gen_time": [],
# "time": [],
# "num_points": [],
# "num_out_points": [],
# "params": {
# "kernel_size": self.kernel_size,
# "stride": self.stride,
# "padding": self.padding,
# "dilation": self.dilation,
# "output_padding": self.output_padding,
# "subm": self.subm,
# "transposed": self.transposed,
# "input_channels": self.in_channels,
# "out_channels": self.out_channels,
# }
# }
# if self.conv1x1 and not is_int8:
# # in int8 test mode, we don't implement conv1x1 via mm.
# if FILTER_HWIO:
# features = torch.mm(
# input.features,
# weight.view(self.out_channels, self.in_channels).T)
# else:
# features = torch.mm(
# input.features,
# weight.view(self.in_channels, self.out_channels))
# if bias is not None:
# features += bias
# out_tensor = out_tensor.replace_feature(features)
# # padding may change spatial shape of conv 1x1.
# out_tensor.spatial_shape = out_spatial_shape
# return out_tensor
# indice_dict = input.indice_dict.copy()
# # only support contiguous tensor for now
# if not features.is_contiguous():
# features = features.contiguous()
# algo = self.algo
# if self.indice_key is not None:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
# assert algo == datas.algo, msg
# # algo = datas.algo
# profile_ctx = nullcontext()
# if input._timer is not None and self._sparse_unique_name:
# profile_ctx = input._timer.namespace(self._sparse_unique_name)
# with profile_ctx:
# if algo == ConvAlgo.Native:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# assert isinstance(datas, IndiceData)
# if self.inverse:
# assert datas is not None and self.indice_key is not None
# assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
# outids = datas.indices
# indice_pairs = datas.indice_pairs
# indice_pair_num = datas.indice_pair_num
# out_spatial_shape = datas.spatial_shape
# self._check_inverse_reuse_valid(input, spatial_shape,
# datas)
# else:
# if self.indice_key is not None and datas is not None:
# outids = datas.out_indices
# indice_pairs = datas.indice_pairs
# indice_pair_num = datas.indice_pair_num
# assert self.subm, "only support reuse subm indices"
# self._check_subm_reuse_valid(input, spatial_shape,
# datas)
# else:
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# try:
# outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
# indices, batch_size, spatial_shape, algo,
# self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding, self.subm,
# self.transposed)
# except Exception as e:
# msg = "[Exception|native_pair]"
# msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
# msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
# msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
# msg += f"transpose={self.transposed}"
# print(msg, file=sys.stderr)
# spconv_save_debug_data(indices)
# raise e
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[
# self.name]["indice_gen_time"].append(interval)
# indice_data = IndiceData(outids,
# indices,
# indice_pairs,
# indice_pair_num,
# spatial_shape,
# out_spatial_shape,
# is_subm=self.subm,
# algo=algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation)
# if self.indice_key is not None:
# msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
# assert self.indice_key not in indice_dict, msg
# indice_dict[self.indice_key] = indice_data
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# indice_pairs_calc = indice_pairs
# if indice_pairs.device != features.device:
# indice_pairs_calc = indice_pairs.to(features.device)
# if self.subm:
# out_features = Fsp.indice_subm_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# if self.inverse:
# out_features = Fsp.indice_inverse_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# out_features = Fsp.indice_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# assert isinstance(datas, ImplicitGemmIndiceData)
# if self.inverse:
# assert datas is not None and self.indice_key is not None
# assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
# outids = datas.indices
# pair_fwd = datas.pair_bwd
# pair_bwd = datas.pair_fwd
# pair_mask_fwd_splits = datas.pair_mask_bwd_splits
# pair_mask_bwd_splits = datas.pair_mask_fwd_splits
# mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
# mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
# masks = datas.masks
# out_spatial_shape = datas.spatial_shape
# # assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
# self._check_inverse_reuse_valid(input, spatial_shape,
# datas)
# else:
# if self.indice_key is not None and datas is not None:
# outids = datas.out_indices
# pair_fwd = datas.pair_fwd
# pair_bwd = datas.pair_bwd
# pair_mask_fwd_splits = datas.pair_mask_fwd_splits
# pair_mask_bwd_splits = datas.pair_mask_bwd_splits
# mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
# mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
# masks = datas.masks
# assert self.subm, "only support reuse subm indices"
# self._check_subm_reuse_valid(input, spatial_shape,
# datas)
# else:
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# with input._timer.namespace("gen_pairs"):
# # we need to gen bwd indices for regular conv
# # because it may be inversed.
# try:
# res = ops.get_indice_pairs_implicit_gemm(
# indices,
# batch_size,
# spatial_shape,
# algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation,
# out_padding=self.output_padding,
# subm=self.subm,
# transpose=self.transposed,
# is_train=(not self.subm) or self.training,
# alloc=input.thrust_allocator,
# timer=input._timer)
# except Exception as e:
# msg = "[Exception|implicit_gemm_pair]"
# msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
# msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
# msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
# msg += f"transpose={self.transposed}"
# print(msg, file=sys.stderr)
# spconv_save_debug_data(indices)
# raise e
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[
# self.name]["indice_gen_time"].append(interval)
# outids = res[0]
# num_inds_per_loc = res[1]
# pair_fwd = res[2]
# pair_bwd = res[3]
# pair_mask_fwd_splits = res[4]
# pair_mask_bwd_splits = res[5]
# mask_argsort_fwd_splits = res[6]
# mask_argsort_bwd_splits = res[7]
# masks = res[8]
# if self.indice_key is not None:
# indice_data = ImplicitGemmIndiceData(
# outids,
# indices,
# pair_fwd,
# pair_bwd,
# pair_mask_fwd_splits=pair_mask_fwd_splits,
# pair_mask_bwd_splits=pair_mask_bwd_splits,
# mask_argsort_fwd_splits=mask_argsort_fwd_splits,
# mask_argsort_bwd_splits=mask_argsort_bwd_splits,
# masks=masks,
# is_subm=self.subm,
# spatial_shape=spatial_shape,
# out_spatial_shape=out_spatial_shape,
# algo=algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation)
# msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
# assert self.indice_key not in indice_dict, msg
# indice_dict[self.indice_key] = indice_data
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# num_activate_out = outids.shape[0]
# weight_cur = weight
# bias_cur = bias_for_infer
# # if self.enable_int8_test_mode:
# # assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
# # weight_cur = self._int8_weight
# # bias_cur = self._int8_bias
# if self.training:
# out_features = Fsp.implicit_gemm(
# features, weight_cur, pair_fwd, pair_bwd,
# pair_mask_fwd_splits, pair_mask_bwd_splits,
# mask_argsort_fwd_splits, mask_argsort_bwd_splits,
# num_activate_out, masks, self.training, self.subm,
# input._timer, self.fp32_accum,
# bias_cur,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# output_dtype = None
# if output_scale is None:
# output_dtype = weight.dtype
# out_features, _, _ = ops.implicit_gemm(
# features, weight_cur, pair_fwd, pair_mask_fwd_splits,
# mask_argsort_fwd_splits,
# num_activate_out, masks, self.training, self.subm,
# input._timer, self.fp32_accum,
# bias_cur,
# self.act_alpha,
# self.act_beta,
# self.act_type,
# # TODO do we really need output scale to scale bias in kernel?
# 1.0 if output_scale is None else output_scale, # output_scale
# channel_scale, # scale
# output_add=add_input.features if add_input is not None else None,
# output_add_scale=output_add_scale,
# output_dtype=output_dtype)
# if bias_for_training is not None:
# out_features += bias_for_training
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[self.name]["time"].append(interval)
# out_tensor.benchmark_record[self.name]["num_points"].append(
# features.shape[0])
# out_tensor.benchmark_record[self.name]["num_out_points"].append(
# out_features.shape[0])
# if not self.subm and not self.inverse and self.record_voxel_count:
# if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
# ops.maximum_value_int_(
# getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING),
# outids.shape[0])
# out_tensor = out_tensor.replace_feature(out_features)
# out_tensor.indices = outids
# out_tensor.indice_dict = indice_dict
# out_tensor.spatial_shape = out_spatial_shape
# if add_input is not None and not is_int8:
# # in int8, we apply add + act in kernel.
# out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
# return out_tensor
# def _check_subm_reuse_valid(self, inp: SparseConvTensor,
# spatial_shape: List[int],
# datas: Union[ImplicitGemmIndiceData,
# IndiceData]):
# assert datas.is_subm, "only support reuse subm indices"
# if self.kernel_size != datas.ksize:
# raise ValueError(
# f"subm with same indice_key must have same kernel"
# f" size, expect {datas.ksize}, this layer {self.kernel_size}")
# if self.dilation != datas.dilation:
# raise ValueError(
# f"subm with same indice_key must have same dilation"
# f", expect {datas.dilation}, this layer {self.dilation}")
# if inp.spatial_shape != datas.spatial_shape:
# raise ValueError(
# f"subm with same indice_key must have same spatial structure"
# f", expect {datas.spatial_shape}, input {spatial_shape}")
# if inp.indices.shape[0] != datas.indices.shape[0]:
# raise ValueError(
# f"subm with same indice_key must have same num of indices"
# f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
# )
# def _check_inverse_reuse_valid(self, inp: SparseConvTensor,
# spatial_shape: List[int],
# datas: Union[ImplicitGemmIndiceData,
# IndiceData]):
# if self.kernel_size != datas.ksize:
# raise ValueError(
# f"Inverse with same indice_key must have same kernel"
# f" size, expect {datas.ksize}, this layer {self.kernel_size}, "
# "please check Inverse Convolution in docs/USAGE.md.")
# if inp.spatial_shape != datas.out_spatial_shape:
# raise ValueError(
# f"Inverse with same indice_key must have same spatial structure (spatial shape)"
# f", expect {datas.spatial_shape}, input {spatial_shape}, "
# "please check Inverse Convolution in docs/USAGE.md.")
# if inp.indices.shape[0] != datas.out_indices.shape[0]:
# raise ValueError(
# f"Inverse with same indice_key must have same num of indices"
# f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, "
# "please check Inverse Convolution in ."
# )
class
SparseConv1d
(
SparseConvolution
):
def
__init__
(
self
,
...
...
spconv/pytorch/core.py
View file @
e387ee74
...
...
@@ -233,6 +233,9 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
features_th
=
x_sp
.
values
()
return
cls
(
features_th
,
indices_th
,
spatial_shape
,
batch_size
)
def
dequantize
(
self
):
return
self
.
replace_feature
(
self
.
features
.
dequantize
())
@
property
def
spatial_size
(
self
):
return
np
.
prod
(
self
.
spatial_shape
)
...
...
@@ -264,6 +267,19 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
# return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size
def
__add__
(
self
,
other
:
"SparseConvTensor"
):
assert
isinstance
(
other
,
SparseConvTensor
)
return
self
.
replace_feature
(
self
.
features
+
other
.
features
)
def
__iadd__
(
self
,
other
:
"SparseConvTensor"
):
assert
isinstance
(
other
,
SparseConvTensor
)
self
.
features
+=
other
.
features
return
self
def
__radd__
(
self
,
other
:
"SparseConvTensor"
):
assert
isinstance
(
other
,
SparseConvTensor
)
return
other
.
replace_feature
(
self
.
features
+
other
.
features
)
def
shadow_copy
(
self
)
->
"SparseConvTensor"
:
"""create a new spconv tensor with all member unchanged"""
tensor
=
SparseConvTensor
(
self
.
features
,
self
.
indices
,
...
...
spconv/pytorch/modules.py
View file @
e387ee74
...
...
@@ -23,7 +23,7 @@ from spconv import pytorch as spconv
def
is_spconv_module
(
module
):
spconv_modules
=
(
SparseModule
,
)
spconv_modules
=
(
SparseModule
,
SparseBatchNorm
,
SparseReLU
)
return
isinstance
(
module
,
spconv_modules
)
...
...
@@ -148,3 +148,37 @@ def assign_name_for_sparse_modules(module: nn.Module):
for
k
,
n
in
module
.
named_modules
():
if
isinstance
(
n
,
SparseModule
):
n
.
_sparse_unique_name
=
k
class
SparseBatchNorm
(
nn
.
BatchNorm1d
):
"""this module is exists only for torch.fx transformation for quantization.
"""
def
forward
(
self
,
input
):
if
isinstance
(
input
,
spconv
.
SparseConvTensor
):
return
input
.
replace_feature
(
super
().
forward
(
input
.
features
))
return
super
().
forward
(
input
)
class
SparseSyncBatchNorm
(
nn
.
SyncBatchNorm
):
"""this module is exists only for torch.fx transformation for quantization.
"""
def
forward
(
self
,
input
):
if
isinstance
(
input
,
spconv
.
SparseConvTensor
):
return
input
.
replace_feature
(
super
().
forward
(
input
.
features
))
return
super
().
forward
(
input
)
class
SparseReLU
(
nn
.
ReLU
):
"""this module is exists only for torch.fx transformation for quantization.
"""
def
forward
(
self
,
input
):
if
isinstance
(
input
,
spconv
.
SparseConvTensor
):
return
input
.
replace_feature
(
super
().
forward
(
input
.
features
))
return
super
().
forward
(
input
)
class
SparseIdentity
(
nn
.
Identity
):
"""this module is exists only for torch.fx transformation for quantization.
"""
def
forward
(
self
,
input
):
if
isinstance
(
input
,
spconv
.
SparseConvTensor
):
return
input
.
replace_feature
(
super
().
forward
(
input
.
features
))
return
super
().
forward
(
input
)
spconv/pytorch/ops.py
View file @
e387ee74
...
...
@@ -1462,14 +1462,14 @@ def implicit_gemm(features: torch.Tensor,
output_scale
:
float
=
1.0
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_add
:
Optional
[
torch
.
Tensor
]
=
None
,
output_add_scale
:
float
=
1
.0
,
output_add_scale
:
float
=
0
.0
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
stream
=
get_current_stream
()
bias_tv
=
tv
.
Tensor
()
scale_tv
=
tv
.
Tensor
()
output_add_tv
=
tv
.
Tensor
()
if
output_add
is
not
None
:
assert
features
.
dtype
==
torch
.
int8
,
"fused residual add only support int8"
assert
features
.
dtype
==
torch
.
q
int8
,
"fused residual add only support int8"
if
bias
is
not
None
:
bias_tv
=
torch_tensor_to_tv
(
bias
)
if
scale
is
not
None
:
...
...
@@ -1485,7 +1485,7 @@ def implicit_gemm(features: torch.Tensor,
output_dtype
=
features
.
dtype
if
SPCONV_CPP_GEMM
and
CONV_CPP
is
not
None
:
alloc
=
TorchAllocator
(
features
.
device
)
alloc
=
TorchAllocator
(
features
.
device
,
features
.
dtype
==
torch
.
qint8
)
features_tv
=
torch_tensor_to_tv
(
features
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
pair_mask_fwd_splits_tv
=
[
...
...
@@ -1963,9 +1963,15 @@ def indice_maxpool_implicit_gemm(features: torch.Tensor,
features
=
features
.
contiguous
()
out_channel
=
features
.
shape
[
-
1
]
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
dtype
=
features
.
dtype
,
device
=
features
.
device
)
if
features
.
is_quantized
:
out_features
=
torch
.
_empty_affine_quantized
((
num_activate_out
,
out_channel
),
scale
=
features
.
q_scale
(),
dtype
=
features
.
dtype
,
device
=
features
.
device
)
else
:
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
dtype
=
features
.
dtype
,
device
=
features
.
device
)
assert
features
.
is_cuda
stream
=
get_current_stream
()
out_features_tv
=
torch_tensor_to_tv
(
out_features
)
...
...
@@ -2016,9 +2022,16 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
features
=
features
.
contiguous
()
out_channel
=
features
.
shape
[
-
1
]
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
dtype
=
features
.
dtype
,
device
=
features
.
device
)
if
features
.
is_quantized
:
out_features
=
torch
.
_empty_affine_quantized
((
num_activate_out
,
out_channel
),
scale
=
features
.
q_scale
(),
dtype
=
features
.
dtype
,
device
=
features
.
device
)
else
:
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
dtype
=
features
.
dtype
,
device
=
features
.
device
)
assert
features
.
is_cuda
stream
=
get_current_stream
()
out_features_tv
=
torch_tensor_to_tv
(
out_features
)
...
...
spconv/pytorch/pool.py
View file @
e387ee74
...
...
@@ -66,14 +66,14 @@ class SparseMaxPool(SparseModule):
if
algo
is
None
:
# keep in mind that this algorithm is set for Inverse Sparse Conv
# maxpool itself don't need mask.
if
kv
<=
32
and
not
CPU_ONLY_BUILD
:
if
kv
<=
128
and
not
CPU_ONLY_BUILD
:
if
kv
<
8
:
algo
=
ConvAlgo
.
MaskImplicitGemm
else
:
algo
=
ConvAlgo
.
MaskImplicitGemm
else
:
algo
=
ConvAlgo
.
Native
if
kv
>
32
:
if
kv
>
128
:
assert
algo
==
ConvAlgo
.
Native
,
"implicit gemm don't support kv >= 32 for now"
if
CPU_ONLY_BUILD
:
assert
algo
==
ConvAlgo
.
Native
,
"cpu only build only support native algorithm"
...
...
@@ -96,7 +96,10 @@ class SparseMaxPool(SparseModule):
return
None
def
forward
(
self
,
input
):
def
forward
(
self
,
input
:
spconv
.
SparseConvTensor
):
is_int8
=
input
.
is_quantized
if
is_int8
:
assert
self
.
algo
==
ConvAlgo
.
MaskImplicitGemm
,
"only ConvAlgo.MaskImplicitGemm support int8."
assert
isinstance
(
input
,
spconv
.
SparseConvTensor
)
features
=
input
.
features
device
=
features
.
device
...
...
@@ -296,6 +299,10 @@ class SparseAvgPool(SparseModule):
def
forward
(
self
,
input
):
assert
isinstance
(
input
,
spconv
.
SparseConvTensor
)
is_int8
=
input
.
is_quantized
if
is_int8
:
assert
self
.
algo
==
ConvAlgo
.
MaskImplicitGemm
,
"only ConvAlgo.MaskImplicitGemm support int8."
features
=
input
.
features
device
=
features
.
device
indices
=
input
.
indices
...
...
@@ -534,3 +541,8 @@ class SparseAvgPool3d(SparseAvgPool):
algo
=
algo
,
record_voxel_count
=
record_voxel_count
,
name
=
name
)
ALL_POOL_LAYERS
=
set
([
SparseAvgPool3d
,
SparseAvgPool2d
,
SparseAvgPool1d
,
SparseMaxPool1d
,
SparseMaxPool2d
,
SparseMaxPool3d
,
SparseMaxPool4d
,
SparseAvgPool
,
SparseMaxPool
])
\ No newline at end of file
spconv/pytorch/quantization/__init__.py
View file @
e387ee74
...
...
@@ -19,3 +19,4 @@ from .fake_q import (get_default_spconv_trt_ptq_qconfig,
get_default_spconv_trt_qat_qconfig
)
from
.qmapping
import
(
get_spconv_fmod_to_qat_mapping
,
get_spconv_qat_to_static_mapping
)
from
.core
import
quantize_per_tensor
\ No newline at end of file
spconv/pytorch/quantization/backend_cfg.py
View file @
e387ee74
from
collections
import
namedtuple
from
typing
import
List
,
Dict
,
Union
,
Type
,
Tuple
import
operator
from
typing
import
Dict
,
List
,
Tuple
,
Type
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.ao.quantization.fx.match_utils
import
(
MatchAllNode
,
)
from
torch.ao.nn.quantized.modules.utils
import
WeightedQuantizedModule
from
torch.ao.quantization.backend_config
import
(
BackendConfig
,
BackendPatternConfig
,
DTypeConfig
,
ObservationType
,
...
...
@@ -11,56 +15,184 @@ from torch.ao.quantization.backend_config import (BackendConfig,
from
torch.ao.quantization.fx.custom_config
import
(
ConvertCustomConfig
,
FuseCustomConfig
,
PrepareCustomConfig
)
from
torch.ao.nn.quantized.modules.utils
import
WeightedQuantizedModule
import
spconv.pytorch.conv
as
sconvmod
from
spconv.pytorch.modules
import
SparseBatchNorm
,
SparseIdentity
,
SparseReLU
,
SparseSyncBatchNorm
import
spconv.pytorch.quantization.intrinsic
as
snni
import
spconv.pytorch.quantization.intrinsic.qat
as
snniqat
import
spconv.pytorch.quantization.intrinsic.quantized
as
snniq
import
spconv.pytorch.quantization.quantized
as
snnq
import
spconv.pytorch.quantization.quantized.reference
as
snnqr
from
spconv.pytorch
import
ToDense
from
spconv.pytorch.constants
import
PYTORCH_VERSION
from
spconv.pytorch.pool
import
ALL_POOL_LAYERS
from
spconv.pytorch.quantization.fuse_mapping
import
(
fuse_conv_bn
,
fuse_conv_bn_relu
)
from
spconv.pytorch
import
ToDense
fuse_conv_bn_relu
,
fuse_conv_bn_add_relu
)
_SpConvMetadataDef
=
namedtuple
(
"_ConvMetadata"
,
[
"root"
,
"bn"
,
"reference"
,
"fused_conv_relu"
,
"fused_conv_bn"
,
"fused_conv_bn_relu"
,
"qat"
,
"relu_qat"
,
"bn_qat"
,
"bn_relu_qat"
])
_SpConvMetadataDef
=
namedtuple
(
"_ConvMetadata"
,
[
"root"
,
"bn"
,
"reference"
,
"fused_conv_relu"
,
"fused_conv_bn"
,
"fused_conv_bn_relu"
,
"fused_conv_add_relu"
,
"fused_conv_bn_add_relu"
,
"qat"
,
"relu_qat"
,
"bn_qat"
,
"bn_relu_qat"
,
"add_relu_qat"
,
"bn_add_relu_qat"
])
_SpConvMetadatas
:
List
[
_SpConvMetadataDef
]
=
[]
for
t
in
sconvmod
.
DEFAULT_SPARSE_CONV_TYPES
:
_SpConvMetadatas
.
append
(
_SpConvMetadataDef
(
t
,
nn
.
BatchNorm1d
,
snnqr
.
SpConv
,
snni
.
SpconvReLUNd
,
snni
.
SpconvBnNd
,
snni
.
SpconvBnReLUNd
,
snniqat
.
SparseConv
,
snniqat
.
SparseConvReLU
,
snniqat
.
SparseConvBn
,
snniqat
.
SparseConvBnReLU
))
_SpConvMetadatas
.
append
(
_SpConvMetadataDef
(
t
,
nn
.
BatchNorm1d
,
snnqr
.
SpConv
,
snni
.
SpconvReLUNd
,
snni
.
SpconvBnNd
,
snni
.
SpconvBnReLUNd
,
snni
.
SpconvAddReLUNd
,
snni
.
SpconvBnAddReLUNd
,
snniqat
.
SparseConv
,
snniqat
.
SparseConvReLU
,
snniqat
.
SparseConvBn
,
snniqat
.
SparseConvBnReLU
,
snniqat
.
SparseConvAddReLU
,
snniqat
.
SparseConvBnAddReLU
))
_SpConvMetadatas
.
append
(
_SpConvMetadataDef
(
sconvmod
.
SparseConvolution
,
nn
.
BatchNorm1d
,
snnqr
.
SpConv
,
snni
.
SpconvReLUNd
,
snni
.
SpconvBnNd
,
snni
.
SpconvBnReLUNd
,
snni
.
SpconvAddReLUNd
,
snni
.
SpconvBnAddReLUNd
,
snniqat
.
SparseConv
,
snniqat
.
SparseConvReLU
,
snniqat
.
SparseConvBn
,
snniqat
.
SparseConvBnReLU
,
snniqat
.
SparseConvAddReLU
,
snniqat
.
SparseConvBnAddReLU
))
_SpConvMetadatas
.
append
(
_SpConvMetadataDef
(
sconvmod
.
SparseConvolution
,
nn
.
BatchNorm1d
,
snnqr
.
SpConv
,
snni
.
SpconvReLUNd
,
snni
.
SpconvBnNd
,
snni
.
SpconvBnReLUNd
,
snniqat
.
SparseConv
,
snniqat
.
SparseConvReLU
,
snniqat
.
SparseConvBn
,
snniqat
.
SparseConvBnReLU
))
def
_sequential_wrapper2
(
sequential
):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules
"""
def
fuser_method
(
is_qat
,
m1
,
m2
):
return
sequential
(
m1
,
m2
)
return
fuser_method
# new cfg remove reverse pattern.
def
_get_spconv_configs
(
dtype_configs
):
def
_conv_bn_res_relu_root_node_getter
(
pattern
):
relu
,
add_pattern
=
pattern
_
,
bn_pattern
,
_
=
add_pattern
bn
,
conv
=
bn_pattern
return
conv
def
_conv_bn_res_relu_extra_inputs_getter
(
pattern
):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu
,
add_pattern
=
pattern
_
,
bn_pattern
,
extra_input
=
add_pattern
bn
,
conv
=
bn_pattern
return
[
extra_input
]
def
_conv_res_relu_root_node_getter
(
pattern
):
relu
,
add_pattern
=
pattern
_
,
conv
,
_
=
add_pattern
return
conv
def
_conv_res_relu_extra_inputs_getter
(
pattern
):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu
,
add_pattern
=
pattern
_
,
conv
,
extra_input
=
add_pattern
return
[
extra_input
]
def
_get_bn_spconv_configs
(
bn_cls
,
dtype_configs
):
"""
Return all configs related to conv modules and ops.
"""
conv_configs
=
[]
observation_type
=
ObservationType
.
OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
if
PYTORCH_VERSION
<=
[
1
,
13
,
1
]:
if
PYTORCH_VERSION
[:
2
]
<=
[
1
,
13
]:
from
torch.ao.quantization.fuser_method_mappings
import
(
reverse2
,
reverse3
,
reverse_sequential_wrapper2
)
for
convs
in
_SpConvMetadatas
:
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs
.
append
(
BackendPatternConfig
((
bn_cls
,
convs
.
root
)).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse2
(
fuse_conv_bn
)).
set_fused_module
(
convs
.
fused_conv_bn
))
# conv + bn + relu module fusion
for
relu_type
in
[
torch
.
nn
.
ReLU
,
F
.
relu
,
SparseReLU
]:
conv_configs
.
append
(
BackendPatternConfig
(
(
relu_type
,
(
bn_cls
,
convs
.
root
))).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse3
(
fuse_conv_bn_relu
)).
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# 5.1 fuse conv + bn + add + relu to one op
for
add_op
in
[
torch
.
add
,
operator
.
add
]:
for
relu_op
in
[
SparseReLU
]:
conv_configs
.
append
(
BackendPatternConfig
((
relu_op
,
(
add_op
,
(
bn_cls
,
convs
.
root
),
MatchAllNode
)))
.
set_dtype_configs
(
dtype_configs
)
# .set_root_module(convs.root)
.
set_fuser_method
(
fuse_conv_bn_add_relu
)
\
.
_set_root_node_getter
(
_conv_bn_res_relu_root_node_getter
)
\
.
_set_extra_inputs_getter
(
_conv_bn_res_relu_extra_inputs_getter
)
.
set_fused_module
(
convs
.
fused_conv_bn_add_relu
))
return
conv_configs
else
:
for
convs
in
_SpConvMetadatas
:
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs
.
append
(
BackendPatternConfig
(
(
convs
.
root
,
bn_cls
)).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
fuse_conv_bn
).
set_fused_module
(
convs
.
fused_conv_bn
))
# conv + bn + relu module fusion
for
relu_type
in
[
torch
.
nn
.
ReLU
,
F
.
relu
,
SparseReLU
]:
conv_configs
.
append
(
BackendPatternConfig
(
(
convs
.
root
,
bn_cls
,
relu_type
)).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
fuse_conv_bn_relu
).
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# (5) conv add and its fusion
# 5.1 fuse conv + bn + add + relu to one op
for
add_op
in
[
torch
.
add
,
operator
.
add
]:
for
relu_op
in
[
SparseReLU
]:
conv_configs
.
append
(
BackendPatternConfig
()
\
.
_set_pattern_complex_format
((
relu_op
,
(
add_op
,
(
bn_cls
,
convs
.
root
),
MatchAllNode
)))
.
set_dtype_configs
(
dtype_configs
)
# .set_root_module(convs.root)
.
set_fuser_method
(
fuse_conv_bn_add_relu
)
\
.
_set_root_node_getter
(
_conv_bn_res_relu_root_node_getter
)
\
.
_set_extra_inputs_getter
(
_conv_bn_res_relu_extra_inputs_getter
)
.
set_fused_module
(
convs
.
fused_conv_bn_add_relu
))
return
conv_configs
def
_get_spconv_configs
(
dtype_configs
):
"""
Return all configs related to conv modules and ops.
"""
conv_configs
=
(
_get_bn_spconv_configs
(
SparseBatchNorm
,
dtype_configs
)
+
_get_bn_spconv_configs
(
nn
.
BatchNorm1d
,
dtype_configs
)
+
_get_bn_spconv_configs
(
SparseSyncBatchNorm
,
dtype_configs
))
observation_type
=
ObservationType
.
OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
if
PYTORCH_VERSION
[:
2
]
<=
[
1
,
13
]:
from
torch.ao.quantization.fuser_method_mappings
import
(
reverse2
,
reverse3
,
reverse_sequential_wrapper2
)
for
convs
in
_SpConvMetadatas
:
...
...
@@ -68,114 +200,90 @@ def _get_spconv_configs(dtype_configs):
# -----------------------------------
# conv module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
root
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
qat
))
BackendPatternConfig
(
convs
.
root
).
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
).
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
).
set_qat_module
(
convs
.
qat
))
# conv qat module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs
.
append
(
BackendPatternConfig
((
torch
.
nn
.
ReLU
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# conv relu fusion, conv module + functional relu
conv_configs
.
append
(
BackendPatternConfig
((
F
.
relu
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
for
relu_type
in
[
torch
.
nn
.
ReLU
,
F
.
relu
,
SparseReLU
]:
conv_configs
.
append
(
BackendPatternConfig
(
(
relu_type
,
convs
.
root
)).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse_sequential_wrapper2
(
convs
.
fused_conv_relu
)).
set_fused_module
(
convs
.
fused_conv_relu
))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
relu_qat
))
BackendPatternConfig
(
convs
.
fused_conv_relu
).
set_
observation_type
(
observation_type
)
# noqa: E131
.
set_
dtype_configs
(
dtype_configs
).
set_
root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
relu_qat
))
# conv relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
# 2.3 functional conv + relu configs
# fused conv relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
relu_qat
))
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
bn
,
convs
.
root
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse2
(
fuse_conv_bn
))
.
set_fused_module
(
convs
.
fused_conv_bn
))
# conv + bn + relu module fusion
conv_configs
.
append
(
BackendPatternConfig
((
nn
.
ReLU
,
(
convs
.
bn
,
convs
.
root
)))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
reverse3
(
fuse_conv_bn_relu
))
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# conv + bn + relu functional fusion
conv_configs
.
append
(
BackendPatternConfig
((
F
.
relu
,
(
convs
.
bn
,
convs
.
root
)))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_fuser_method
(
reverse3
(
fuse_conv_bn_relu
))
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_qat
))
BackendPatternConfig
(
convs
.
fused_conv_bn
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_qat
))
# fused conv bn relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_relu_qat
))
BackendPatternConfig
(
convs
.
fused_conv_bn_relu
).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_relu_qat
))
# conv bn, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
bn_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
# conv bn relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
bn_relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
# (4) conv transpose and its fusion
# 4.1 conv transpose config
...
...
@@ -192,6 +300,56 @@ def _get_spconv_configs(dtype_configs):
# .set_fuser_method(reverse2(fuse_conv_bn))
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# (5) conv add and its fusion
# 5.1 fuse conv + bn + add + relu to one op
for
add_op
in
[
torch
.
add
,
operator
.
add
]:
for
relu_op
in
[
SparseReLU
]:
conv_configs
.
append
(
BackendPatternConfig
((
relu_op
,
(
add_op
,
convs
.
root
,
MatchAllNode
)))
.
set_dtype_configs
(
dtype_configs
)
# .set_root_module(convs.root)
.
set_fuser_method
(
reverse_sequential_wrapper2
(
convs
.
fused_conv_add_relu
))
\
.
_set_root_node_getter
(
_conv_res_relu_root_node_getter
)
\
.
_set_extra_inputs_getter
(
_conv_res_relu_extra_inputs_getter
)
.
set_fused_module
(
convs
.
fused_conv_add_relu
))
# 5.2 fused add
# fused conv bn relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_add_relu
).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
add_relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn_add_relu
).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_add_relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_add_relu
).
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
).
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
).
set_qat_module
(
convs
.
add_relu_qat
))
# conv bn, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
add_relu_qat
).
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
).
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_add_relu_qat
).
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
).
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
return
conv_configs
else
:
for
convs
in
_SpConvMetadatas
:
...
...
@@ -199,114 +357,102 @@ def _get_spconv_configs(dtype_configs):
# -----------------------------------
# conv module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
root
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
qat
))
BackendPatternConfig
(
convs
.
root
).
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
).
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
).
set_qat_module
(
convs
.
qat
))
# conv qat module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
torch
.
nn
.
ReLU
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
# conv relu fusion, conv module + functional relu
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
F
.
relu
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
_sequential_wrapper2
(
convs
.
fused_conv_relu
))
.
set_fused_module
(
convs
.
fused_conv_relu
))
for
relu_type
in
[
torch
.
nn
.
ReLU
,
F
.
relu
,
SparseReLU
]:
conv_configs
.
append
(
BackendPatternConfig
(
(
convs
.
root
,
relu_type
)).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
_sequential_wrapper2
(
convs
.
fused_conv_relu
)).
set_fused_module
(
convs
.
fused_conv_relu
))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
relu_qat
))
BackendPatternConfig
(
convs
.
fused_conv_relu
).
set_
observation_type
(
observation_type
)
# noqa: E131
.
set_
dtype_configs
(
dtype_configs
).
set_
root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
)
.
set_qat_module
(
convs
.
relu_qat
))
# conv relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
# fused conv relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
relu_qat
))
BackendPatternConfig
(
convs
.
fused_conv_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
relu_qat
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
fuse_conv_bn
)
.
set_fused_module
(
convs
.
fused_conv_bn
))
# conv + bn + relu module fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
,
nn
.
ReLU
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_fuser_method
(
fuse_conv_bn_relu
)
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# conv + bn + relu functional fusion
conv_configs
.
append
(
BackendPatternConfig
((
convs
.
root
,
convs
.
bn
,
F
.
relu
))
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_root_module
(
convs
.
root
)
.
set_fuser_method
(
fuse_conv_bn_relu
)
.
set_fused_module
(
convs
.
fused_conv_bn_relu
))
# # conv + bn + relu functional fusion
# conv_configs.append(
# BackendPatternConfig((convs.root, convs.bn, F.relu))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.root)
# .set_fuser_method(fuse_conv_bn_relu)
# .set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_qat
))
BackendPatternConfig
(
convs
.
fused_conv_bn
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_qat
))
# fused conv bn relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn_relu
)
.
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_relu_qat
))
BackendPatternConfig
(
convs
.
fused_conv_bn_relu
).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_relu_qat
))
# conv bn, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
bn_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
# conv bn relu, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
)
.
set_reference_quantized_module
(
convs
.
reference
))
BackendPatternConfig
(
convs
.
bn_relu_qat
)
.
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
)
.
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
# # (4) conv transpose and its fusion
# # 4.1 conv transpose config
...
...
@@ -323,38 +469,123 @@ def _get_spconv_configs(dtype_configs):
# .set_fuser_method(fuse_conv_bn)
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# (5) conv add and its fusion
# 5.1 fuse conv + bn + add + relu to one op
for
add_op
in
[
torch
.
add
,
operator
.
add
]:
for
relu_op
in
[
SparseReLU
]:
conv_configs
.
append
(
BackendPatternConfig
()
\
.
_set_pattern_complex_format
((
relu_op
,
(
add_op
,
convs
.
root
,
MatchAllNode
)))
.
set_dtype_configs
(
dtype_configs
)
# .set_root_module(convs.root)
.
set_fuser_method
(
_sequential_wrapper2
(
convs
.
fused_conv_add_relu
))
\
.
_set_root_node_getter
(
_conv_res_relu_root_node_getter
)
\
.
_set_extra_inputs_getter
(
_conv_res_relu_extra_inputs_getter
)
.
set_fused_module
(
convs
.
fused_conv_add_relu
))
# 5.2 fused add
# fused conv bn relu
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_add_relu
).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
add_relu_qat
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
fused_conv_bn_add_relu
).
set_dtype_configs
(
dtype_configs
)
# noqa: E131
.
set_qat_module
(
convs
.
bn_add_relu_qat
))
# conv bn, qat fused module
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
add_relu_qat
).
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
).
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
conv_configs
.
append
(
BackendPatternConfig
(
convs
.
bn_add_relu_qat
).
set_observation_type
(
observation_type
)
# noqa: E131
.
set_dtype_configs
(
dtype_configs
).
set_root_module
(
convs
.
root
).
set_reference_quantized_module
(
convs
.
reference
))
return
conv_configs
def
_get_share_observer_ops
(
dtype_configs
):
res
:
List
[
BackendPatternConfig
]
=
[]
_to_dense_cfg
=
(
BackendPatternConfig
(
ToDense
).
set_observation_type
(
ObservationType
.
OUTPUT_SHARE_OBSERVER_WITH_INPUT
).
set_dtype_configs
(
dtype_configs
))
iden_cfg
=
(
BackendPatternConfig
(
SparseIdentity
).
set_observation_type
(
ObservationType
.
OUTPUT_SHARE_OBSERVER_WITH_INPUT
).
set_dtype_configs
(
dtype_configs
))
res
.
append
(
_to_dense_cfg
)
res
.
append
(
iden_cfg
)
for
p
in
ALL_POOL_LAYERS
:
_pool_cfg
=
(
BackendPatternConfig
(
p
).
set_observation_type
(
ObservationType
.
OUTPUT_SHARE_OBSERVER_WITH_INPUT
).
set_dtype_configs
(
dtype_configs
))
res
.
append
(
_pool_cfg
)
return
res
weighted_op_qint8_dtype_config
=
DTypeConfig
(
input_dtype
=
torch
.
qint8
,
output_dtype
=
torch
.
qint8
,
weight_dtype
=
torch
.
qint8
,
bias_dtype
=
torch
.
float
,
)
non_weighted_op_qint8_dtype_config
=
DTypeConfig
(
input_dtype
=
torch
.
qint8
,
output_dtype
=
torch
.
qint8
,
)
conv_dtype_configs
=
[
weighted_op_qint8_dtype_config
,
]
_to_dense_cfg
=
(
BackendPatternConfig
(
ToDense
)
.
set_observation_type
(
ObservationType
.
OUTPUT_SHARE_OBSERVER_WITH_INPUT
))
backend_config
=
get_tensorrt_backend_config
()
\
.
set_backend_pattern_configs
(
_get_spconv_configs
(
conv_dtype_configs
)
+
[
_to_dense_cfg
])
.
set_backend_pattern_configs
(
_get_spconv_configs
(
conv_dtype_configs
)
+
_get_share_observer_ops
([
non_weighted_op_qint8_dtype_config
]))
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Tuple
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]]
=
{
snni
.
SpconvReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvReLU
),
snni
.
SpconvAddReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvAddReLU
),
}
SPCONV_STATIC_LOWER_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]
=
{
snnqr
.
SpConv
:
snnq
.
SparseConv
,
}
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Tuple
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]]
=
{
snni
.
SpconvReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvReLU
),
}
def
get_spconv_backend_config
():
return
backend_config
def
get_spconv_prepare_custom_config
():
cfg
=
PrepareCustomConfig
()
cfg
.
non_traceable_module_classes
=
[
*
sconvmod
.
DEFAULT_SPARSE_CONV_TYPES
]
return
cfg
cfg
.
non_traceable_module_classes
.
extend
(
[
SparseReLU
,
SparseBatchNorm
,
SparseSyncBatchNorm
])
return
cfg
def
get_spconv_convert_custom_config
():
cfg
=
ConvertCustomConfig
()
cfg
.
set_observed_to_quantized_mapping
(
snni
.
SpconvReLUNd
,
snniq
.
SparseConvReLU
)
cfg
.
set_observed_to_quantized_mapping
(
snni
.
SpconvReLUNd
,
snniq
.
SparseConvReLU
)
cfg
.
set_observed_to_quantized_mapping
(
snni
.
SpconvAddReLUNd
,
snniq
.
SparseConvReLU
)
# cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU)
return
cfg
\ No newline at end of file
return
cfg
spconv/pytorch/quantization/core.py
0 → 100644
View file @
e387ee74
from
typing
import
Union
,
List
,
Dict
import
torch
from
spconv.pytorch.core
import
SparseConvTensor
def
quantize_per_tensor
(
ten
:
Union
[
Union
[
SparseConvTensor
,
torch
.
Tensor
],
List
[
Union
[
SparseConvTensor
,
torch
.
Tensor
]]],
scale
,
zero_point
,
dtype
):
if
isinstance
(
ten
,
(
list
,
tuple
)):
res
=
[]
for
i
,
v
in
enumerate
(
ten
):
if
isinstance
(
v
,
SparseConvTensor
):
res
.
append
(
v
.
replace_feature
(
torch
.
quantize_per_tensor
(
v
.
features
,
scale
[
i
],
zero_point
[
i
],
dtype
)))
else
:
res
.
append
(
torch
.
quantize_per_tensor
(
v
,
scale
[
i
],
zero_point
[
i
],
dtype
))
return
res
else
:
if
isinstance
(
ten
,
SparseConvTensor
):
return
ten
.
replace_feature
(
torch
.
quantize_per_tensor
(
ten
.
features
,
scale
,
zero_point
,
dtype
))
else
:
return
torch
.
quantize_per_tensor
(
ten
,
scale
,
zero_point
,
dtype
)
\ No newline at end of file
spconv/pytorch/quantization/fake_q.py
View file @
e387ee74
...
...
@@ -11,7 +11,7 @@ from torch.ao.quantization.observer import (HistogramObserver,
from
torch.ao.quantization.qconfig
import
QConfig
,
QConfigAny
,
default_reuse_input_qconfig
from
torch.ao.quantization.qconfig_mapping
import
QConfigMapping
,
_FIXED_QPARAMS_OP_TO_OBSERVER
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Union
,
List
from
torch.ao.quantization
import
get_default_qconfig
from
torch.ao.quantization
import
get_default_qconfig
,
get_default_qat_qconfig
from
spconv.pytorch.core
import
SparseConvTensor
__all__
=
[
"get_default_spconv_trt_ptq_qconfig"
,
"get_default_spconv_trt_qat_qconfig"
]
...
...
@@ -80,13 +80,14 @@ def get_default_spconv_trt_ptq_qconfig(backend, version):
def
get_default_spconv_trt_qat_qconfig
(
backend
,
version
):
return
default_symmetric_spconv_qat_qconfig
def
get_default_spconv_qconfig_mapping
(
is_qat
:
bool
,
backend
:
str
=
"
x86
"
,
version
:
int
=
0
)
->
QConfigMapping
:
def
get_default_spconv_qconfig_mapping
(
is_qat
:
bool
,
backend
:
str
=
"
fbgemm
"
,
version
:
int
=
0
)
->
QConfigMapping
:
"""
From torch.ao.quantization.qconfig_mapping
Return the default QConfigMapping for the given quantization type and backend.
"""
# get_default_qconfig(backend, version)
if
is_qat
:
# qconfig = get_default_qat_qconfig(backend, version)
qconfig
=
get_default_spconv_trt_qat_qconfig
(
backend
,
version
)
else
:
# qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8),
...
...
@@ -144,3 +145,4 @@ def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "x86", versi
.
set_object_type
(
torch
.
nn
.
functional
.
tanh
,
qconfig
)
return
qconfig_mapping
spconv/pytorch/quantization/fuse_mapping.py
View file @
e387ee74
from
functools
import
partial
from
typing
import
Union
,
Callable
,
Tuple
,
Dict
,
Optional
,
Type
,
Any
import
torch.nn
as
nn
import
spconv.pytorch
as
spconv
...
...
@@ -5,7 +6,8 @@ from .utils import fuse_spconv_bn_eval
from
.
import
intrinsic
as
snni
from
.intrinsic.qat.modules
import
SparseConvBn
,
SparseConvBnReLU
,
SparseConvBnAddReLU
from
spconv.pytorch.conv
import
DEFAULT_SPARSE_CONV_TYPES
def
fuse_conv_bn
(
is_qat
,
conv
,
bn
):
def
fuse_conv_bn
(
is_qat
,
conv
,
bn
,
is_add_fuse
:
bool
=
False
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
...
...
@@ -20,11 +22,10 @@ def fuse_conv_bn(is_qat, conv, bn):
"""
assert
(
conv
.
training
==
bn
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
fuse_cls
=
snni
.
SpconvAddReLUNd
if
is_add_fuse
else
snni
.
SpconvBnNd
fused_module_class_map
=
{
k
:
snni
.
SpconvBnNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
k
:
fuse_cls
for
k
in
DEFAULT_SPARSE_CONV_TYPES
}
if
is_qat
:
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
assert
bn
.
affine
,
'Only support fusing BatchNorm2d with affine set to True'
...
...
@@ -37,7 +38,7 @@ def fuse_conv_bn(is_qat, conv, bn):
else
:
return
fuse_spconv_bn_eval
(
conv
,
bn
)
def
fuse_conv_bn_relu
(
is_qat
,
conv
,
bn
,
relu
):
def
fuse_conv_bn_relu
(
is_qat
,
conv
,
bn
,
relu
,
is_add_fuse
:
bool
=
False
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
...
...
@@ -54,8 +55,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
"Conv and BN both must be in the same mode (train or eval)."
fused_module
:
Optional
[
Type
[
spconv
.
SparseSequential
]]
=
None
if
is_qat
:
fuse_cls
=
snni
.
SpconvBnAddReLUNd
if
is_add_fuse
else
snni
.
SpconvBnReLUNd
map_to_fused_module_train
=
{
k
:
snni
.
SpconvBnReLUNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
k
:
fuse_cls
for
k
in
DEFAULT_SPARSE_CONV_TYPES
}
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv must match num_features of BatchNorm'
assert
bn
.
affine
,
'Only support fusing BatchNorm with affine set to True'
...
...
@@ -66,8 +68,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
else
:
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
,
relu
)))
else
:
fuse_cls
=
snni
.
SpconvAddReLUNd
if
is_add_fuse
else
snni
.
SpconvReLUNd
map_to_fused_module_eval
=
{
k
:
snni
.
SpconvReLUNd
for
k
in
DEFAULT_SPARSE_CONV_TYPES
k
:
fuse_cls
for
k
in
DEFAULT_SPARSE_CONV_TYPES
}
fused_module
=
map_to_fused_module_eval
.
get
(
type
(
conv
),
None
)
if
fused_module
is
not
None
:
...
...
@@ -76,28 +79,21 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
else
:
raise
NotImplementedError
(
"Cannot fuse eval modules: {}"
.
format
((
conv
,
bn
,
relu
)))
# DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
# (spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# }
# def get_spconv_fuse_method_mapping():
# return DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
def
fuse_conv_bn_add_relu
(
is_qat
,
relu
,
add_pattern
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
_
,
bn_pattern
,
_
=
add_pattern
bn
,
conv
=
bn_pattern
return
fuse_conv_bn_relu
(
is_qat
,
conv
,
bn
,
relu
,
True
)
# Default map for swapping float module to qat modules
spconv/pytorch/quantization/intrinsic/__init__.py
View file @
e387ee74
...
...
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.modules
import
SpconvBnNd
,
SpconvBnReLUNd
,
SpconvBnAddReLUNd
,
SpconvReLUNd
from
.modules
import
SpconvBnNd
,
SpconvBnReLUNd
,
SpconvBnAddReLUNd
,
SpconvReLUNd
,
SpconvAddReLUNd
spconv/pytorch/quantization/intrinsic/modules.py
View file @
e387ee74
...
...
@@ -60,3 +60,27 @@ class SpconvBnAddReLUNd(_FusedSparseModule):
isinstance
(
relu
,
ReLU
),
'Incorrect types for input modules{}{}{}'
\
.
format
(
type
(
conv
),
type
(
bn
),
type
(
relu
))
super
().
__init__
(
conv
,
bn
,
relu
)
def
forward
(
self
,
input
,
add_input
):
conv
=
self
[
0
]
bn
=
self
[
1
]
relu
=
self
[
2
]
conv_res
=
conv
(
input
)
conv_res
=
conv_res
.
replace_feature
(
bn
(
conv_res
.
features
))
return
conv_res
.
replace_feature
(
relu
(
conv_res
.
features
+
add_input
.
features
))
class
SpconvAddReLUNd
(
_FusedSparseModule
):
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
relu
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
relu
,
ReLU
),
\
'Incorrect types for input modules{}{}'
.
format
(
type
(
conv
),
type
(
relu
))
super
().
__init__
(
conv
,
relu
)
def
forward
(
self
,
input
,
add_input
):
conv
=
self
[
0
]
relu
=
self
[
1
]
conv_res
=
conv
(
input
)
return
conv_res
.
replace_feature
(
relu
(
conv_res
.
features
+
add_input
.
features
))
spconv/pytorch/quantization/intrinsic/qat/__init__.py
View file @
e387ee74
...
...
@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.modules
import
SparseConvBn
,
SparseConvBnAddReLU
,
SparseConvBnReLU
,
SparseConv
,
SparseConvReLU
\ No newline at end of file
from
.modules
import
(
SparseConv
,
SparseConvAddReLU
,
SparseConvBn
,
SparseConvBnAddReLU
,
SparseConvBnReLU
,
SparseConvReLU
)
spconv/pytorch/quantization/intrinsic/qat/modules.py
View file @
e387ee74
...
...
@@ -17,7 +17,7 @@ import spconv.pytorch.quantization.intrinsic as snni
from
spconv.pytorch.quantization.utils
import
fuse_spconv_bn_weights
MOD
=
TypeVar
(
'MOD'
,
bound
=
SparseConvolution
)
class
_SparseConv
(
SparseConvolution
,
nni
.
_FusedModule
):
class
_SparseConv
(
SparseConvolution
):
_FLOAT_MODULE
=
MOD
_FLOAT_CONV_MODULE
=
SparseConvolution
...
...
@@ -67,7 +67,7 @@ class _SparseConv(SparseConvolution, nni._FusedModule):
self
.
weight_fake_quant
=
qconfig
.
weight
(
factory_kwargs
=
factory_kwargs
)
def
forward
(
self
,
input
):
return
self
.
_conv_forward
(
False
,
input
,
self
.
weight_fake_quant
(
self
.
weight
),
self
.
bias
)
return
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight_fake_quant
(
self
.
weight
),
self
.
bias
)
@
staticmethod
def
from_float
(
cls
,
mod
):
...
...
@@ -77,11 +77,12 @@ class _SparseConv(SparseConvolution, nni._FusedModule):
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert
type
(
mod
)
==
cls
.
_FLOAT_MODULE
,
(
assert
issubclass
(
type
(
mod
)
,
cls
.
_FLOAT_MODULE
)
,
(
"qat."
+
cls
.
__name__
+
".from_float only works for "
+
cls
.
_FLOAT_MODULE
.
__name__
# type: ignore[attr-defined]
+
f
" not
{
type
(
mod
).
__qualname__
}
"
)
assert
hasattr
(
mod
,
'qconfig'
),
'Input float module must have qconfig defined'
assert
mod
.
qconfig
,
'Input float module must have a valid qconfig'
...
...
@@ -197,6 +198,33 @@ class SparseConvReLU(SparseConv, nni._FusedModule):
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvReLU
,
cls
).
from_float
(
mod
)
class
SparseConvAddReLU
(
SparseConv
,
nni
.
_FusedModule
):
r
"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE
=
snni
.
SpconvAddReLUNd
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
None
_FLOAT_RELU_MODULE
=
nn
.
ReLU
def
forward
(
self
,
input
,
add_input
):
x
=
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight_fake_quant
(
self
.
weight
),
self
.
bias
,
add_input
=
add_input
)
return
x
.
replace_feature
(
F
.
relu
(
x
.
features
))
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvAddReLU
,
cls
).
from_float
(
mod
)
class
_SparseConvBn
(
SparseConvolution
,
nni
.
_FusedModule
):
_version
=
2
...
...
@@ -323,9 +351,9 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
scaled_weight
.
device
,
dtype
=
input
.
features
.
dtype
)
conv_spt
=
self
.
_conv_forward
(
self
.
training
,
input
,
scaled_weight
,
zero_bias
)
conv
=
conv_spt
.
features
conv_orig
=
conv
/
scale_factor
.
reshape
(
bias_shape
)
conv_orig
=
conv
/
scale_factor
#
.reshape(bias_shape)
if
self
.
bias
is
not
None
:
conv_orig
=
conv_orig
+
self
.
bias
.
reshape
(
bias_shape
)
conv_orig
=
conv_orig
+
self
.
bias
#
.reshape(bias_shape)
conv
=
self
.
bn
(
conv_orig
)
if
add_input
is
not
None
:
conv
=
conv
+
add_input
.
features
...
...
@@ -377,7 +405,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
conv_out
=
torch
.
Tensor
()
if
self
.
bn
.
training
:
# needed to compute batch mean/std
conv_spt
=
self
.
_conv_forward
(
input
,
self
.
weight
,
zero_bias
)
conv_spt
=
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight
,
zero_bias
)
conv_out
=
conv_spt
.
features
# update bn statistics
with
torch
.
no_grad
():
...
...
@@ -393,7 +421,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
self
.
weight
*
scale_factor
.
reshape
(
weight_shape
)
)
# fused conv without bias for inference: (r * W / running_std) * X
conv_bn_spt
=
self
.
_conv_forward
(
input
,
scaled_weight
,
zero_bias
)
conv_bn_spt
=
self
.
_conv_forward
(
self
.
training
,
input
,
scaled_weight
,
zero_bias
)
conv_bn
=
conv_bn_spt
.
features
if
self
.
bn
.
training
:
avg_dims
=
[
0
]
+
list
(
range
(
2
,
len
(
self
.
weight
.
shape
)))
...
...
@@ -669,12 +697,12 @@ class SparseConvBnAddReLU(_SparseConvBn):
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE
=
snni
.
SpconvBnReLUNd
# type: ignore[assignment]
_FLOAT_MODULE
=
snni
.
SpconvBn
Add
ReLUNd
# type: ignore[assignment]
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
nn
.
BatchNorm1d
_FLOAT_RELU_MODULE
=
nn
.
ReLU
# type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE
=
snni
.
SpconvReLUNd
_FUSED_FLOAT_MODULE
=
snni
.
Spconv
Add
ReLUNd
def
forward
(
self
,
input
,
add_input
):
x
=
_SparseConvBn
.
_forward
(
self
,
input
,
add_input
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment