Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
b1c57a31
Commit
b1c57a31
authored
Jan 03, 2023
by
yan.yan
Browse files
still working on int8
parent
aa26c99e
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
641 additions
and
0 deletions
+641
-0
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
+62
-0
spconv/pytorch/quantization/qmapping.py
spconv/pytorch/quantization/qmapping.py
+50
-0
spconv/pytorch/quantization/quantized/__init__.py
spconv/pytorch/quantization/quantized/__init__.py
+15
-0
spconv/pytorch/quantization/quantized/conv.py
spconv/pytorch/quantization/quantized/conv.py
+359
-0
spconv/pytorch/quantization/quantized/reference.py
spconv/pytorch/quantization/quantized/reference.py
+155
-0
No files found.
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
spconv.pytorch.quantization.quantized
as
nnq
from
spconv.pytorch.quantization.intrinsic
import
SpconvReLUNd
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.quantization.utils
import
fuse_spconv_bn_weights
import
spconv.pytorch.quantization.intrinsic.qat
as
snniqat
import
spconv.pytorch.quantization.intrinsic
as
snni
__all__
=
[
"SparseConvReLU"
]
class
SparseConvReLU
(
nnq
.
SparseConv
):
r
"""
A ConvReLU1d module is a fused module of Conv1d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
Attributes:
Same as torch.ao.nn.quantized.Conv1d
"""
_FLOAT_MODULE
=
SpconvReLUNd
# type: ignore[assignment]
def
forward
(
self
,
input
):
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
()
out_scale
=
self
.
scale
channel_scale
=
out_scale
/
(
inp_scale
*
w_scales
)
bias
=
self
.
bias
()
*
out_scale
return
self
.
_conv_forward
(
False
,
input
,
self
.
weight
(),
bias
,
channel_scale
=
channel_scale
,
output_scale
=
out_scale
,
act_type
=
tv
.
gemm
.
Activation
.
ReLU
)
def
_get_name
(
self
):
return
'QuantizedConvReLU1d'
@
classmethod
def
from_float
(
cls
,
mod
):
if
type
(
mod
)
==
snniqat
.
SparseConvBnReLU
:
mod
.
weight
,
mod
.
bias
=
fuse_spconv_bn_weights
(
mod
.
weight
,
mod
.
bias
,
mod
.
bn
.
running_mean
,
mod
.
bn
.
running_var
,
mod
.
bn
.
eps
,
mod
.
bn
.
weight
,
mod
.
bn
.
bias
)
return
super
(
SparseConvReLU
,
cls
).
from_float
(
mod
)
@
classmethod
def
from_reference
(
cls
,
ref_qconv
,
output_scale
,
output_zero_point
):
assert
type
(
ref_qconv
)
!=
snni
.
SpconvBnReLUNd
,
\
"BatchNorm1d should be fused into Conv1d before converting to reference module"
return
super
().
from_reference
(
ref_qconv
[
0
],
output_scale
,
output_zero_point
)
spconv/pytorch/quantization/qmapping.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
,
Union
,
Dict
,
Set
,
Callable
,
Any
import
spconv.pytorch.quantization.quantized
as
nnq
from
spconv.pytorch.conv
import
DEFAULT_SPARSE_CONV_TYPES
import
spconv.pytorch.quantization.intrinsic.qat
as
snniqat
import
spconv.pytorch.quantization.intrinsic
as
snni
import
spconv.pytorch.quantization.intrinsic.quantized
as
snniq
import
spconv.pytorch.quantization.quantized
as
snnq
STATIC_SPCONV_QUANT_MODULE_MAPPINGS
:
Dict
[
Callable
,
Any
]
=
{}
for
x
in
DEFAULT_SPARSE_CONV_TYPES
:
STATIC_SPCONV_QUANT_MODULE_MAPPINGS
[
x
]
=
nnq
.
SparseConv
STATIC_SPCONV_QUANT_MODULE_MAPPINGS
.
update
({
snni
.
SpconvReLUNd
:
snniq
.
SparseConvReLU
,
snniqat
.
SparseConvBn
:
snnq
.
SparseConv
,
snniqat
.
SparseConvBnReLU
:
snniq
.
SparseConvReLU
,
snniqat
.
SparseConvReLU
:
snniq
.
SparseConvReLU
,
})
SPCONV_QAT_MODULE_MAPPINGS
:
Dict
[
Callable
,
Any
]
=
{
# nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
snni
.
SpconvBnNd
:
snniqat
.
SparseConvBn
,
snni
.
SpconvBnReLUNd
:
snniqat
.
SparseConvBnReLU
,
snni
.
SpconvBnAddReLUNd
:
snniqat
.
SparseConvBnAddReLU
,
}
def
get_spconv_qat_to_static_mapping
():
return
STATIC_SPCONV_QUANT_MODULE_MAPPINGS
def
get_spconv_fmod_to_qat_mapping
():
return
SPCONV_QAT_MODULE_MAPPINGS
spconv/pytorch/quantization/quantized/__init__.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.conv
import
SparseConv
\ No newline at end of file
spconv/pytorch/quantization/quantized/conv.py
0 → 100644
View file @
b1c57a31
# coding=utf-8
r
"""Quantized convolution modules."""
from
typing
import
Optional
,
List
,
TypeVar
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
spconv.pytorch.conv
import
SparseConvolution
,
SparseConvolutionBase
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
spconv.core
import
ConvAlgo
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.core
import
SparseConvTensor
from
torch._ops
import
ops
from
torch.nn.common_types
import
_size_1_t
from
torch.nn.modules.utils
import
_single
,
_pair
,
_triple
from
torch.ao.nn.quantized.modules.utils
import
WeightedQuantizedModule
,
_quantize_weight
import
spconv.pytorch.quantization.intrinsic.qat.modules
as
snniqat
import
spconv.pytorch.quantization.intrinsic.modules
as
snni
from
spconv.pytorch.quantization.utils
import
fuse_spconv_bn_eval
,
fuse_spconv_bn_weights
class
_SparseConv
(
SparseConvolutionBase
,
WeightedQuantizedModule
):
_FLOAT_MODULE
=
SparseConvolution
_NNIQAT_CONV_BN_MODULE
=
snniqat
.
SparseConvBn
_NNI_CONV_RELU_MODULE
=
snni
.
SpconvReLUNd
def
__init__
(
self
,
ndim
:
int
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
3
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
transposed
:
bool
=
False
,
inverse
:
bool
=
False
,
indice_key
:
Optional
[
str
]
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
device
=
None
,
dtype
=
None
):
SparseConvolutionBase
.
__init__
(
self
,
ndim
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
=
bias
,
subm
=
subm
,
output_padding
=
output_padding
,
transposed
=
transposed
,
inverse
=
inverse
,
indice_key
=
indice_key
,
algo
=
algo
,
fp32_accum
=
fp32_accum
,
record_voxel_count
=
record_voxel_count
,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
)
WeightedQuantizedModule
.
__init__
(
self
)
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
qweight
=
torch
.
_empty_affine_quantized
(
self
.
weight_shape
,
scale
=
1
,
zero_point
=
0
,
dtype
=
torch
.
qint8
,
**
{
k
:
v
for
k
,
v
in
factory_kwargs
.
items
()
if
k
!=
'dtype'
})
bias_float
=
(
torch
.
zeros
(
out_channels
,
dtype
=
torch
.
float
,
**
{
k
:
v
for
k
,
v
in
factory_kwargs
.
items
()
if
k
!=
'dtype'
})
if
bias
else
None
)
self
.
set_weight_bias
(
qweight
,
bias_float
)
self
.
scale
=
1.0
self
.
zero_point
=
0
def
set_weight_bias
(
self
,
qweight
,
bias_float
):
self
.
_weight
:
torch
.
Tensor
=
qweight
self
.
_bias
:
torch
.
Tensor
=
bias_float
def
bias
(
self
):
return
self
.
_bias
def
_weight_bias
(
self
):
return
(
self
.
_weight
,
self
.
_bias
)
def
extra_repr
(
self
):
s
=
(
'{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}, scale={scale}, zero_point={zero_point}'
)
if
self
.
padding
!=
(
0
,)
*
len
(
self
.
padding
):
s
+=
', padding={padding}'
if
self
.
dilation
!=
(
1
,)
*
len
(
self
.
dilation
):
s
+=
', dilation={dilation}'
if
self
.
output_padding
!=
(
0
,)
*
len
(
self
.
output_padding
):
s
+=
', output_padding={output_padding}'
if
self
.
groups
!=
1
:
s
+=
', groups={groups}'
if
self
.
bias
()
is
None
:
s
+=
', bias=False'
s
+=
f
', wqscheme=
{
self
.
_weight_bias
()[
0
].
qscheme
()
}
'
return
s
.
format
(
**
self
.
__dict__
)
# ===== Serialization methods =====
# The special consideration here is that we have to unpack the weights into
# their regular QTensor form for serialization. Packed weights should not
# live outside the process in which they were created, rather they should be
# derived from the QTensor weight.
# self
# |--- weight : Tensor
# |--- bias : Tensor
#
# TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed
# self
# |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
super
(
_SparseConv
,
self
).
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
(
w
,
b
)
=
self
.
_weight_bias
()
destination
[
prefix
+
'weight'
]
=
w
destination
[
prefix
+
'bias'
]
=
b
destination
[
prefix
+
'scale'
]
=
torch
.
tensor
(
self
.
scale
)
destination
[
prefix
+
'zero_point'
]
=
torch
.
tensor
(
self
.
zero_point
)
# @torch.jit.export
# def __getstate__(self):
# (w, b) = self._weight_bias()
# return (
# self.in_channels,
# self.out_channels,
# self.kernel_size,
# self.stride,
# self.padding,
# self.dilation,
# self.transposed,
# self.output_padding,
# self.groups,
# self.padding_mode,
# w,
# b,
# self.scale,
# self.zero_point,
# self.training
# )
# ===== Deserialization methods =====
# Counterpart to the serialization methods, we must pack the serialized
# QTensor weight into its packed format for use by the FBGEMM ops.
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
self
.
set_weight_bias
(
state_dict
[
prefix
+
'weight'
],
state_dict
[
prefix
+
'bias'
])
state_dict
.
pop
(
prefix
+
'weight'
)
state_dict
.
pop
(
prefix
+
'bias'
)
self
.
scale
=
float
(
state_dict
[
prefix
+
'scale'
])
state_dict
.
pop
(
prefix
+
'scale'
)
self
.
zero_point
=
int
(
state_dict
[
prefix
+
'zero_point'
])
state_dict
.
pop
(
prefix
+
'zero_point'
)
super
(
_SparseConv
,
self
).
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
False
,
missing_keys
,
unexpected_keys
,
error_msgs
)
# @torch.jit.export
# def __setstate__(self, state):
# self.in_channels = state[0]
# self.out_channels = state[1]
# self.kernel_size = state[2]
# self.stride = state[3]
# self.padding = state[4]
# self.dilation = state[5]
# self.transposed = state[6]
# self.output_padding = state[7]
# self.groups = state[8]
# self.padding_mode = state[9]
# self.set_weight_bias(state[10], state[11])
# self.scale = state[12]
# self.zero_point = state[13]
# self.training = state[14]
# def __deepcopy__(self, memo):
# new_instance = type(self).__new__(type(self))
# torch.nn.Module.__init__(new_instance)
# state = self.__getstate__()
# new_instance.__setstate__(state)
# return new_instance
# def __copy__(self):
# return self.__deepcopy__({})
@
classmethod
def
get_qconv
(
cls
,
mod
,
activation_post_process
,
weight_post_process
=
None
):
r
"""Creates a qconv object and returns it.
"""
if
weight_post_process
is
None
:
weight_post_process
=
mod
.
qconfig
.
weight
()
weight_post_process
(
mod
.
weight
)
assert
weight_post_process
.
dtype
==
torch
.
qint8
,
\
'Weight observer must have a dtype of qint8'
qweight
=
_quantize_weight
(
mod
.
weight
.
float
(),
weight_post_process
)
# the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv
=
cls
(
mod
.
ndim
,
mod
.
in_channels
,
mod
.
out_channels
,
mod
.
kernel_size
,
mod
.
stride
,
mod
.
padding
,
mod
.
dilation
,
mod
.
groups
,
mod
.
bias
is
not
None
,
subm
=
mod
.
subm
,
output_padding
=
mod
.
output_padding
,
transposed
=
mod
.
transposed
,
inverse
=
mod
.
inverse
,
indice_key
=
mod
.
indice_key
,
algo
=
mod
.
algo
,
fp32_accum
=
mod
.
fp32_accum
,
record_voxel_count
=
mod
.
record_voxel_count
,
act_type
=
mod
.
act_type
,
act_alpha
=
mod
.
act_alpha
,
act_beta
=
mod
.
act_beta
)
qconv
.
set_weight_bias
(
qweight
,
mod
.
bias
)
if
activation_post_process
is
None
or
activation_post_process
.
dtype
==
torch
.
float
:
return
qconv
# dynamic quantization doesn't need scale/zero_point
else
:
act_scale
,
act_zp
=
activation_post_process
.
calculate_qparams
()
qconv
.
scale
=
float
(
act_scale
)
qconv
.
zero_point
=
int
(
act_zp
)
return
qconv
@
staticmethod
def
from_float
(
cls
,
mod
):
if
hasattr
(
mod
,
"weight_fake_quant"
):
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
# ".from_float only works for " + cls.__QAT_MODULE.__name__
if
type
(
mod
)
==
cls
.
_NNIQAT_CONV_BN_MODULE
:
mod
.
weight
,
mod
.
bias
=
fuse_spconv_bn_weights
(
mod
.
weight
,
mod
.
bias
,
mod
.
bn
.
running_mean
,
mod
.
bn
.
running_var
,
mod
.
bn
.
eps
,
mod
.
bn
.
weight
,
mod
.
bn
.
bias
)
assert
hasattr
(
mod
,
"activation_post_process"
),
\
"Input QAT module must have observer attached"
weight_post_process
=
mod
.
weight_fake_quant
activation_post_process
=
mod
.
activation_post_process
else
:
assert
type
(
mod
)
==
cls
.
_FLOAT_MODULE
,
\
" nnq."
+
cls
.
__name__
+
".from_float only works for "
+
\
cls
.
_FLOAT_MODULE
.
__name__
+
" but got:"
+
str
(
type
(
mod
))
assert
hasattr
(
mod
,
"qconfig"
),
\
"Input float module must have qconfig defined."
activation_post_process
=
None
if
not
hasattr
(
mod
,
"activation_post_process"
)
else
mod
.
activation_post_process
if
type
(
mod
)
==
cls
.
_NNI_CONV_RELU_MODULE
:
mod
=
mod
[
0
]
weight_post_process
=
mod
.
qconfig
.
weight
()
return
cls
.
get_qconv
(
mod
,
activation_post_process
,
weight_post_process
)
@
classmethod
def
from_reference
(
cls
,
ref_qconv
,
output_scale
,
output_zero_point
):
r
"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
output_zero_point (int): zero point for output Tensor
"""
qconv
=
cls
(
ref_qconv
.
ndim
,
ref_qconv
.
in_channels
,
ref_qconv
.
out_channels
,
ref_qconv
.
kernel_size
,
ref_qconv
.
stride
,
ref_qconv
.
padding
,
ref_qconv
.
dilation
,
ref_qconv
.
groups
,
ref_qconv
.
bias
is
not
None
,
subm
=
ref_qconv
.
subm
,
output_padding
=
ref_qconv
.
output_padding
,
transposed
=
ref_qconv
.
transposed
,
inverse
=
ref_qconv
.
inverse
,
indice_key
=
ref_qconv
.
indice_key
,
algo
=
ref_qconv
.
algo
,
fp32_accum
=
ref_qconv
.
fp32_accum
,
record_voxel_count
=
ref_qconv
.
record_voxel_count
,
act_type
=
ref_qconv
.
act_type
,
act_alpha
=
ref_qconv
.
act_alpha
,
act_beta
=
ref_qconv
.
act_beta
,
device
=
ref_qconv
.
weight
.
device
,
dtype
=
ref_qconv
.
weight
.
dtype
)
qweight
=
ref_qconv
.
get_quantized_weight
()
qconv
.
set_weight_bias
(
qweight
,
ref_qconv
.
bias
)
qconv
.
scale
=
float
(
output_scale
)
qconv
.
zero_point
=
int
(
output_zero_point
)
return
qconv
class
SparseConv
(
_SparseConv
):
r
"""Applies a 1D convolution over a quantized input signal composed of
several quantized input planes.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv1d`.
.. note::
Only `zeros` is supported for the :attr:`padding_mode` argument.
.. note::
Only `torch.quint8` is supported for the input data type.
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv1d` for other attributes.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 100)
>>> # quantize input to quint8
>>> # xdoctest: +SKIP
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
... dtype=torch.quint8)
>>> output = m(q_input)
"""
_FLOAT_MODULE
=
SparseConvolution
_NNIQAT_CONV_BN_MODULE
=
snniqat
.
SparseConvBn
_NNI_CONV_RELU_MODULE
=
snni
.
SpconvReLUNd
def
_get_name
(
self
):
return
'QuantizedSparseConvolution'
def
set_weight_bias
(
self
,
w
:
torch
.
Tensor
,
b
:
Optional
[
torch
.
Tensor
])
->
None
:
assert
b
is
not
None
self
.
_weight
=
w
self
.
_bias
=
b
def
weight
(
self
):
return
self
.
_weight_bias
()[
0
]
def
bias
(
self
):
return
self
.
_weight_bias
()[
1
]
def
forward
(
self
,
input
:
SparseConvTensor
):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
print
(
"?"
)
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
()
out_scale
=
self
.
scale
channel_scale
=
out_scale
/
(
inp_scale
*
w_scales
)
bias
=
self
.
bias
()
*
out_scale
return
self
.
_conv_forward
(
False
,
input
,
self
.
weight
(),
bias
,
channel_scale
=
channel_scale
,
output_scale
=
out_scale
)
return
ops
.
quantized
.
conv1d
(
input
,
self
.
_packed_params
,
self
.
scale
,
self
.
zero_point
)
@
classmethod
def
from_float
(
cls
,
mod
):
r
"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return
_SparseConv
.
from_float
(
cls
,
mod
)
spconv/pytorch/quantization/quantized/reference.py
0 → 100644
View file @
b1c57a31
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
typing
import
Optional
,
Dict
,
Any
,
List
from
torch.nn.common_types
import
_size_1_t
from
torch.ao.nn.quantized.reference.modules.utils
import
ReferenceQuantizedModule
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
spconv.core
import
ConvAlgo
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.core
import
SparseConvTensor
import
spconv.pytorch.conv
as
sconvmod
class
_SpConvNd
(
sconvmod
.
SparseConvolution
,
ReferenceQuantizedModule
):
""" A reference version of nn.quantized.Conv2d
we will not pack the parameters in this module, since weight packing is an
optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
this is useful when user want to use this module in other backends like Glow.
"""
__annotations__
=
{
"bias"
:
Optional
[
torch
.
Tensor
]}
_IS_REFERENCE
=
True
# @staticmethod
# def from_float(cls, float_conv, weight_qparams):
# qref_conv = cls(
# float_conv.in_channels,
# float_conv.out_channels,
# float_conv.kernel_size, # type: ignore[arg-type]
# float_conv.stride, # type: ignore[arg-type]
# float_conv.padding, # type: ignore[arg-type]
# float_conv.dilation, # type: ignore[arg-type]
# float_conv.groups,
# float_conv.bias is not None, # type: ignore[arg-type]
# float_conv.padding_mode,
# device=float_conv.weight.device,
# dtype=float_conv.weight.dtype,
# weight_qparams=weight_qparams)
# qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
# if float_conv.bias is not None:
# qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
# return qref_conv
@
staticmethod
def
from_float
(
cls
,
float_conv
,
weight_qparams
):
r
"""Create a qat module from a float module
Args:
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
conv
:
sconvmod
.
SparseConvolution
=
float_conv
qref_conv
=
cls
(
conv
.
ndim
,
conv
.
in_channels
,
conv
.
out_channels
,
conv
.
kernel_size
,
conv
.
stride
,
conv
.
padding
,
conv
.
dilation
,
conv
.
groups
,
conv
.
bias
is
not
None
,
subm
=
conv
.
subm
,
output_padding
=
conv
.
output_padding
,
transposed
=
conv
.
transposed
,
inverse
=
conv
.
inverse
,
indice_key
=
conv
.
indice_key
,
algo
=
conv
.
algo
,
fp32_accum
=
conv
.
fp32_accum
,
record_voxel_count
=
conv
.
record_voxel_count
,
act_type
=
conv
.
act_type
,
act_alpha
=
conv
.
act_alpha
,
act_beta
=
conv
.
act_beta
,
name
=
conv
.
name
,
device
=
float_conv
.
weight
.
device
,
dtype
=
float_conv
.
weight
.
dtype
,
weight_qparams
=
weight_qparams
)
qref_conv
.
weight
=
torch
.
nn
.
Parameter
(
float_conv
.
weight
.
detach
())
if
float_conv
.
bias
is
not
None
:
qref_conv
.
bias
=
torch
.
nn
.
Parameter
(
float_conv
.
bias
.
detach
())
return
qref_conv
class
SpConv
(
_SpConvNd
,
sconvmod
.
SparseConvolution
):
def
__init__
(
self
,
ndim
:
int
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
3
,
stride
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
dilation
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
subm
:
bool
=
False
,
output_padding
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...]]
=
0
,
transposed
:
bool
=
False
,
inverse
:
bool
=
False
,
indice_key
:
Optional
[
str
]
=
None
,
algo
:
Optional
[
ConvAlgo
]
=
None
,
fp32_accum
:
Optional
[
bool
]
=
None
,
record_voxel_count
:
bool
=
False
,
act_type
:
tv
.
gemm
.
Activation
=
tv
.
gemm
.
Activation
.
None_
,
act_alpha
:
float
=
0
,
act_beta
:
float
=
0
,
name
=
None
,
device
=
None
,
dtype
=
None
,
weight_qparams
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
sconvmod
.
SparseConvolution
.
__init__
(
self
,
ndim
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
=
False
,
subm
=
subm
,
output_padding
=
output_padding
,
transposed
=
transposed
,
inverse
=
inverse
,
indice_key
=
indice_key
,
algo
=
algo
,
fp32_accum
=
fp32_accum
,
record_voxel_count
=
record_voxel_count
,
act_type
=
act_type
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
name
=
name
,
dtype
=
dtype
,
device
=
device
)
self
.
_init_weight_qparams
(
weight_qparams
,
device
)
def
forward
(
self
,
x
:
SparseConvTensor
)
->
SparseConvTensor
:
"""
we have:
w(float) -- quant - dequant
\
x(float) ------------- SparseConvolution ---
In the full model, we will see
w(float) -- quant - *dequant
\
x -- quant --- *dequant -- *SparseConvolution --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized SparseConvolution
"""
weight_quant_dequant
=
self
.
get_weight
()
result
=
self
.
_conv_forward
(
self
.
training
,
x
,
weight_quant_dequant
,
self
.
bias
)
return
result
def
_get_name
(
self
):
return
"QuantizedSparseConv(Reference)"
@
classmethod
def
from_float
(
cls
,
float_conv
,
weight_qparams
):
return
_SpConvNd
.
from_float
(
cls
,
float_conv
,
weight_qparams
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment