Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
5b3fe9e7
Commit
5b3fe9e7
authored
Jan 10, 2023
by
yan.yan
Browse files
sync quantization code
parent
e387ee74
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
758 additions
and
97 deletions
+758
-97
example/mnist/mnist_qat.py
example/mnist/mnist_qat.py
+24
-23
spconv/algo.py
spconv/algo.py
+7
-6
spconv/core.py
spconv/core.py
+103
-13
spconv/pytorch/__init__.py
spconv/pytorch/__init__.py
+1
-0
spconv/pytorch/core.py
spconv/pytorch/core.py
+21
-9
spconv/pytorch/cppcore.py
spconv/pytorch/cppcore.py
+4
-0
spconv/pytorch/modules.py
spconv/pytorch/modules.py
+22
-0
spconv/pytorch/quantization/__init__.py
spconv/pytorch/quantization/__init__.py
+5
-2
spconv/pytorch/quantization/backend_cfg.py
spconv/pytorch/quantization/backend_cfg.py
+56
-13
spconv/pytorch/quantization/core.py
spconv/pytorch/quantization/core.py
+17
-1
spconv/pytorch/quantization/fake_q.py
spconv/pytorch/quantization/fake_q.py
+39
-11
spconv/pytorch/quantization/graph.py
spconv/pytorch/quantization/graph.py
+56
-0
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
+68
-18
spconv/pytorch/quantization/quantized/__init__.py
spconv/pytorch/quantization/quantized/__init__.py
+1
-1
spconv/pytorch/quantization/quantized/conv.py
spconv/pytorch/quantization/quantized/conv.py
+334
-0
No files found.
example/mnist/mnist_qat.py
View file @
5b3fe9e7
...
...
@@ -250,8 +250,8 @@ class NetV2(nn.Module):
)
self
.
fc1
=
nn
.
Linear
(
14
*
14
*
64
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
#
self.dropout1 = nn.Dropout2d(0.25)
#
self.dropout2 = nn.Dropout2d(0.5)
self
.
quant
=
QuantStub
()
self
.
dequant
=
DeQuantStub
()
...
...
@@ -263,10 +263,10 @@ class NetV2(nn.Module):
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x
=
self
.
net
(
x_sp
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
dropout1
(
x
)
#
x = self.dropout1(x)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
#
x = self.dropout2(x)
x
=
self
.
fc2
(
x
)
x
=
self
.
dequant
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
...
...
@@ -474,22 +474,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else
:
output
=
model
(
image
)
def
transform_qdq
(
m
:
torch
.
fx
.
GraphModule
)
->
torch
.
fx
.
GraphModule
:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for
node
in
m
.
graph
.
nodes
:
# Checks if we're calling a function (i.e:
# torch.add)
if
node
.
op
==
'call_function'
:
# The target attribute is the function
# that call_function calls.
if
node
.
target
==
torch
.
quantize_per_tensor
:
node
.
target
=
quantize_per_tensor
m
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
m
.
recompile
()
return
m
def
is_dequantize_node
(
node
):
...
...
@@ -522,6 +506,23 @@ def remove_conv_add_dq(model: torch.fx.graph_module.GraphModule):
# Graph is well-formed.
return
model
def
transform_qdq
(
m
:
torch
.
fx
.
GraphModule
)
->
torch
.
fx
.
GraphModule
:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for
node
in
m
.
graph
.
nodes
:
# Checks if we're calling a function (i.e:
# torch.add)
if
node
.
op
==
'call_function'
:
# The target attribute is the function
# that call_function calls.
if
node
.
target
==
torch
.
quantize_per_tensor
:
node
.
target
=
quantize_per_tensor
m
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
m
.
recompile
()
return
m
def
main
():
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
...
...
@@ -561,7 +562,7 @@ def main():
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--sparse'
,
action
=
'store_true'
,
default
=
Tru
e
,
default
=
Fals
e
,
help
=
'use sparse conv network instead of dense'
)
parser
.
add_argument
(
'--log-interval'
,
...
...
@@ -588,7 +589,7 @@ def main():
qdevice
=
torch
.
device
(
"cuda"
if
use_cuda
and
args
.
sparse
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
if
args
.
sparse
:
model
=
ResidualNetPTQ
().
to
(
device
)
model
=
NetV2
().
to
(
device
)
else
:
model
=
NetDense
().
to
(
device
)
...
...
@@ -647,7 +648,7 @@ def main():
# print(prepared_model)
# calibrate: run model with some inputs
#
calibrate(args, prepared_model, test_loader, qdevice)
calibrate
(
args
,
prepared_model
,
test_loader
,
qdevice
)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model
=
qfx
.
convert_fx
(
prepared_model
,
qconfig_mapping
=
qconfig_mapping
,
backend_config
=
backend_cfg
)
converted_model
=
transform_qdq
(
converted_model
)
...
...
spconv/algo.py
View file @
5b3fe9e7
...
...
@@ -269,7 +269,8 @@ class SimpleGemm:
def
device_synchronize
(
self
):
return
GemmMainUnitTest
.
device_synchronize
()
def
_compile_nvrtc_module
(
self
,
desp
:
GemmAlgoDesp
):
@
staticmethod
def
_compile_nvrtc_module
(
desp
:
GemmAlgoDesp
):
params
=
algocore
.
get_gemm_param_from_desp
(
desp
)
kernel
=
gen_gemm_kernels
(
params
,
SPCONV_NVRTC_MODE
)
kernel
.
namespace
=
"spconv"
...
...
@@ -808,7 +809,8 @@ class SimpleConv:
return
desp
.
query_conv_workspace_size
(
mnk
[
0
],
mnk
[
1
],
mnk
[
2
],
splitk
,
kv
)
def
_compile_nvrtc_module
(
self
,
desp
:
ConvAlgoDesp
):
@
staticmethod
def
_compile_nvrtc_module
(
desp
:
ConvAlgoDesp
):
params
=
algocore
.
get_conv_param_from_desp
(
desp
)
kernel
=
gen_conv_kernels
(
params
,
SPCONV_NVRTC_MODE
)
kernel
.
namespace
=
"spconv"
...
...
@@ -824,9 +826,8 @@ class SimpleConv:
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
True
,
custom_names
=
custom_names
,
verbose_path
=
"/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8"
)
verbose
=
False
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
...
...
@@ -870,7 +871,7 @@ class SimpleConv:
inp
=
inp
.
clone
()
weight
=
weight
.
clone
()
output
=
output
.
clone
()
print
(
len
(
avail
),
inp
.
dtype
,
weight
.
dtype
,
output
.
dtype
,
bias
.
dtype
,
scale
.
dtype
,
bias
.
empty
(),
scale
.
empty
())
#
print(len(avail), inp.dtype, weight.dtype, output.dtype, bias.dtype, scale.dtype, bias.empty(), scale.empty())
channel_k
=
output
.
dim
(
1
)
channel_c
=
inp
.
dim
(
1
)
weight
=
weight
.
view
([
channel_k
,
-
1
,
channel_c
])
...
...
spconv/core.py
View file @
5b3fe9e7
...
...
@@ -410,6 +410,7 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first
=
True
,
access_per_vector
=
1
),
]
IMPLGEMM_VOLTA_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
...
...
@@ -618,12 +619,26 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
]
if
not
SPCONV_INT8_DEBUG
:
IMPLGEMM_AMPERE_PARAMS
.
extend
([
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
),
(
64
,
32
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
32
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
...
...
@@ -632,14 +647,28 @@ if not SPCONV_INT8_DEBUG:
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
32
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
...
...
@@ -655,7 +684,7 @@ if not SPCONV_INT8_DEBUG:
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
)
,
(
32
,
64
,
64
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
64
)
,
(
64
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
...
...
@@ -671,7 +700,7 @@ if not SPCONV_INT8_DEBUG:
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
)
,
(
32
,
64
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
)
,
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
...
...
@@ -687,11 +716,27 @@ if not SPCONV_INT8_DEBUG:
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
,
"s8,s8,f32,s32,f32"
,
"s8,s8,f32,s32,f16"
,
"s8,s8,f16,s32,f32"
,
"s8,s8,f16,s32,f16"
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Ampere
,
TensorOp
((
16
,
8
,
32
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
[
2
,
3
,
4
],
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
...
...
@@ -1073,6 +1118,51 @@ IMPLGEMM_TURING_PARAMS = [
if
not
SPCONV_INT8_DEBUG
:
IMPLGEMM_TURING_PARAMS
.
extend
([
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
32
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
32
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
,
"s8,s8,s8,s32,f16"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
64
),
(
32
,
32
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
...
...
@@ -1136,7 +1226,7 @@ if not SPCONV_INT8_DEBUG:
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
),
# TODO 16,8,32 produce wrong result.
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
128
,
64
,
32
),
(
64
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
...
...
spconv/pytorch/__init__.py
View file @
5b3fe9e7
import
platform
from
pathlib
import
Path
from
typing
import
Union
import
numpy
as
np
import
torch
...
...
spconv/pytorch/core.py
View file @
5b3fe9e7
...
...
@@ -267,18 +267,30 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
# return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size
def
__add__
(
self
,
other
:
"SparseConvTensor"
):
assert
isinstance
(
other
,
SparseConvTensor
)
return
self
.
replace_feature
(
self
.
features
+
other
.
features
)
def
__add__
(
self
,
other
:
Union
[
"SparseConvTensor"
,
torch
.
Tensor
]):
assert
isinstance
(
other
,
(
SparseConvTensor
,
torch
.
Tensor
))
if
isinstance
(
other
,
torch
.
Tensor
):
other_features
=
other
else
:
other_features
=
other
.
features
return
self
.
replace_feature
(
self
.
features
+
other_features
)
def
__iadd__
(
self
,
other
:
"SparseConvTensor"
):
assert
isinstance
(
other
,
SparseConvTensor
)
self
.
features
+=
other
.
features
def
__iadd__
(
self
,
other
:
Union
[
"SparseConvTensor"
,
torch
.
Tensor
]):
assert
isinstance
(
other
,
(
SparseConvTensor
,
torch
.
Tensor
))
if
isinstance
(
other
,
torch
.
Tensor
):
other_features
=
other
else
:
other_features
=
other
.
features
self
.
features
+=
other_features
return
self
def
__radd__
(
self
,
other
:
"SparseConvTensor"
):
assert
isinstance
(
other
,
SparseConvTensor
)
return
other
.
replace_feature
(
self
.
features
+
other
.
features
)
def
__radd__
(
self
,
other
:
Union
[
"SparseConvTensor"
,
torch
.
Tensor
]):
assert
isinstance
(
other
,
(
SparseConvTensor
,
torch
.
Tensor
))
if
isinstance
(
other
,
torch
.
Tensor
):
other_features
=
other
else
:
other_features
=
other
.
features
return
self
.
replace_feature
(
self
.
features
+
other_features
)
def
shadow_copy
(
self
)
->
"SparseConvTensor"
:
"""create a new spconv tensor with all member unchanged"""
...
...
spconv/pytorch/cppcore.py
View file @
5b3fe9e7
...
...
@@ -137,6 +137,10 @@ class TorchAllocator(ExternalAllocator):
else
:
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
).
zero_
()
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
# if self.is_quantized:
# ctx = tv.Context()
# ctx.set_cuda_stream(stream)
# ten_tv.zero_(ctx)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
if
name
and
not
is_temp_memory
:
self
.
allocated
[
name
]
=
ten
...
...
spconv/pytorch/modules.py
View file @
5b3fe9e7
...
...
@@ -15,6 +15,7 @@
import
sys
import
time
from
collections
import
OrderedDict
from
typing
import
Union
import
torch
from
torch
import
nn
...
...
@@ -182,3 +183,24 @@ class SparseIdentity(nn.Identity):
if
isinstance
(
input
,
spconv
.
SparseConvTensor
):
return
input
.
replace_feature
(
super
().
forward
(
input
.
features
))
return
super
().
forward
(
input
)
class
PrintTensorMeta
(
nn
.
Module
):
def
forward
(
self
,
x
:
Union
[
spconv
.
SparseConvTensor
,
torch
.
Tensor
]):
if
isinstance
(
x
,
torch
.
Tensor
):
print
(
x
.
min
(),
x
.
max
(),
x
.
mean
())
elif
isinstance
(
x
,
spconv
.
SparseConvTensor
):
ft
=
x
.
features
print
(
ft
.
min
(),
ft
.
max
(),
ft
.
mean
())
return
x
class
PrintCurrentTime
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
first_time
=
time
.
time
()
def
forward
(
self
,
x
,
msg
=
""
,
reset
:
bool
=
False
):
if
reset
:
self
.
first_time
=
time
.
time
()
torch
.
cuda
.
synchronize
()
print
(
msg
,
time
.
time
()
-
self
.
first_time
)
return
x
spconv/pytorch/quantization/__init__.py
View file @
5b3fe9e7
...
...
@@ -16,7 +16,10 @@ from .backend_cfg import (get_spconv_backend_config,
get_spconv_prepare_custom_config
,
get_spconv_convert_custom_config
)
from
.fake_q
import
(
get_default_spconv_trt_ptq_qconfig
,
get_default_spconv_trt_qat_qconfig
)
get_default_spconv_trt_qat_qconfig
,
get_default_spconv_qconfig_mapping
)
from
.qmapping
import
(
get_spconv_fmod_to_qat_mapping
,
get_spconv_qat_to_static_mapping
)
from
.core
import
quantize_per_tensor
from
.graph
import
remove_conv_add_dq
,
transform_qdq
\ No newline at end of file
spconv/pytorch/quantization/backend_cfg.py
View file @
5b3fe9e7
from
collections
import
namedtuple
import
operator
from
typing
import
Dict
,
List
,
Tuple
,
Type
,
Union
from
collections
import
namedtuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.ao.quantization.fx.match_utils
import
(
MatchAllNode
,
)
from
torch.ao.nn.quantized.modules.utils
import
WeightedQuantizedModule
from
torch.ao.quantization.backend_config
import
(
BackendConfig
,
BackendPatternConfig
,
...
...
@@ -15,9 +13,12 @@ from torch.ao.quantization.backend_config import (BackendConfig,
from
torch.ao.quantization.fx.custom_config
import
(
ConvertCustomConfig
,
FuseCustomConfig
,
PrepareCustomConfig
)
from
torch.ao.quantization.fx.match_utils
import
MatchAllNode
import
torch.nn.intrinsic
as
nni
import
torch.nn.intrinsic.qat
as
nniqat
import
torch.nn.quantized._reference
as
nnqr
import
spconv.pytorch.conv
as
sconvmod
from
spconv.pytorch.modules
import
SparseBatchNorm
,
SparseIdentity
,
SparseReLU
,
SparseSyncBatchNorm
import
spconv.pytorch.quantization.intrinsic
as
snni
import
spconv.pytorch.quantization.intrinsic.qat
as
snniqat
import
spconv.pytorch.quantization.intrinsic.quantized
as
snniq
...
...
@@ -25,10 +26,15 @@ import spconv.pytorch.quantization.quantized as snnq
import
spconv.pytorch.quantization.quantized.reference
as
snnqr
from
spconv.pytorch
import
ToDense
from
spconv.pytorch.constants
import
PYTORCH_VERSION
from
spconv.pytorch.modules
import
(
PrintTensorMeta
,
SparseBatchNorm
,
SparseIdentity
,
SparseReLU
,
SparseSyncBatchNorm
,
PrintCurrentTime
)
from
spconv.pytorch.pool
import
ALL_POOL_LAYERS
from
spconv.pytorch.quantization.fuse_mapping
import
(
fuse_conv_bn
,
fuse_conv_bn_relu
,
fuse_conv_bn_add_relu
)
fuse_conv_bn_add_relu
,
fuse_conv_bn_relu
)
_SpConvMetadataDef
=
namedtuple
(
"_ConvMetadata"
,
[
"root"
,
"bn"
,
"reference"
,
"fused_conv_relu"
,
"fused_conv_bn"
,
...
...
@@ -105,6 +111,31 @@ def _conv_res_relu_extra_inputs_getter(pattern):
return
[
extra_input
]
# def _get_custom_bn_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
# """
# Return all configs related to linear modules and ops.
# """
# observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
# linear_configs: List[BackendPatternConfig] = []
# # (3) Linear + batchnorm
# # ------------------------
# # 3.1 linear bn fusion
# if PYTORCH_VERSION[:2] <= [1, 13]:
# linear_configs.append(
# BackendPatternConfig((nn.Linear, nn.BatchNorm1d))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(fuse_linear_bn)
# .set_fused_module(nni.LinearBn1d))
# else:
# linear_configs.append(
# BackendPatternConfig((nn.Linear, nn.BatchNorm1d))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(fuse_linear_bn)
# .set_fused_module(nni.LinearBn1d))
# return linear_configs
def
_get_bn_spconv_configs
(
bn_cls
,
dtype_configs
):
"""
Return all configs related to conv modules and ops.
...
...
@@ -526,6 +557,9 @@ def _get_share_observer_ops(dtype_configs):
res
.
append
(
_to_dense_cfg
)
res
.
append
(
iden_cfg
)
res
.
append
(
BackendPatternConfig
(
PrintCurrentTime
).
set_observation_type
(
ObservationType
.
OUTPUT_SHARE_OBSERVER_WITH_INPUT
).
set_dtype_configs
(
dtype_configs
))
for
p
in
ALL_POOL_LAYERS
:
_pool_cfg
=
(
BackendPatternConfig
(
p
).
set_observation_type
(
...
...
@@ -551,31 +585,40 @@ conv_dtype_configs = [
weighted_op_qint8_dtype_config
,
]
backend_config
=
get_tensorrt_backend_config
()
\
.
set_backend_pattern_configs
(
_get_spconv_configs
(
conv_dtype_configs
)
+
_get_share_observer_ops
([
non_weighted_op_qint8_dtype_config
]))
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Tuple
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]]
=
{
snni
.
SpconvReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvReLU
),
snni
.
SpconvAddReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvAddReLU
),
# use simple cumm i8 conv to implement linear
nni
.
LinearReLU
:
(
nnqr
.
Linear
,
snniq
.
LinearPerChannelWeightReLU
),
}
SPCONV_STATIC_LOWER_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]
=
{
snnqr
.
SpConv
:
snnq
.
SparseConv
,
nnqr
.
Linear
:
snnq
.
LinearPerChannelWeight
,
}
def
get_spconv_backend_config
():
def
get_spconv_backend_config
(
additional_bns
:
Optional
[
List
[
Type
[
nn
.
Module
]]]
=
None
):
backend_config
=
get_tensorrt_backend_config
()
\
.
set_backend_pattern_configs
(
_get_spconv_configs
(
conv_dtype_configs
)
+
_get_share_observer_ops
([
non_weighted_op_qint8_dtype_config
]))
if
additional_bns
is
not
None
:
for
bn_type
in
additional_bns
:
backend_config
.
set_backend_pattern_configs
(
_get_bn_spconv_configs
(
bn_type
,
conv_dtype_configs
))
return
backend_config
def
get_spconv_prepare_custom_config
():
def
get_spconv_prepare_custom_config
(
additional_bns
:
Optional
[
List
[
Type
[
nn
.
Module
]]]
=
None
):
cfg
=
PrepareCustomConfig
()
cfg
.
non_traceable_module_classes
=
[
*
sconvmod
.
DEFAULT_SPARSE_CONV_TYPES
]
cfg
.
non_traceable_module_classes
.
extend
(
[
SparseReLU
,
SparseBatchNorm
,
SparseSyncBatchNorm
])
[
SparseReLU
,
SparseBatchNorm
,
SparseSyncBatchNorm
,
PrintTensorMeta
,
PrintCurrentTime
])
if
additional_bns
is
not
None
:
cfg
.
non_traceable_module_classes
.
extend
(
additional_bns
)
return
cfg
...
...
spconv/pytorch/quantization/core.py
View file @
5b3fe9e7
...
...
@@ -3,8 +3,12 @@ from typing import Union, List, Dict
import
torch
from
spconv.pytorch.core
import
SparseConvTensor
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.cppcore
import
get_current_stream
,
torch_tensor_to_tv
def
quantize_per_tensor
(
ten
:
Union
[
Union
[
SparseConvTensor
,
torch
.
Tensor
],
List
[
Union
[
SparseConvTensor
,
torch
.
Tensor
]]],
scale
,
zero_point
,
dtype
):
# with tv.measure_and_print("quantize_per_tensor", stream=get_current_stream()):
if
isinstance
(
ten
,
(
list
,
tuple
)):
res
=
[]
for
i
,
v
in
enumerate
(
ten
):
...
...
@@ -19,3 +23,15 @@ def quantize_per_tensor(ten: Union[Union[SparseConvTensor, torch.Tensor], List[U
else
:
return
torch
.
quantize_per_tensor
(
ten
,
scale
,
zero_point
,
dtype
)
def
quantized_add
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
scale
,
zero_point
):
x_detach
=
torch
.
zeros
(
size
=
x
.
shape
,
dtype
=
torch
.
int8
,
device
=
x
.
device
)
y_detach
=
torch
.
zeros
(
size
=
y
.
shape
,
dtype
=
torch
.
int8
,
device
=
y
.
device
)
torch_tensor_to_tv
(
x_detach
).
copy_
(
torch_tensor_to_tv
(
x
))
torch_tensor_to_tv
(
y_detach
).
copy_
(
torch_tensor_to_tv
(
y
))
res
=
(
x_detach
.
to
(
torch
.
float32
)
*
x
.
q_scale
()
+
y_detach
.
to
(
torch
.
float32
)
*
y
.
q_scale
())
/
scale
res
=
torch
.
clip
(
torch
.
round
(
res
),
-
128
,
127
).
to
(
torch
.
int8
)
res_q
=
torch
.
_empty_affine_quantized
(
size
=
res
.
shape
,
dtype
=
torch
.
qint8
,
scale
=
scale
,
zero_point
=
zero_point
,
device
=
x
.
device
)
torch_tensor_to_tv
(
res_q
,
tv
.
int8
).
copy_
(
torch_tensor_to_tv
(
res
))
return
res_q
spconv/pytorch/quantization/fake_q.py
View file @
5b3fe9e7
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
import
torch
from
torch.ao.quantization
import
get_default_qat_qconfig
,
get_default_qconfig
from
torch.ao.quantization.fake_quantize
import
(
FixedQParamsFakeQuantize
,
FusedMovingAvgObsFakeQuantize
,
FakeQuantize
,
default_fused_per_channel_wt_fake_quant
,
default_weight_fake_quant
,
default_per_channel_weight_fake_quant
)
from
torch.ao.quantization.observer
import
(
HistogramObserver
,
MovingAverageMinMaxObserver
,
default_weight_observer
,
default_placeholder_observer
,
default_per_channel_weight_observer
)
from
torch.ao.quantization.qconfig
import
QConfig
,
QConfigAny
,
default_reuse_input_qconfig
from
torch.ao.quantization.qconfig_mapping
import
QConfigMapping
,
_FIXED_QPARAMS_OP_TO_OBSERVER
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Union
,
List
from
torch.ao.quantization
import
get_default_qconfig
,
get_default_qat_qconfig
FakeQuantize
,
FixedQParamsFakeQuantize
,
FusedMovingAvgObsFakeQuantize
,
default_fused_per_channel_wt_fake_quant
,
default_per_channel_weight_fake_quant
,
default_weight_fake_quant
)
from
torch.ao.quantization.observer
import
(
MinMaxObserver
,
HistogramObserver
,
MovingAverageMinMaxObserver
,
default_per_channel_weight_observer
,
default_placeholder_observer
,
default_weight_observer
)
from
torch.ao.quantization.qconfig
import
(
QConfig
,
QConfigAny
,
default_reuse_input_qconfig
)
from
torch.ao.quantization.qconfig_mapping
import
(
_FIXED_QPARAMS_OP_TO_OBSERVER
,
QConfigMapping
)
from
spconv.pytorch.core
import
SparseConvTensor
from
spconv.pytorch.modules
import
PrintTensorMeta
,
PrintCurrentTime
__all__
=
[
"get_default_spconv_trt_ptq_qconfig"
,
"get_default_spconv_trt_qat_qconfig"
]
...
...
@@ -26,6 +32,16 @@ class SparseFusedMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
else
:
return
super
().
forward
(
input
)
class
SparseMovingAvgObsFakeQuantize
(
FakeQuantize
):
def
forward
(
self
,
input
:
Union
[
SparseConvTensor
,
torch
.
Tensor
]):
if
isinstance
(
input
,
SparseConvTensor
):
# add lines to support spconv
x
=
input
.
features
res_features
=
super
().
forward
(
x
)
return
input
.
replace_feature
(
res_features
)
else
:
return
super
().
forward
(
input
)
# class SparseMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
# def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
# if isinstance(input, SparseConvTensor):
...
...
@@ -46,6 +62,16 @@ class SparseHistogramObserver(HistogramObserver):
else
:
return
super
().
forward
(
input
)
class
SparseMinMaxObserver
(
MinMaxObserver
):
def
forward
(
self
,
input
:
Union
[
SparseConvTensor
,
torch
.
Tensor
]):
if
isinstance
(
input
,
SparseConvTensor
):
# add lines to support spconv
x
=
input
.
features
res_features
=
super
().
forward
(
x
)
return
input
.
replace_feature
(
res_features
)
else
:
return
super
().
forward
(
input
)
default_symmetric_spconv_ptq_qconfig
=
QConfig
(
activation
=
SparseHistogramObserver
.
with_args
(
quant_min
=-
128
,
quant_max
=
127
,
...
...
@@ -143,6 +169,8 @@ def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "fbgemm", ve
.
set_object_type
(
torch
.
nn
.
functional
.
leaky_relu
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
Tanh
,
qconfig
)
\
.
set_object_type
(
torch
.
nn
.
functional
.
tanh
,
qconfig
)
qconfig_mapping
.
set_object_type
(
PrintTensorMeta
,
None
)
qconfig_mapping
.
set_object_type
(
PrintCurrentTime
,
None
)
return
qconfig_mapping
spconv/pytorch/quantization/graph.py
0 → 100644
View file @
5b3fe9e7
import
torch.fx
import
torch
from
torch
import
nn
from
typing
import
Dict
,
Optional
from
spconv.pytorch.quantization.core
import
quantize_per_tensor
,
quantized_add
import
spconv.pytorch.quantization.intrinsic.quantized
as
snniq
def
is_dequantize_node
(
node
):
return
isinstance
(
node
,
torch
.
fx
.
Node
)
and
node
.
op
==
"call_method"
and
node
.
target
==
"dequantize"
def
_get_module
(
node
:
torch
.
fx
.
Node
,
modules
:
Dict
[
str
,
nn
.
Module
])
->
Optional
[
nn
.
Module
]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if
node
.
op
==
"call_module"
and
str
(
node
.
target
)
in
modules
:
return
modules
[
str
(
node
.
target
)]
else
:
return
None
def
remove_conv_add_dq
(
model
:
torch
.
fx
.
graph_module
.
GraphModule
):
modules
=
dict
(
model
.
named_modules
(
remove_duplicate
=
False
))
for
n
in
model
.
graph
.
nodes
:
if
(
n
.
op
==
"call_module"
and
type
(
_get_module
(
n
,
modules
))
==
snniq
.
SparseConvAddReLU
):
# check second input, if it's dequantized, remove that dequantize node
arg1
=
n
.
args
[
1
]
if
is_dequantize_node
(
arg1
):
dq_node
=
arg1
assert
(
isinstance
(
dq_node
,
torch
.
fx
.
Node
))
dn_input
=
dq_node
.
args
[
0
]
n
.
replace_input_with
(
dq_node
,
dn_input
)
model
.
graph
.
eliminate_dead_code
()
model
.
recompile
()
model
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
return
model
def
transform_qdq
(
m
:
torch
.
fx
.
GraphModule
)
->
torch
.
fx
.
GraphModule
:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for
node
in
m
.
graph
.
nodes
:
# Checks if we're calling a function (i.e:
# torch.add)
if
node
.
op
==
'call_function'
:
# The target attribute is the function
# that call_function calls.
if
node
.
target
==
torch
.
quantize_per_tensor
:
node
.
target
=
quantize_per_tensor
if
node
.
target
==
torch
.
ops
.
quantized
.
add
:
node
.
target
=
quantized_add
m
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
m
.
recompile
()
return
m
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
View file @
5b3fe9e7
...
...
@@ -14,16 +14,18 @@
from
typing
import
Optional
from
spconv.pytorch.core
import
SparseConvTensor
from
spconv.pytorch.cppcore
import
get_current_stream
import
spconv.pytorch.quantization.quantized
as
nnq
from
spconv.pytorch.quantization.intrinsic
import
SpconvReLUNd
,
SpconvAddReLUNd
from
cumm
import
tensorview
as
tv
from
spconv.pytorch.quantization.utils
import
fuse_spconv_bn_weights
import
torch.ao.nn.intrinsic
as
nni
import
spconv.pytorch.quantization.intrinsic.qat
as
snniqat
import
spconv.pytorch.quantization.intrinsic
as
snni
import
torch
__all__
=
[
"SparseConvReLU"
,
"SparseConvAddReLU"
]
__all__
=
[
"SparseConvReLU"
,
"SparseConvAddReLU"
,
"LinearPerChannelWeightReLU"
]
class
SparseConvReLU
(
nnq
.
SparseConv
):
r
"""
...
...
@@ -38,6 +40,10 @@ class SparseConvReLU(nnq.SparseConv):
_FLOAT_MODULE
=
SpconvReLUNd
# type: ignore[assignment]
def
forward
(
self
,
input
):
msg
=
f
"
{
input
.
features
.
shape
[
0
]
}
,
{
input
.
features
.
shape
[
1
]
}
,
{
self
.
weight
().
shape
[
0
]
}
"
with
tv
.
measure_and_print
(
f
"QuantizedSparseConvReLU|
{
msg
}
"
,
get_current_stream
(),
enable
=
False
):
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
().
to
(
torch
.
float32
)
out_scale
=
self
.
scale
...
...
@@ -80,6 +86,9 @@ class SparseConvAddReLU(nnq.SparseConv):
_FLOAT_MODULE
=
SpconvAddReLUNd
# type: ignore[assignment]
def
forward
(
self
,
input
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
msg
=
f
"
{
input
.
features
.
shape
[
0
]
}
,
{
input
.
features
.
shape
[
1
]
}
,
{
self
.
weight
().
shape
[
0
]
}
"
with
tv
.
measure_and_print
(
f
"QuantizedSparseConvAddReLU|
{
msg
}
"
,
get_current_stream
(),
enable
=
False
):
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
().
to
(
torch
.
float32
)
out_scale
=
self
.
scale
...
...
@@ -92,7 +101,7 @@ class SparseConvAddReLU(nnq.SparseConv):
return
res
def
_get_name
(
self
):
return
'QuantizedSparseConvReLU'
return
'QuantizedSparseConv
Add
ReLU'
@
classmethod
def
from_float
(
cls
,
mod
):
...
...
@@ -107,3 +116,44 @@ class SparseConvAddReLU(nnq.SparseConv):
assert
type
(
ref_qconv
)
!=
snni
.
SpconvBnReLUNd
,
\
"BatchNorm1d should be fused into Conv1d before converting to reference module"
return
super
().
from_reference
(
ref_qconv
[
0
],
output_scale
,
output_zero_point
)
class
LinearPerChannelWeightReLU
(
nnq
.
LinearPerChannelWeight
):
r
"""
A LinearPerChannelWeight module fused from Linear and ReLU modules
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
Attributes:
Same as torch.ao.nn.quantized.Linear
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE
=
nni
.
LinearReLU
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
qint8
):
super
().
__init__
(
in_features
,
out_features
,
bias
,
dtype
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
,
nvrtc_params
=
self
.
_linear_fwd
(
x
,
self
.
weight
(),
self
.
bias
(),
self
.
scale
,
tv
.
gemm
.
Activation
.
ReLU
,
self
.
_nvrtc_params
)
if
self
.
_nvrtc_params
is
None
:
self
.
_nvrtc_params
=
nvrtc_params
return
out
def
_get_name
(
self
):
return
'QuantizedLinearPerChannelWeightReLU'
@
classmethod
def
from_float
(
cls
,
mod
):
return
super
(
LinearPerChannelWeightReLU
,
cls
).
from_float
(
mod
)
@
classmethod
def
from_reference
(
cls
,
ref_linear_relu
,
output_scale
,
output_zero_point
):
return
super
().
from_reference
(
ref_linear_relu
[
0
],
output_scale
,
output_zero_point
)
spconv/pytorch/quantization/quantized/__init__.py
View file @
5b3fe9e7
...
...
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.conv
import
SparseConv
\ No newline at end of file
from
.conv
import
SparseConv
,
LinearPerChannelWeight
\ No newline at end of file
spconv/pytorch/quantization/quantized/conv.py
View file @
5b3fe9e7
...
...
@@ -16,11 +16,29 @@ from spconv.pytorch.core import SparseConvTensor
from
torch._ops
import
ops
from
torch.nn.common_types
import
_size_1_t
from
torch.nn.modules.utils
import
_single
,
_pair
,
_triple
from
collections.abc
import
Iterable
from
torch.ao.nn.quantized.modules.utils
import
WeightedQuantizedModule
,
_quantize_weight
import
spconv.pytorch.quantization.intrinsic.qat.modules
as
snniqat
import
spconv.pytorch.quantization.intrinsic.modules
as
snni
from
spconv.pytorch.quantization.utils
import
fuse_spconv_bn_eval
,
fuse_spconv_bn_weights
from
cumm.tensorview.gemm
import
ConvParams
,
GemmAlgoDesp
,
GemmParams
from
cumm.tensorview.gemm
import
ConvAlgoDesp
from
cumm.tensorview.gemm
import
ConvOpType
as
ConvOpTypeCpp
from
spconv.constants
import
(
NDIM_DONT_CARE
,
SPCONV_BWD_SPLITK
,
SPCONV_NVRTC_MODE
,
SPCONV_DEBUG_NVRTC_KERNELS
)
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
from
spconv
import
algocore
from
spconv.pytorch.cppcore
import
torch_tensor_to_tv
,
get_current_stream
import
torch.ao.nn.intrinsic
as
nni
import
torch.nn.intrinsic.qat
as
nniqat
from
torch.nn.utils.fusion
import
fuse_linear_bn_weights
from
torch.nn.utils.parametrize
import
type_before_parametrizations
from
spconv.algo
import
_get_nvrtc_params
,
SimpleConv
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
cumm.gemm.algospec.core
import
TensorOp
class
_SparseConv
(
SparseConvolutionBase
,
WeightedQuantizedModule
):
_FLOAT_MODULE
=
SparseConvolution
...
...
@@ -359,3 +377,319 @@ class SparseConv(_SparseConv):
return
_SparseConv
.
from_float
(
cls
,
mod
)
class
LinearPerChannelWeight
(
WeightedQuantizedModule
):
r
"""
A quantized linear module with quantized tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
This module use conv int8 in cumm to provide qcuda int8 debug.
Similar to :class:`~torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized to zero.
scale: `scale` parameter of output Quantized Tensor, type: double
zero_point: `zero_point` parameter for output Quantized Tensor, type: long
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> m = nn.quantized.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> # xdoctest: +SKIP
>>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_version
=
3
_FLOAT_MODULE
=
(
nn
.
Linear
,
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
)
CUMM_CONV_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
2
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
access_per_vector
=
1
,
is_nvrtc
=
True
,
int8_inference
=
True
,
dynamic_mask
=
False
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
2
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,f32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
16
)),
access_per_vector
=
0
,
is_nvrtc
=
True
,
int8_inference
=
True
,
dynamic_mask
=
False
),
]
def
__init__
(
self
,
in_features
,
out_features
,
bias_
=
True
,
dtype
=
torch
.
qint8
):
super
().
__init__
()
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
self
.
in_features
=
in_features
self
.
out_features
=
out_features
bias
=
None
if
bias_
:
bias
=
torch
.
zeros
(
out_features
,
dtype
=
torch
.
float
)
if
dtype
==
torch
.
qint8
:
qweight
=
torch
.
_empty_affine_quantized
(
[
out_features
,
in_features
],
scale
=
1
,
zero_point
=
0
,
dtype
=
torch
.
qint8
)
elif
dtype
==
torch
.
float16
:
qweight
=
torch
.
zeros
([
out_features
,
in_features
],
dtype
=
torch
.
float
)
else
:
raise
RuntimeError
(
'Unsupported dtype specified for quantized Linear!'
)
self
.
_weight
:
torch
.
Tensor
=
qweight
self
.
_bias
:
Optional
[
torch
.
Tensor
]
=
bias
self
.
scale
=
1.0
self
.
zero_point
=
0
self
.
_nvrtc_params
=
None
# this standard int8 conv operators is used for only quantization debug (to implement quantized Linear/Conv for qcuda backend)
def
_get_name
(
self
):
return
'QuantizedLinearPerChannelWeight'
def
extra_repr
(
self
):
return
'in_features={}, out_features={}, scale={}, zero_point={}, qscheme={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
scale
,
self
.
zero_point
,
self
.
weight
().
qscheme
()
)
@
staticmethod
def
_linear_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
scale
:
float
,
act
:
tv
.
gemm
.
Activation
,
nvrtc_params
):
is_ref
=
True
inp_scale
=
x
.
q_scale
()
w_scales
=
weight
.
q_per_channel_scales
().
to
(
torch
.
float32
)
out_scale
=
scale
channel_scale
=
(
inp_scale
*
w_scales
)
/
out_scale
channel_k
=
weight
.
size
(
0
)
channel_c
=
weight
.
size
(
-
1
)
if
bias
is
not
None
:
bias
=
bias
/
out_scale
else
:
bias
=
torch
.
zeros
([
channel_k
],
dtype
=
torch
.
float32
,
device
=
x
.
device
)
ldi
=
x
.
size
(
-
1
)
ldw
=
weight
.
size
(
-
1
)
ldo
=
weight
.
size
(
0
)
params
=
ConvParams
(
2
,
ConvOpTypeCpp
(
ConvOpType
.
kForward
.
value
))
assert
len
(
LinearPerChannelWeight
.
CUMM_CONV_PARAMS
)
==
2
algo_desp_fast
=
algocore
.
get_conv_algo_desp_from_param
(
LinearPerChannelWeight
.
CUMM_CONV_PARAMS
[
0
])
algo_desp_generic
=
algocore
.
get_conv_algo_desp_from_param
(
LinearPerChannelWeight
.
CUMM_CONV_PARAMS
[
1
])
algo_desp
=
algo_desp_fast
if
not
algo_desp_fast
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
):
algo_desp
=
algo_desp_generic
# if not algo_desp.supported_ldx_conv(ldi, ldw, ldo):
# breakpoint()
if
is_ref
:
x_detach
=
torch
.
zeros
(
size
=
x
.
size
(),
dtype
=
torch
.
int8
,
device
=
x
.
device
)
weight_detach
=
torch
.
zeros
(
size
=
weight
.
size
(),
dtype
=
torch
.
int8
,
device
=
x
.
device
)
torch_tensor_to_tv
(
x_detach
).
copy_
(
torch_tensor_to_tv
(
x
))
torch_tensor_to_tv
(
weight_detach
).
copy_
(
torch_tensor_to_tv
(
weight
))
# o_tmp = torch.from_numpy(x_detach.to(torch.int32).cpu().numpy() @ weight_detach.to(torch.int32).cpu().numpy().T).to(x.device)
o_tmp
=
x_detach
.
to
(
torch
.
float32
)
@
weight_detach
.
to
(
torch
.
float32
).
T
o_tmp
=
o_tmp
.
to
(
torch
.
float32
)
*
channel_scale
+
bias
if
act
==
tv
.
gemm
.
Activation
.
ReLU
:
o_tmp
=
torch
.
maximum
(
o_tmp
,
torch
.
tensor
(
0
,
dtype
=
o_tmp
.
dtype
,
device
=
x
.
device
))
o_tmp
=
torch
.
clip
(
torch
.
round
(
o_tmp
),
-
128
,
127
).
to
(
torch
.
int8
)
output
=
torch
.
_empty_affine_quantized
(
o_tmp
.
shape
,
scale
=
scale
,
zero_point
=
0
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
torch_tensor_to_tv
(
output
).
copy_
(
torch_tensor_to_tv
(
o_tmp
))
return
output
,
None
else
:
assert
algo_desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
out_shape
=
[
x
.
size
(
0
),
weight
.
size
(
0
)
]
output
=
torch
.
_empty_affine_quantized
(
out_shape
,
scale
=
scale
,
zero_point
=
0
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
params
.
conv_algo_desp
=
algo_desp
params
.
input
=
torch_tensor_to_tv
(
x
).
view
([
x
.
size
(
0
),
1
,
1
,
channel_c
])
params
.
verbose
=
False
params
.
weight
=
torch_tensor_to_tv
(
weight
).
view
([
channel_k
,
1
,
1
,
channel_c
])
params
.
output
=
torch_tensor_to_tv
(
output
).
view
([
x
.
size
(
0
),
1
,
1
,
channel_k
])
params
.
split_k_slices
=
1
params
.
alpha
=
1.0
params
.
beta
=
0.0
params
.
act_alpha
=
1.0
params
.
act_beta
=
0.0
params
.
act_type
=
act
params
.
padding
=
[
0
,
0
]
params
.
stride
=
[
1
,
1
]
params
.
dilation
=
[
1
,
1
]
params
.
stream
=
get_current_stream
()
if
nvrtc_params
is
None
:
mod
,
ker
=
SimpleConv
.
_compile_nvrtc_module
(
algo_desp
)
nvrtc_params
=
_get_nvrtc_params
(
mod
,
ker
,
"conv_kernel"
)
params
.
bias
=
torch_tensor_to_tv
(
bias
)
params
.
scale
=
torch_tensor_to_tv
(
channel_scale
)
params
.
nvrtc_params
=
nvrtc_params
tv
.
gemm
.
run_nvrtc_conv_kernel
(
params
)
return
output
,
nvrtc_params
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
,
nvrtc_params
=
self
.
_linear_fwd
(
x
,
self
.
weight
(),
self
.
bias
(),
self
.
scale
,
tv
.
gemm
.
Activation
.
None_
,
self
.
_nvrtc_params
)
if
self
.
_nvrtc_params
is
None
:
self
.
_nvrtc_params
=
nvrtc_params
return
out
# ===== Serialization methods =====
# The special consideration here is that we have to unpack the weights into their
# regular QTensor form for serialization. Packed weights should not live
# outside the process in which they were created, rather they should be derived
# from the QTensor weight.
#
# Version 1
# self
# |--- scale : float
# |--- zero_point : int
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 2
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 3
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- _packed_params : (Tensor, Tensor) representing weight, bias
# of LinearPackedParams C++ struct
#
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
destination
[
prefix
+
'scale'
]
=
torch
.
tensor
(
self
.
scale
)
destination
[
prefix
+
'zero_point'
]
=
torch
.
tensor
(
self
.
zero_point
)
(
w
,
b
)
=
self
.
_weight_bias
()
destination
[
prefix
+
'weight'
]
=
w
destination
[
prefix
+
'bias'
]
=
b
# ===== Deserialization methods =====
# Counterpart to the serialization methods, we must pack the serialized QTensor
# weight into its packed format for use by the FBGEMM ops.
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
self
.
scale
=
float
(
state_dict
[
prefix
+
'scale'
])
state_dict
.
pop
(
prefix
+
'scale'
)
self
.
zero_point
=
int
(
state_dict
[
prefix
+
'zero_point'
])
state_dict
.
pop
(
prefix
+
'zero_point'
)
version
=
local_metadata
.
get
(
'version'
,
None
)
# if version is None or version == 1:
# # We moved the parameters into a LinearPackedParameters submodule
# weight = state_dict.pop(prefix + 'weight')
# bias = state_dict.pop(prefix + 'bias')
# state_dict.update({prefix + '_packed_params.weight': weight,
# prefix + '_packed_params.bias': bias})
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
False
,
missing_keys
,
unexpected_keys
,
error_msgs
)
# Function rather than property to make sure that JIT serialization doesn't
# register this as an attribute
def
_weight_bias
(
self
):
return
(
self
.
_weight
,
self
.
_bias
)
def
weight
(
self
):
return
self
.
_weight_bias
()[
0
]
def
bias
(
self
):
return
self
.
_weight_bias
()[
1
]
def
set_weight_bias
(
self
,
w
:
torch
.
Tensor
,
b
:
Optional
[
torch
.
Tensor
])
->
None
:
self
.
_weight
=
w
self
.
_bias
=
b
# self._packed_params.set_weight_bias(w, b)
@
classmethod
def
from_float
(
cls
,
mod
):
r
"""Create a quantized module from an observed float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
if
hasattr
(
mod
,
'weight_fake_quant'
):
if
type_before_parametrizations
(
mod
)
==
nniqat
.
LinearBn1d
:
mod
.
weight
,
mod
.
bias
=
fuse_linear_bn_weights
(
mod
.
weight
,
mod
.
bias
,
mod
.
bn
.
running_mean
,
mod
.
bn
.
running_var
,
mod
.
bn
.
eps
,
mod
.
bn
.
weight
,
mod
.
bn
.
bias
)
weight_post_process
=
mod
.
weight_fake_quant
activation_post_process
=
mod
.
activation_post_process
else
:
# This function does not participate in JIT, so it is OK to ignore
# the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too.
if
not
isinstance
(
cls
.
_FLOAT_MODULE
,
Iterable
):
cls
.
_FLOAT_MODULE
=
[
cls
.
_FLOAT_MODULE
]
# type: ignore[assignment]
supported_modules
=
', '
.
join
([
float_mod
.
__name__
for
float_mod
in
cls
.
_FLOAT_MODULE
])
# type: ignore[attr-defined]
error_msg
=
'nnq.{}.from_float only works for {}, but got: {}'
.
format
(
cls
.
__name__
,
supported_modules
,
type
(
mod
))
assert
type_before_parametrizations
(
mod
)
in
cls
.
_FLOAT_MODULE
,
error_msg
.
format
()
# type: ignore[attr-defined]
assert
hasattr
(
mod
,
'qconfig'
),
'Input float module must have qconfig defined'
activation_post_process
=
mod
.
activation_post_process
if
type_before_parametrizations
(
mod
)
==
nni
.
LinearReLU
:
mod
=
mod
[
0
]
weight_post_process
=
mod
.
qconfig
.
weight
()
weight_post_process
(
mod
.
weight
)
dtype
=
weight_post_process
.
dtype
act_scale
,
act_zp
=
activation_post_process
.
calculate_qparams
()
assert
dtype
==
torch
.
qint8
,
'Weight observer must have dtype torch.qint8'
qweight
=
_quantize_weight
(
mod
.
weight
.
float
(),
weight_post_process
)
qlinear
=
cls
(
mod
.
in_features
,
mod
.
out_features
,
dtype
=
dtype
)
qlinear
.
set_weight_bias
(
qweight
,
mod
.
bias
)
qlinear
.
scale
=
float
(
act_scale
)
qlinear
.
zero_point
=
int
(
act_zp
)
return
qlinear
@
classmethod
def
from_reference
(
cls
,
ref_qlinear
,
output_scale
,
output_zero_point
):
r
"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
output_zero_point (int): zero point for output Tensor
"""
qlinear
=
cls
(
ref_qlinear
.
in_features
,
ref_qlinear
.
out_features
)
qweight
=
ref_qlinear
.
get_quantized_weight
()
qlinear
.
set_weight_bias
(
qweight
,
ref_qlinear
.
bias
)
qlinear
.
scale
=
float
(
output_scale
)
qlinear
.
zero_point
=
int
(
output_zero_point
)
return
qlinear
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment