Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
aa26c99e
Commit
aa26c99e
authored
Dec 29, 2022
by
yan.yan
Browse files
working on quantization
parent
ee8c9465
Changes
29
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
965 additions
and
133 deletions
+965
-133
spconv/pytorch/quantization/__init__.py
spconv/pytorch/quantization/__init__.py
+0
-0
spconv/pytorch/quantization/conv_fused.py
spconv/pytorch/quantization/conv_fused.py
+475
-0
spconv/pytorch/quantization/fake_q.py
spconv/pytorch/quantization/fake_q.py
+23
-0
spconv/pytorch/quantization/fuse_mapping.py
spconv/pytorch/quantization/fuse_mapping.py
+130
-0
spconv/pytorch/quantization/intrinsic.py
spconv/pytorch/quantization/intrinsic.py
+35
-0
spconv/pytorch/quantization/modules.py
spconv/pytorch/quantization/modules.py
+21
-0
spconv/pytorch/quantization/utils.py
spconv/pytorch/quantization/utils.py
+52
-0
spconv/test_utils.py
spconv/test_utils.py
+5
-5
test/test_all_algo.py
test/test_all_algo.py
+224
-128
No files found.
spconv/pytorch/quantization/__init__.py
0 → 100644
View file @
aa26c99e
spconv/pytorch/quantization/conv_fused.py
0 → 100644
View file @
aa26c99e
# torch.ao.nn.intrinsic.qat.modules.conv_fused
import
math
import
torch
import
torch.nn
as
nn
import
torch.ao.nn.intrinsic
as
nni
import
torch.ao.nn.qat
as
nnqat
import
torch.nn.functional
as
F
from
torch.nn
import
init
from
torch.nn.utils
import
fuse_conv_bn_weights
from
torch.nn.modules.utils
import
_single
,
_pair
,
_triple
from
torch.nn.parameter
import
Parameter
from
typing
import
TypeVar
from
spconv.pytorch.conv
import
SparseConvolution
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
spconv.core
import
ConvAlgo
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.core
import
SparseConvTensor
import
spconv.pytorch.quantization.intrinsic
as
snni
MOD
=
TypeVar
(
'MOD'
,
bound
=
SparseConvolution
)
class
_SparseConvBn
(
SparseConvolution
,
nni
.
_FusedModule
):
_version
=
2
_FLOAT_MODULE
=
MOD
_FLOAT_CONV_MODULE
=
SparseConvolution
def
__init__
(
self
,
# SparseConvolution args
ndim
:
int
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
3
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
transposed
:
bool
=
False
,
inverse
:
bool
=
False
,
indice_key
:
Optional
[
str
]
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
name
=
None
,
# BatchNormNd args
# num_features: out_channels
eps
=
1e-05
,
momentum
=
0.1
,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn
=
False
,
qconfig
=
None
):
SparseConvolution
.
__init__
(
self
,
ndim
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
=
False
,
subm
=
subm
,
output_padding
=
output_padding
,
transposed
=
transposed
,
inverse
=
inverse
,
indice_key
=
indice_key
,
algo
=
algo
,
fp32_accum
=
fp32_accum
,
record_voxel_count
=
record_voxel_count
,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
name
=
name
)
assert
qconfig
,
'qconfig must be provided for QAT module'
self
.
qconfig
=
qconfig
self
.
freeze_bn
=
freeze_bn
if
self
.
training
else
True
self
.
bn
=
nn
.
BatchNorm1d
(
out_channels
,
eps
,
momentum
,
True
,
True
)
self
.
weight_fake_quant
=
self
.
qconfig
.
weight
()
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
out_channels
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_bn_parameters
()
# this needs to be called after reset_bn_parameters,
# as they modify the same state
if
self
.
training
:
if
freeze_bn
:
self
.
freeze_bn_stats
()
else
:
self
.
update_bn_stats
()
else
:
self
.
freeze_bn_stats
()
self
.
_enable_slow_path_for_better_numerical_stability
=
False
def
reset_running_stats
(
self
):
self
.
bn
.
reset_running_stats
()
def
reset_bn_parameters
(
self
):
self
.
bn
.
reset_running_stats
()
init
.
uniform_
(
self
.
bn
.
weight
)
init
.
zeros_
(
self
.
bn
.
bias
)
# note: below is actully for conv, not BN
if
self
.
bias
is
not
None
:
fan_in
,
_
=
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
reset_parameters
(
self
):
super
(
_SparseConvBn
,
self
).
reset_parameters
()
def
update_bn_stats
(
self
):
self
.
freeze_bn
=
False
self
.
bn
.
training
=
True
return
self
def
freeze_bn_stats
(
self
):
self
.
freeze_bn
=
True
self
.
bn
.
training
=
False
return
self
def
_forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
assert
not
self
.
_enable_slow_path_for_better_numerical_stability
if
self
.
_enable_slow_path_for_better_numerical_stability
:
return
self
.
_forward_slow
(
input
)
return
self
.
_forward_approximate
(
input
,
add_input
)
def
_forward_approximate
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
"""Approximated method to fuse conv and bn. It requires only one forward pass.
conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
"""
assert
self
.
bn
.
running_var
is
not
None
running_std
=
torch
.
sqrt
(
self
.
bn
.
running_var
+
self
.
bn
.
eps
)
scale_factor
=
self
.
bn
.
weight
/
running_std
weight_shape
=
[
1
]
*
len
(
self
.
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
self
.
weight
.
shape
)
bias_shape
[
1
]
=
-
1
scaled_weight
=
self
.
weight_fake_quant
(
self
.
weight
*
scale_factor
.
reshape
(
weight_shape
))
# using zero bias here since the bias for original conv
# will be added later
if
self
.
bias
is
not
None
:
zero_bias
=
torch
.
zeros_like
(
self
.
bias
,
dtype
=
input
.
features
.
dtype
)
else
:
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
scaled_weight
.
device
,
dtype
=
input
.
features
.
dtype
)
conv_spt
=
self
.
_conv_forward
(
input
,
scaled_weight
,
zero_bias
)
conv
=
conv_spt
.
features
conv_orig
=
conv
/
scale_factor
.
reshape
(
bias_shape
)
if
self
.
bias
is
not
None
:
conv_orig
=
conv_orig
+
self
.
bias
.
reshape
(
bias_shape
)
conv
=
self
.
bn
(
conv_orig
)
if
add_input
is
not
None
:
conv
=
conv
+
add_input
.
features
conv_spt
=
conv_spt
.
replace_feature
(
conv
)
return
conv_spt
def
_forward_slow
(
self
,
input
:
SparseConvTensor
):
"""
TODO not implemented for now
A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
It requires two forward passes but handles the case bn.weight == 0
Conv: Y = WX + B_c
Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
Batch statistics:
mean_Y = Y.mean()
= Y0.mean() + B_c
var_Y = (Y - mean_Y)^2.mean()
= (Y0 - Y0.mean())^2.mean()
BN (r: bn.weight, beta: bn.bias):
Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
= r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
= (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
Fused Conv BN inference (running_std = sqrt(running_var + eps)):
Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
QAT with fused conv bn:
Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
= conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
"""
assert
self
.
bn
.
running_var
is
not
None
assert
self
.
bn
.
running_mean
is
not
None
# using zero bias here since the bias for original conv
# will be added later
zero_bias
=
torch
.
zeros
(
self
.
out_channels
,
device
=
self
.
weight
.
device
,
dtype
=
input
.
features
.
dtype
)
weight_shape
=
[
1
]
*
len
(
self
.
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
self
.
weight
.
shape
)
bias_shape
[
1
]
=
-
1
conv_out
=
torch
.
Tensor
()
if
self
.
bn
.
training
:
# needed to compute batch mean/std
conv_spt
=
self
.
_conv_forward
(
input
,
self
.
weight
,
zero_bias
)
conv_out
=
conv_spt
.
features
# update bn statistics
with
torch
.
no_grad
():
conv_out_bias
=
(
conv_out
if
self
.
bias
is
None
else
conv_out
+
self
.
bias
.
reshape
(
bias_shape
)
)
self
.
bn
(
conv_out_bias
)
# fused conv + bn without bias using bn running statistics
running_std
=
torch
.
sqrt
(
self
.
bn
.
running_var
+
self
.
bn
.
eps
)
scale_factor
=
self
.
bn
.
weight
/
running_std
scaled_weight
=
self
.
weight_fake_quant
(
self
.
weight
*
scale_factor
.
reshape
(
weight_shape
)
)
# fused conv without bias for inference: (r * W / running_std) * X
conv_bn_spt
=
self
.
_conv_forward
(
input
,
scaled_weight
,
zero_bias
)
conv_bn
=
conv_bn_spt
.
features
if
self
.
bn
.
training
:
avg_dims
=
[
0
]
+
list
(
range
(
2
,
len
(
self
.
weight
.
shape
)))
batch_mean
=
conv_out
.
mean
(
avg_dims
)
batch_var
=
torch
.
square
(
conv_out
-
batch_mean
.
reshape
(
bias_shape
)).
mean
(
avg_dims
)
batch_std
=
torch
.
sqrt
(
batch_var
+
self
.
bn
.
eps
)
# scale to use batch std in training mode
# conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
unscale_factor
=
running_std
/
batch_std
conv_bn
*=
unscale_factor
.
reshape
(
bias_shape
)
fused_mean
=
batch_mean
fused_std
=
batch_std
else
:
fused_mean
=
self
.
bn
.
running_mean
-
(
self
.
bias
if
self
.
bias
is
not
None
else
0
)
fused_std
=
running_std
# fused bias = beta - r * mean / std
fused_bias
=
self
.
bn
.
bias
-
self
.
bn
.
weight
*
fused_mean
/
fused_std
conv_bn
+=
fused_bias
.
reshape
(
bias_shape
)
# HACK to let conv bias particpiate in loss to avoid DDP error (parameters
# were not used in producing loss)
if
self
.
bias
is
not
None
:
conv_bn
+=
(
self
.
bias
-
self
.
bias
).
reshape
(
bias_shape
)
conv_bn_spt
=
conv_bn_spt
.
replace_feature
(
conv_bn
)
return
conv_bn_spt
return
conv_bn
def
extra_repr
(
self
):
# TODO(jerryzh): extend
return
super
(
_SparseConvBn
,
self
).
extra_repr
()
def
forward
(
self
,
input
):
return
self
.
_forward
(
input
)
def
train
(
self
,
mode
=
True
):
"""
Batchnorm's training behavior is using the self.training flag. Prevent
changing it if BN is frozen. This makes sure that calling `model.train()`
on a model with a frozen BN will behave properly.
"""
self
.
training
=
mode
if
not
self
.
freeze_bn
:
for
module
in
self
.
children
():
module
.
train
(
mode
)
return
self
# ===== Serialization version history =====
#
# Version 1/None
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- gamma : Tensor
# |--- beta : Tensor
# |--- running_mean : Tensor
# |--- running_var : Tensor
# |--- num_batches_tracked : Tensor
#
# Version 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- bn : Module
# |--- weight : Tensor (moved from v1.self.gamma)
# |--- bias : Tensor (moved from v1.self.beta)
# |--- running_mean : Tensor (moved from v1.self.running_mean)
# |--- running_var : Tensor (moved from v1.self.running_var)
# |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
version
=
local_metadata
.
get
(
'version'
,
None
)
if
version
is
None
or
version
==
1
:
# BN related parameters and buffers were moved into the BN module for v2
v2_to_v1_names
=
{
'bn.weight'
:
'gamma'
,
'bn.bias'
:
'beta'
,
'bn.running_mean'
:
'running_mean'
,
'bn.running_var'
:
'running_var'
,
'bn.num_batches_tracked'
:
'num_batches_tracked'
,
}
for
v2_name
,
v1_name
in
v2_to_v1_names
.
items
():
if
prefix
+
v1_name
in
state_dict
:
state_dict
[
prefix
+
v2_name
]
=
state_dict
[
prefix
+
v1_name
]
state_dict
.
pop
(
prefix
+
v1_name
)
elif
prefix
+
v2_name
in
state_dict
:
# there was a brief period where forward compatibility
# for this module was broken (between
# https://github.com/pytorch/pytorch/pull/38478
# and https://github.com/pytorch/pytorch/pull/38820)
# and modules emitted the v2 state_dict format while
# specifying that version == 1. This patches the forward
# compatibility issue by allowing the v2 style entries to
# be used.
pass
elif
strict
:
missing_keys
.
append
(
prefix
+
v2_name
)
super
(
_SparseConvBn
,
self
).
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
@
classmethod
def
from_float
(
cls
,
mod
):
r
"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
# The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
# has no __name__ (code is fine though)
assert
type
(
mod
)
==
cls
.
_FLOAT_MODULE
,
'qat.'
+
cls
.
__name__
+
'.from_float only works for '
+
\
cls
.
_FLOAT_MODULE
.
__name__
# type: ignore[attr-defined]
assert
hasattr
(
mod
,
'qconfig'
),
'Input float module must have qconfig defined'
assert
mod
.
qconfig
,
'Input float module must have a valid qconfig'
qconfig
=
mod
.
qconfig
conv
:
SparseConvolution
=
mod
[
0
]
bn
:
nn
.
BatchNorm1d
=
mod
[
1
]
qat_convbn
=
cls
(
conv
.
ndim
,
conv
.
in_channels
,
conv
.
out_channels
,
conv
.
kernel_size
,
conv
.
stride
,
conv
.
padding
,
conv
.
dilation
,
conv
.
groups
,
conv
.
bias
is
not
None
,
subm
=
conv
.
subm
,
output_padding
=
conv
.
output_padding
,
transposed
=
conv
.
transposed
,
inverse
=
conv
.
inverse
,
indice_key
=
conv
.
indice_key
,
algo
=
conv
.
algo
,
fp32_accum
=
conv
.
fp32_accum
,
record_voxel_count
=
conv
.
record_voxel_count
,
act_type
=
conv
.
act_type
,
act_alpha
=
conv
.
act_alpha
,
act_beta
=
conv
.
act_beta
,
name
=
conv
.
name
,
eps
=
bn
.
eps
,
momentum
=
bn
.
momentum
,
freeze_bn
=
False
,
qconfig
=
qconfig
)
qat_convbn
.
weight
=
conv
.
weight
qat_convbn
.
bias
=
conv
.
bias
qat_convbn
.
bn
.
weight
=
bn
.
weight
qat_convbn
.
bn
.
bias
=
bn
.
bias
qat_convbn
.
bn
.
running_mean
=
bn
.
running_mean
qat_convbn
.
bn
.
running_var
=
bn
.
running_var
# mypy error: Cannot determine type of 'num_batches_tracked'
qat_convbn
.
bn
.
num_batches_tracked
=
bn
.
num_batches_tracked
# type: ignore[has-type]
return
qat_convbn
def
to_float
(
self
):
cls
=
type
(
self
)
conv
=
cls
.
_FLOAT_CONV_MODULE
(
# type: ignore[attr-defined]
self
.
ndim
,
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
bias
is
not
None
,
subm
=
self
.
subm
,
output_padding
=
self
.
output_padding
,
transposed
=
self
.
transposed
,
inverse
=
self
.
inverse
,
indice_key
=
self
.
indice_key
,
algo
=
self
.
algo
,
fp32_accum
=
self
.
fp32_accum
,
record_voxel_count
=
self
.
record_voxel_count
,
act_type
=
self
.
act_type
,
act_alpha
=
self
.
act_alpha
,
act_beta
=
self
.
act_beta
,
name
=
self
.
name
)
conv
.
weight
=
torch
.
nn
.
Parameter
(
self
.
weight
.
detach
())
if
self
.
bias
is
not
None
:
conv
.
bias
=
torch
.
nn
.
Parameter
(
self
.
bias
.
detach
())
if
cls
.
_FLOAT_BN_MODULE
:
# type: ignore[attr-defined]
# fuse bn into conv
conv
.
weight
,
conv
.
bias
=
fuse_conv_bn_weights
(
conv
.
weight
,
conv
.
bias
,
self
.
bn
.
running_mean
,
self
.
bn
.
running_var
,
self
.
bn
.
eps
,
self
.
bn
.
weight
,
self
.
bn
.
bias
)
if
cls
.
_FLOAT_RELU_MODULE
:
# type: ignore[attr-defined]
modules
=
[]
modules
.
append
(
conv
)
relu
=
cls
.
_FLOAT_RELU_MODULE
()
# type: ignore[attr-defined]
modules
.
append
(
relu
)
conv_relu
=
cls
.
_FUSED_FLOAT_MODULE
(
*
modules
)
# type: ignore[attr-defined]
conv_relu
.
train
(
self
.
training
)
return
conv_relu
else
:
conv
.
train
(
self
.
training
)
return
conv
class
SparseConvBn
(
_SparseConvBn
):
r
"""
A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d`.
Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE
=
snni
.
SpconvBnNd
# type: ignore[assignment]
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
nn
.
BatchNorm1d
_FLOAT_RELU_MODULE
=
None
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE
=
snni
.
SpconvReLUNd
class
SparseConvBnReLU
(
_SparseConvBn
):
r
"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE
=
snni
.
SpconvBnReLUNd
# type: ignore[assignment]
_FLOAT_CONV_MODULE
=
SparseConvolution
_FLOAT_BN_MODULE
=
nn
.
BatchNorm1d
_FLOAT_RELU_MODULE
=
nn
.
ReLU
# type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE
=
snni
.
SpconvReLUNd
def
forward
(
self
,
input
):
x
=
_SparseConvBn
.
_forward
(
self
,
input
)
return
x
.
replace_feature
(
F
.
relu
(
x
.
features
))
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
(
SparseConvBnReLU
,
cls
).
from_float
(
mod
)
spconv/pytorch/quantization/fake_q.py
0 → 100644
View file @
aa26c99e
from
torch.ao.quantization.fake_quantize
import
FusedMovingAvgObsFakeQuantize
,
fused_wt_fake_quant_range_neg_127_to_127
from
spconv.pytorch.core
import
SparseConvTensor
import
torch
from
torch.ao.quantization.qconfig
import
QConfig
from
torch.ao.quantization.observer
import
MovingAverageMinMaxObserver
class
SparseFusedMovingAvgObsFakeQuantize
(
FusedMovingAvgObsFakeQuantize
):
def
forward
(
self
,
input
:
SparseConvTensor
):
# add lines to support spconv
x
=
input
.
features
res_features
=
super
().
forward
(
x
)
return
input
.
replace_feature
(
res_features
)
default_symmetric_spconv_qat_qconfig
=
QConfig
(
activation
=
SparseFusedMovingAvgObsFakeQuantize
.
with_args
(
observer
=
MovingAverageMinMaxObserver
,
quant_min
=-
128
,
quant_max
=
127
,
dtype
=
torch
.
qint8
,
reduce_range
=
False
,
eps
=
2
**
-
12
),
weight
=
fused_wt_fake_quant_range_neg_127_to_127
)
spconv/pytorch/quantization/fuse_mapping.py
0 → 100644
View file @
aa26c99e
from
typing
import
Union
,
Callable
,
Tuple
,
Dict
,
Optional
,
Type
,
Any
import
torch.nn
as
nn
import
spconv.pytorch
as
spconv
from
.utils
import
fuse_spconv_bn_eval
from
.
import
intrinsic
as
snni
from
.conv_fused
import
SparseConvBn
,
SparseConvBnReLU
def
fuse_conv_bn
(
conv
,
bn
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert
(
conv
.
training
==
bn
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvBnNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvBnNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvBnNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvBnNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvBnNd
,
}
if
conv
.
training
:
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
assert
bn
.
affine
,
'Only support fusing BatchNorm2d with affine set to True'
assert
bn
.
track_running_stats
,
'Only support fusing BatchNorm2d with tracking_running_stats set to True'
fused_module_class
=
fused_module_class_map
.
get
((
type
(
conv
)),
None
)
if
fused_module_class
is
not
None
:
return
fused_module_class
(
conv
,
bn
)
else
:
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
)))
else
:
return
fuse_spconv_bn_eval
(
conv
,
bn
)
def
fuse_conv_bn_relu
(
conv
,
bn
,
relu
):
r
"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert
(
conv
.
training
==
bn
.
training
==
relu
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
fused_module
:
Optional
[
Type
[
spconv
.
SparseSequential
]]
=
None
if
conv
.
training
:
map_to_fused_module_train
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvBnReLUNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvBnReLUNd
,
}
assert
bn
.
num_features
==
conv
.
out_channels
,
'Output channel of Conv must match num_features of BatchNorm'
assert
bn
.
affine
,
'Only support fusing BatchNorm with affine set to True'
assert
bn
.
track_running_stats
,
'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module
=
map_to_fused_module_train
.
get
(
type
(
conv
),
None
)
if
fused_module
is
not
None
:
return
fused_module
(
conv
,
bn
,
relu
)
else
:
raise
NotImplementedError
(
"Cannot fuse train modules: {}"
.
format
((
conv
,
bn
,
relu
)))
else
:
map_to_fused_module_eval
=
{
spconv
.
SubMConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv1d
:
snni
.
SpconvReLUNd
,
spconv
.
SubMConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv2d
:
snni
.
SpconvReLUNd
,
spconv
.
SubMConv3d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseConv3d
:
snni
.
SpconvReLUNd
,
spconv
.
SparseInverseConv3d
:
snni
.
SpconvReLUNd
,
}
fused_module
=
map_to_fused_module_eval
.
get
(
type
(
conv
),
None
)
if
fused_module
is
not
None
:
fused_conv
=
fuse_spconv_bn_eval
(
conv
,
bn
)
return
fused_module
(
fused_conv
,
relu
)
else
:
raise
NotImplementedError
(
"Cannot fuse eval modules: {}"
.
format
((
conv
,
bn
,
relu
)))
DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
:
Dict
[
Tuple
,
Union
[
nn
.
Sequential
,
Callable
]]
=
{
(
spconv
.
SubMConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SubMConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseInverseConv1d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseInverseConv1d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SubMConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SubMConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseInverseConv2d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseInverseConv2d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SubMConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SubMConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
(
spconv
.
SparseInverseConv3d
,
nn
.
BatchNorm1d
):
fuse_conv_bn
,
(
spconv
.
SparseInverseConv3d
,
nn
.
BatchNorm1d
,
nn
.
ReLU
):
fuse_conv_bn_relu
,
}
# Default map for swapping float module to qat modules
DEFAULT_SPCONV_QAT_MODULE_MAPPINGS
:
Dict
[
Callable
,
Any
]
=
{
# nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
snni
.
SpconvBnNd
:
SparseConvBn
,
snni
.
SpconvBnReLUNd
:
SparseConvBnReLU
,
}
spconv/pytorch/quantization/intrinsic.py
0 → 100644
View file @
aa26c99e
import
torch
from
torch.nn
import
Conv1d
,
Conv2d
,
Conv3d
,
ReLU
,
Linear
,
BatchNorm1d
,
BatchNorm2d
,
BatchNorm3d
from
torch.nn.utils.parametrize
import
type_before_parametrizations
import
torch.ao.nn.intrinsic
as
nni
from
spconv.pytorch.conv
import
SparseConvolution
class
SpconvReLUNd
(
nni
.
_FusedModule
):
r
"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
relu
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
relu
,
ReLU
),
\
'Incorrect types for input modules{}{}'
.
format
(
type
(
conv
),
type
(
relu
))
super
().
__init__
(
conv
,
relu
)
class
SpconvBnNd
(
nni
.
_FusedModule
):
r
"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
bn
,
BatchNorm1d
),
\
'Incorrect types for input modules{}{}'
.
format
(
type
(
conv
),
type
(
bn
))
super
().
__init__
(
conv
,
bn
)
class
SpconvBnReLUNd
(
nni
.
_FusedModule
):
r
"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def
__init__
(
self
,
conv
,
bn
,
relu
):
assert
isinstance
(
conv
,
SparseConvolution
)
and
isinstance
(
bn
,
BatchNorm1d
)
and
\
isinstance
(
relu
,
ReLU
),
'Incorrect types for input modules{}{}{}'
\
.
format
(
type
(
conv
),
type
(
bn
),
type
(
relu
))
super
().
__init__
(
conv
,
bn
,
relu
)
spconv/pytorch/quantization/modules.py
0 → 100644
View file @
aa26c99e
from
spconv.pytorch.modules
import
SparseModule
from
spconv.pytorch.conv
import
SparseConvolution
from
spconv.pytorch.core
import
SparseConvTensor
import
torch
class
ConvBatchNormAddAct
(
torch
.
nn
.
Module
):
"""for simple int8 residual op fusion, we can use this module to handle add.
"""
def
__init__
(
self
,
conv
:
SparseConvolution
,
bn
:
torch
.
nn
.
BatchNorm1d
,
act
:
torch
.
nn
.
ReLU
)
->
None
:
super
().
__init__
()
self
.
conv
=
conv
self
.
bn
=
bn
self
.
act
=
act
def
forward
(
self
,
x
:
SparseConvTensor
,
x_add
:
SparseConvTensor
):
x
=
self
.
conv
(
x
)
x
=
x
.
replace_feature
(
self
.
bn
(
x
.
features
))
return
self
.
act
(
x
.
replace_feature
(
x
.
features
+
x_add
.
features
))
spconv/pytorch/quantization/utils.py
0 → 100644
View file @
aa26c99e
import
torch
import
copy
from
cumm
import
tensorview
as
tv
def
fuse_spconv_bn_weights
(
conv_w_OKI
,
conv_b
,
bn_rm
,
bn_rv
,
bn_eps
,
bn_w
,
bn_b
):
NDim
=
conv_w_OKI
.
ndim
-
2
permute
=
[
0
,
NDim
+
1
]
+
[
i
+
1
for
i
in
range
(
NDim
)]
conv_w_OIK
=
conv_w_OKI
.
permute
(
*
permute
)
# OIDHW
if
conv_b
is
None
:
conv_b
=
torch
.
zeros_like
(
bn_rm
)
if
bn_w
is
None
:
bn_w
=
torch
.
ones_like
(
bn_rm
)
if
bn_b
is
None
:
bn_b
=
torch
.
zeros_like
(
bn_rm
)
bn_var_rsqrt
=
torch
.
rsqrt
(
bn_rv
+
bn_eps
)
conv_w_OIK
=
conv_w_OIK
*
(
bn_w
*
bn_var_rsqrt
).
reshape
([
-
1
]
+
[
1
]
*
(
len
(
conv_w_OIK
.
shape
)
-
1
))
conv_b
=
(
conv_b
-
bn_rm
)
*
bn_var_rsqrt
*
bn_w
+
bn_b
permute
=
[
0
,]
+
[
i
+
2
for
i
in
range
(
NDim
)]
+
[
1
,]
conv_w_OKI
=
conv_w_OIK
.
permute
(
*
permute
).
contiguous
()
return
torch
.
nn
.
Parameter
(
conv_w_OKI
),
torch
.
nn
.
Parameter
(
conv_b
)
def
fuse_spconv_bn_eval
(
conv
,
bn
):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert
(
not
(
conv
.
training
or
bn
.
training
)),
"Fusion only for eval!"
fused_conv
=
copy
.
deepcopy
(
conv
)
fused_conv
.
weight
,
fused_conv
.
bias
=
\
fuse_spconv_bn_weights
(
fused_conv
.
weight
,
fused_conv
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
bn
.
eps
,
bn
.
weight
,
bn
.
bias
)
return
fused_conv
def
fuse_spconv_act_eval
(
conv
,
act
):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert
(
not
(
conv
.
training
)),
"Fusion only for eval!"
fused_conv
=
copy
.
deepcopy
(
conv
)
if
isinstance
(
act
,
torch
.
nn
.
ReLU
):
fused_conv
.
act_type
=
tv
.
gemm
.
Activation
.
ReLU
elif
isinstance
(
act
,
torch
.
nn
.
LeakyReLU
):
fused_conv
.
act_type
=
tv
.
gemm
.
Activation
.
LeakyReLU
fused_conv
.
act_alpha
=
act
.
negative_slope
else
:
raise
NotImplementedError
return
fused_conv
spconv/test_utils.py
View file @
aa26c99e
...
@@ -53,7 +53,7 @@ class TestCase(unittest.TestCase):
...
@@ -53,7 +53,7 @@ class TestCase(unittest.TestCase):
print
(
"not equal rhs = "
,
y
)
print
(
"not equal rhs = "
,
y
)
np
.
testing
.
assert_array_equal
(
a
,
b
)
np
.
testing
.
assert_array_equal
(
a
,
b
)
def
assertAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
):
def
assertAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
,
msg
:
str
=
""
):
"""Asserts that two numpy arrays, or dicts of same, have near values.
"""Asserts that two numpy arrays, or dicts of same, have near values.
This does not support nested dicts.
This does not support nested dicts.
Args:
Args:
...
@@ -68,22 +68,22 @@ class TestCase(unittest.TestCase):
...
@@ -68,22 +68,22 @@ class TestCase(unittest.TestCase):
"""
"""
is_a_dict
=
isinstance
(
a
,
dict
)
is_a_dict
=
isinstance
(
a
,
dict
)
if
is_a_dict
!=
isinstance
(
b
,
dict
):
if
is_a_dict
!=
isinstance
(
b
,
dict
):
raise
ValueError
(
"Can't compare dict to non-dict, %s vs %s."
%
raise
ValueError
(
f
"Can't compare dict to non-dict, %s vs %s.
{
msg
}
"
%
(
a
,
b
))
(
a
,
b
))
if
is_a_dict
:
if
is_a_dict
:
self
.
assertCountEqual
(
a
.
keys
(),
self
.
assertCountEqual
(
a
.
keys
(),
b
.
keys
(),
b
.
keys
(),
msg
=
"mismatched keys, expected %s, got %s"
%
msg
=
f
"mismatched keys, expected %s, got %s
.
{
msg
}
"
%
(
a
.
keys
(),
b
.
keys
()))
(
a
.
keys
(),
b
.
keys
()))
for
k
in
a
:
for
k
in
a
:
self
.
_assertArrayLikeAllClose
(
a
[
k
],
self
.
_assertArrayLikeAllClose
(
a
[
k
],
b
[
k
],
b
[
k
],
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
,
atol
=
atol
,
msg
=
"%s: expected %s, got %s."
%
msg
=
f
"%s: expected %s, got %s.
{
msg
}
"
%
(
k
,
a
,
b
))
(
k
,
a
,
b
))
else
:
else
:
self
.
_assertArrayLikeAllClose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
self
.
_assertArrayLikeAllClose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
msg
)
def
_assertArrayLikeAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
,
msg
=
None
):
def
_assertArrayLikeAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
,
msg
=
None
):
a
=
self
.
_GetNdArray
(
a
)
a
=
self
.
_GetNdArray
(
a
)
...
...
test/test_all_algo.py
View file @
aa26c99e
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment