Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
e387ee74
Commit
e387ee74
authored
Jan 04, 2023
by
yan.yan
Browse files
sync quantization code
parent
b1c57a31
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
306 additions
and
18 deletions
+306
-18
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
+1
-1
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
+55
-8
spconv/pytorch/quantization/quantized/conv.py
spconv/pytorch/quantization/quantized/conv.py
+8
-6
spconv/pytorch/quantization/quantized/reference.py
spconv/pytorch/quantization/quantized/reference.py
+2
-2
spconv/pytorch/quantization/utils.py
spconv/pytorch/quantization/utils.py
+1
-0
test/debug/dev.py
test/debug/dev.py
+146
-0
test/debug/dev2.py
test/debug/dev2.py
+92
-0
test/test_all_algo.py
test/test_all_algo.py
+1
-1
No files found.
spconv/pytorch/quantization/intrinsic/quantized/__init__.py
View file @
e387ee74
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
View file @
e387ee74
...
...
@@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
from
spconv.pytorch.core
import
SparseConvTensor
import
spconv.pytorch.quantization.quantized
as
nnq
from
spconv.pytorch.quantization.intrinsic
import
SpconvReLUNd
from
spconv.pytorch.quantization.intrinsic
import
SpconvReLUNd
,
SpconvAddReLUNd
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.quantization.utils
import
fuse_spconv_bn_weights
import
spconv.pytorch.quantization.intrinsic.qat
as
snniqat
import
spconv.pytorch.quantization.intrinsic
as
snni
import
torch
__all__
=
[
"SparseConvReLU"
]
__all__
=
[
"SparseConvReLU"
,
"SparseConvAddReLU"
]
class
SparseConvReLU
(
nnq
.
SparseConv
):
r
"""
...
...
@@ -36,16 +39,18 @@ class SparseConvReLU(nnq.SparseConv):
def
forward
(
self
,
input
):
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
()
w_scales
=
self
.
weight
().
q_per_channel_scales
()
.
to
(
torch
.
float32
)
out_scale
=
self
.
scale
channel_scale
=
out_scale
/
(
inp_scale
*
w_scales
)
bias
=
self
.
bias
()
*
out_scale
return
self
.
_conv_forward
(
False
,
input
,
self
.
weight
(),
bias
,
channel_scale
=
channel_scale
,
output_scale
=
out_scale
,
channel_scale
=
(
inp_scale
*
w_scales
)
/
out_scale
scaled_bias
=
self
.
bias
()
/
out_scale
# print(bias.dtype, input.features.dtype, channel_scale.dtype, w_scales.dtype)
res
=
self
.
_conv_forward
(
False
,
input
,
self
.
weight
(),
scaled_bias
,
channel_scale
=
channel_scale
,
output_scale
=
out_scale
,
act_type
=
tv
.
gemm
.
Activation
.
ReLU
)
return
res
def
_get_name
(
self
):
return
'QuantizedConvReLU
1d
'
return
'Quantized
Sparse
ConvReLU'
@
classmethod
def
from_float
(
cls
,
mod
):
...
...
@@ -60,3 +65,45 @@ class SparseConvReLU(nnq.SparseConv):
assert
type
(
ref_qconv
)
!=
snni
.
SpconvBnReLUNd
,
\
"BatchNorm1d should be fused into Conv1d before converting to reference module"
return
super
().
from_reference
(
ref_qconv
[
0
],
output_scale
,
output_zero_point
)
class
SparseConvAddReLU
(
nnq
.
SparseConv
):
r
"""
A ConvReLU1d module is a fused module of Conv1d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
Attributes:
Same as torch.ao.nn.quantized.Conv1d
"""
_FLOAT_MODULE
=
SpconvAddReLUNd
# type: ignore[assignment]
def
forward
(
self
,
input
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
().
to
(
torch
.
float32
)
out_scale
=
self
.
scale
channel_scale
=
(
inp_scale
*
w_scales
)
/
out_scale
scaled_bias
=
self
.
bias
()
/
out_scale
# print(bias.dtype, input.features.dtype, channel_scale.dtype, w_scales.dtype)
res
=
self
.
_conv_forward
(
False
,
input
,
self
.
weight
(),
scaled_bias
,
channel_scale
=
channel_scale
,
output_scale
=
out_scale
,
act_type
=
tv
.
gemm
.
Activation
.
ReLU
,
add_input
=
add_input
)
return
res
def
_get_name
(
self
):
return
'QuantizedSparseConvReLU'
@
classmethod
def
from_float
(
cls
,
mod
):
if
type
(
mod
)
==
snniqat
.
SparseConvBnAddReLU
:
mod
.
weight
,
mod
.
bias
=
fuse_spconv_bn_weights
(
mod
.
weight
,
mod
.
bias
,
mod
.
bn
.
running_mean
,
mod
.
bn
.
running_var
,
mod
.
bn
.
eps
,
mod
.
bn
.
weight
,
mod
.
bn
.
bias
)
return
super
(
SparseConvAddReLU
,
cls
).
from_float
(
mod
)
@
classmethod
def
from_reference
(
cls
,
ref_qconv
,
output_scale
,
output_zero_point
):
assert
type
(
ref_qconv
)
!=
snni
.
SpconvBnReLUNd
,
\
"BatchNorm1d should be fused into Conv1d before converting to reference module"
return
super
().
from_reference
(
ref_qconv
[
0
],
output_scale
,
output_zero_point
)
spconv/pytorch/quantization/quantized/conv.py
View file @
e387ee74
...
...
@@ -323,8 +323,11 @@ class SparseConv(_SparseConv):
return
'QuantizedSparseConvolution'
def
set_weight_bias
(
self
,
w
:
torch
.
Tensor
,
b
:
Optional
[
torch
.
Tensor
])
->
None
:
assert
b
is
not
None
self
.
_weight
=
w
if
b
is
None
:
# currently bias tensor must exists.
self
.
_bias
=
torch
.
zeros
((
w
.
shape
[
0
],),
dtype
=
torch
.
float32
,
device
=
w
.
device
)
else
:
self
.
_bias
=
b
def
weight
(
self
):
...
...
@@ -336,12 +339,11 @@ class SparseConv(_SparseConv):
def
forward
(
self
,
input
:
SparseConvTensor
):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
print
(
"?"
)
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
()
w_scales
=
self
.
weight
().
q_per_channel_scales
()
.
to
(
torch
.
float32
)
out_scale
=
self
.
scale
channel_scale
=
out_scale
/
(
inp_scale
*
w_scales
)
bias
=
self
.
bias
()
*
out_scale
channel_scale
=
(
inp_scale
*
w_scales
)
/
out_scale
bias
=
self
.
bias
()
/
out_scale
return
self
.
_conv_forward
(
False
,
input
,
self
.
weight
(),
bias
,
channel_scale
=
channel_scale
,
output_scale
=
out_scale
)
return
ops
.
quantized
.
conv1d
(
input
,
self
.
_packed_params
,
self
.
scale
,
self
.
zero_point
)
...
...
spconv/pytorch/quantization/quantized/reference.py
View file @
e387ee74
...
...
@@ -132,7 +132,7 @@ class SpConv(_SpConvNd, sconvmod.SparseConvolution):
device
=
device
)
self
.
_init_weight_qparams
(
weight_qparams
,
device
)
def
forward
(
self
,
x
:
SparseConvTensor
)
->
SparseConvTensor
:
def
forward
(
self
,
x
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
)
->
SparseConvTensor
:
"""
we have:
w(float) -- quant - dequant
\
...
...
@@ -144,7 +144,7 @@ class SpConv(_SpConvNd, sconvmod.SparseConvolution):
and the backend should be able to fuse the ops with `*` into a quantized SparseConvolution
"""
weight_quant_dequant
=
self
.
get_weight
()
result
=
self
.
_conv_forward
(
self
.
training
,
x
,
weight_quant_dequant
,
self
.
bias
)
result
=
self
.
_conv_forward
(
self
.
training
,
x
,
weight_quant_dequant
,
self
.
bias
,
add_input
=
add_input
)
return
result
def
_get_name
(
self
):
...
...
spconv/pytorch/quantization/utils.py
View file @
e387ee74
...
...
@@ -50,3 +50,4 @@ def fuse_spconv_act_eval(conv, act):
else
:
raise
NotImplementedError
return
fused_conv
test/debug/dev.py
0 → 100644
View file @
e387ee74
from
collections
import
OrderedDict
import
contextlib
import
operator
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
torch.ao.quantization.fx.match_utils
import
(
MatchAllNode
,
)
from
torch.ao.quantization.quantize_fx
import
(
fuse_fx
,
)
from
torch.ao.quantization.backend_config
import
(
get_qnnpack_backend_config
,
BackendConfig
,
BackendPatternConfig
,
DTypeConfig
,
ObservationType
,
get_fbgemm_backend_config
)
from
torch.ao.quantization
import
get_default_qconfig_mapping
import
torch.ao.quantization.quantize_fx
as
qfx
class
M
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
3
,
3
,
3
)
self
.
bn
=
torch
.
nn
.
BatchNorm2d
(
3
)
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
maxpool
=
torch
.
nn
.
MaxPool2d
(
3
)
self
.
iden
=
nn
.
Identity
()
def
forward
(
self
,
x
):
y
=
x
y
=
self
.
iden
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
x
=
torch
.
add
(
x
,
y
)
x
=
self
.
relu
(
x
)
return
x
class
M2
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
3
,
3
,
3
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
3
,
3
,
3
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
3
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
3
)
self
.
relu1
=
torch
.
nn
.
ReLU
()
self
.
relu2
=
torch
.
nn
.
ReLU
()
self
.
iden
=
nn
.
Identity
()
def
forward
(
self
,
x
):
y
=
x
y
=
self
.
iden
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
x
=
torch
.
add
(
x
,
y
)
x
=
self
.
relu2
(
x
)
return
x
class
M3
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
conv1
=
torch
.
nn
.
Conv2d
(
3
,
3
,
3
)
conv2
=
torch
.
nn
.
Conv2d
(
3
,
3
,
3
)
bn1
=
torch
.
nn
.
BatchNorm2d
(
3
)
bn2
=
torch
.
nn
.
BatchNorm2d
(
3
)
relu1
=
torch
.
nn
.
ReLU
()
self
.
relu2
=
torch
.
nn
.
ReLU
()
self
.
conv1_bn_relu
=
nn
.
Sequential
(
OrderedDict
(
conv
=
conv1
,
bn
=
bn1
,
relu
=
nn
.
ReLU
(
inplace
=
True
)))
self
.
conv2_bn
=
nn
.
Sequential
(
OrderedDict
(
conv
=
conv2
,
bn
=
bn2
))
self
.
iden
=
nn
.
Identity
()
self
.
conv3
=
torch
.
nn
.
Conv2d
(
3
,
3
,
3
)
def
forward
(
self
,
x
):
y
=
x
y
=
self
.
conv3
(
x
)
x
=
self
.
conv1_bn_relu
(
x
)
x
=
self
.
conv2_bn
(
x
)
x
=
self
.
relu2
(
torch
.
add
(
x
,
y
))
return
x
m
=
M
().
eval
()
def
fuse_conv_bn_relu
(
is_qat
,
relu
,
add_pattern
):
_
,
bn_pattern
,
_
=
add_pattern
bn
,
conv
=
bn_pattern
return
conv
def
conv_bn_res_relu_root_node_getter
(
pattern
):
relu
,
add_pattern
=
pattern
_
,
bn_pattern
,
_
=
add_pattern
bn
,
conv
=
bn_pattern
return
conv
def
conv_bn_res_relu_extra_inputs_getter
(
pattern
):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu
,
add_pattern
=
pattern
_
,
bn_pattern
,
extra_input
=
add_pattern
bn
,
conv
=
bn_pattern
return
[
extra_input
]
# conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
# .set_fuser_method(fuse_conv_bn_relu) \
# ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
# ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
fbgemm_weighted_op_int8_dtype_config
=
DTypeConfig
(
input_dtype
=
torch
.
quint8
,
output_dtype
=
torch
.
quint8
,
weight_dtype
=
torch
.
qint8
,
bias_dtype
=
torch
.
float
,
)
conv_bn_res_relu_config
=
BackendPatternConfig
()
\
.
_set_pattern_complex_format
((
nn
.
ReLU
,
(
torch
.
add
,
(
nn
.
BatchNorm2d
,
nn
.
Conv2d
),
MatchAllNode
)))
\
.
set_fuser_method
(
fuse_conv_bn_relu
)
\
.
_set_root_node_getter
(
conv_bn_res_relu_root_node_getter
)
\
.
_set_extra_inputs_getter
(
conv_bn_res_relu_extra_inputs_getter
)
\
.
set_dtype_configs
(
fbgemm_weighted_op_int8_dtype_config
)
backend_config
=
get_fbgemm_backend_config
().
set_backend_pattern_config
(
conv_bn_res_relu_config
)
# m = fuse_fx(m, backend_config=backend_config)
qmapping
=
get_default_qconfig_mapping
()
prepared_model
=
qfx
.
prepare_fx
(
m
,
qmapping
,
(),
backend_config
=
backend_config
)
converted_model
=
qfx
.
convert_fx
(
prepared_model
,
qconfig_mapping
=
qmapping
,
backend_config
=
backend_config
)
converted_model
.
print_readable
()
\ No newline at end of file
test/debug/dev2.py
0 → 100644
View file @
e387ee74
from
collections
import
OrderedDict
import
contextlib
import
operator
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
torch.ao.quantization.fx.match_utils
import
(
MatchAllNode
,
)
from
torch.ao.quantization.quantize_fx
import
(
fuse_fx
,
)
from
torch.ao.quantization.backend_config
import
(
get_qnnpack_backend_config
,
BackendConfig
,
BackendPatternConfig
,
DTypeConfig
,
ObservationType
,
get_fbgemm_backend_config
)
from
torch.ao.quantization
import
get_default_qconfig_mapping
import
torch.ao.quantization.quantize_fx
as
qfx
class
M
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
3
,
3
,
3
)
self
.
bn
=
torch
.
nn
.
BatchNorm2d
(
3
)
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
maxpool
=
torch
.
nn
.
MaxPool2d
(
3
)
self
.
iden
=
nn
.
Conv2d
(
3
,
3
,
3
)
# self.iden2 = nn.Conv2d(3, 3, 3)
def
forward
(
self
,
x
):
y
=
x
y
=
self
.
iden
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
x
=
torch
.
add
(
x
,
y
)
x
=
self
.
relu
(
x
)
return
x
m
=
M
().
eval
()
def
fuse_conv_bn_relu
(
is_qat
,
relu
,
add_pattern
):
_
,
bn_pattern
,
_
=
add_pattern
bn
,
conv
=
bn_pattern
return
conv
def
conv_bn_res_relu_root_node_getter
(
pattern
):
relu
,
add_pattern
=
pattern
_
,
bn_pattern
,
_
=
add_pattern
bn
,
conv
=
bn_pattern
return
conv
def
conv_bn_res_relu_extra_inputs_getter
(
pattern
):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu
,
add_pattern
=
pattern
_
,
bn_pattern
,
extra_input
=
add_pattern
bn
,
conv
=
bn_pattern
return
[
extra_input
]
# for pytorch <= 1.13
# conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
# .set_fuser_method(fuse_conv_bn_relu) \
# ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
# ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
fbgemm_weighted_op_int8_dtype_config
=
DTypeConfig
(
input_dtype
=
torch
.
quint8
,
output_dtype
=
torch
.
quint8
,
weight_dtype
=
torch
.
qint8
,
bias_dtype
=
torch
.
float
,
)
# for pytorch master
conv_bn_res_relu_config
=
BackendPatternConfig
()
\
.
_set_pattern_complex_format
((
nn
.
ReLU
,
(
torch
.
add
,
(
nn
.
BatchNorm2d
,
nn
.
Conv2d
),
MatchAllNode
)))
\
.
set_fuser_method
(
fuse_conv_bn_relu
)
\
.
_set_root_node_getter
(
conv_bn_res_relu_root_node_getter
)
\
.
_set_extra_inputs_getter
(
conv_bn_res_relu_extra_inputs_getter
)
\
.
set_dtype_configs
(
fbgemm_weighted_op_int8_dtype_config
)
backend_config
=
get_fbgemm_backend_config
()
# .set_backend_pattern_config(conv_bn_res_relu_config)
# m = fuse_fx(m, backend_config=backend_config)
qmapping
=
get_default_qconfig_mapping
()
prepared_model
=
qfx
.
prepare_fx
(
m
,
qmapping
,
(),
backend_config
=
backend_config
)
prepared_model
.
print_readable
()
converted_model
=
qfx
.
convert_fx
(
prepared_model
,
qconfig_mapping
=
qmapping
,
backend_config
=
backend_config
)
converted_model
.
print_readable
()
\ No newline at end of file
test/test_all_algo.py
View file @
e387ee74
...
...
@@ -88,7 +88,7 @@ class SparseConvTester:
op
=
expand_nd
(
ndim
,
0
)
self
.
kv
:
int
=
np
.
prod
(
self
.
ksize
)
self
.
num_split
=
1
if
algo
==
ConvAlgo
.
MaskImplicitGemm
else
2
self
.
output_scale
:
float
=
1.0
self
.
output_scale
:
float
=
3.4
self
.
check_int8_infer
=
check_int8_infer
if
check_int8_infer
:
assert
check_bias
and
self
.
dtype
==
np
.
int8
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment