Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
1f6deed6
Commit
1f6deed6
authored
Jan 17, 2023
by
yan.yan
Browse files
prepare int8 release
parent
5b3fe9e7
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
96 additions
and
86 deletions
+96
-86
example/mnist/mnist_qat.py
example/mnist/mnist_qat.py
+3
-51
spconv/core_cc/csrc/sparse/all/__init__.pyi
spconv/core_cc/csrc/sparse/all/__init__.pyi
+6
-3
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+11
-4
spconv/csrc/sparse/alloc.py
spconv/csrc/sparse/alloc.py
+2
-0
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+1
-1
spconv/gencode/__main__.py
spconv/gencode/__main__.py
+5
-2
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+6
-2
spconv/pytorch/core.py
spconv/pytorch/core.py
+7
-1
spconv/pytorch/cppcore.py
spconv/pytorch/cppcore.py
+5
-4
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+27
-10
spconv/pytorch/quantization/backend_cfg.py
spconv/pytorch/quantization/backend_cfg.py
+2
-2
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
+1
-2
spconv/pytorch/quantization/quantized/conv.py
spconv/pytorch/quantization/quantized/conv.py
+12
-1
spconv/pytorch/quantization/quantized/reference.py
spconv/pytorch/quantization/quantized/reference.py
+2
-0
test/test_all_algo.py
test/test_all_algo.py
+6
-3
No files found.
example/mnist/mnist_qat.py
View file @
1f6deed6
...
...
@@ -317,6 +317,7 @@ class ResidualNetPTQ(nn.Module):
super
(
ResidualNetPTQ
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
SubMConvBNReLU
(
1
,
32
,
3
),
# SubMConvBNReLU(32, 32, 3),
SparseBasicBlock2
(
32
,
32
),
SubMConvBNReLU
(
32
,
64
,
3
),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 14x14
...
...
@@ -474,55 +475,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else
:
output
=
model
(
image
)
def
is_dequantize_node
(
node
):
return
isinstance
(
node
,
torch
.
fx
.
Node
)
and
node
.
op
==
"call_method"
and
node
.
target
==
"dequantize"
def
_get_module
(
node
:
torch
.
fx
.
Node
,
modules
:
Dict
[
str
,
nn
.
Module
])
->
Optional
[
nn
.
Module
]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if
node
.
op
==
"call_module"
and
str
(
node
.
target
)
in
modules
:
return
modules
[
str
(
node
.
target
)]
else
:
return
None
def
remove_conv_add_dq
(
model
:
torch
.
fx
.
graph_module
.
GraphModule
):
modules
=
dict
(
model
.
named_modules
(
remove_duplicate
=
False
))
for
n
in
model
.
graph
.
nodes
:
if
(
n
.
op
==
"call_module"
and
type
(
_get_module
(
n
,
modules
))
==
snniq
.
SparseConvAddReLU
):
# check second input, if it's dequantized, remove that dequantize node
arg1
=
n
.
args
[
1
]
if
is_dequantize_node
(
arg1
):
dq_node
=
arg1
assert
(
isinstance
(
dq_node
,
torch
.
fx
.
Node
))
dn_input
=
dq_node
.
args
[
0
]
n
.
replace_input_with
(
dq_node
,
dn_input
)
model
.
graph
.
eliminate_dead_code
()
model
.
recompile
()
model
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
return
model
def
transform_qdq
(
m
:
torch
.
fx
.
GraphModule
)
->
torch
.
fx
.
GraphModule
:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for
node
in
m
.
graph
.
nodes
:
# Checks if we're calling a function (i.e:
# torch.add)
if
node
.
op
==
'call_function'
:
# The target attribute is the function
# that call_function calls.
if
node
.
target
==
torch
.
quantize_per_tensor
:
node
.
target
=
quantize_per_tensor
m
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
m
.
recompile
()
return
m
def
main
():
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
...
...
@@ -562,7 +514,7 @@ def main():
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--sparse'
,
action
=
'store_true'
,
default
=
Fals
e
,
default
=
Tru
e
,
help
=
'use sparse conv network instead of dense'
)
parser
.
add_argument
(
'--log-interval'
,
...
...
@@ -589,7 +541,7 @@ def main():
qdevice
=
torch
.
device
(
"cuda"
if
use_cuda
and
args
.
sparse
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
if
args
.
sparse
:
model
=
NetV2
().
to
(
device
)
model
=
ResidualNetPTQ
().
to
(
device
)
else
:
model
=
NetDense
().
to
(
device
)
...
...
spconv/core_cc/csrc/sparse/all/__init__.pyi
View file @
1f6deed6
...
...
@@ -380,7 +380,7 @@ class SpconvOps:
"""
...
@staticmethod
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1
, do_sort: bool = True
) -> Tensor:
"""
Args:
data:
...
...
@@ -388,10 +388,11 @@ class SpconvOps:
indices:
stream:
mask_count:
do_sort:
"""
...
@staticmethod
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1
, do_sort: bool = True
) -> Tensor:
"""
Args:
data:
...
...
@@ -399,6 +400,7 @@ class SpconvOps:
indices:
stream:
mask_count:
do_sort:
"""
...
@staticmethod
...
...
@@ -555,7 +557,7 @@ class SpconvOps:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}
, do_sort: bool = True
) -> Tuple[Tensor, int]:
"""
Args:
allocator:
...
...
@@ -576,6 +578,7 @@ class SpconvOps:
timer:
direct_table:
preallocated:
do_sort:
"""
...
@staticmethod
...
...
spconv/csrc/sparse/all.py
View file @
1f6deed6
...
...
@@ -922,6 +922,8 @@ class SpconvOps(pccm.Class):
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_count"
,
"int"
,
"1"
,
pyanno
=
"int"
)
code
.
arg
(
"do_sort"
,
"bool"
,
"true"
)
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
self
.
cuda_common_kernel
)
if
not
use_allocator
:
...
...
@@ -935,6 +937,9 @@ class SpconvOps(pccm.Class):
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
if (!do_sort){{
return indices;
}}
// auto timer = tv::CUDATimer();
"""
)
# nested tv::dispatch may cause compiler bug in msvc.
...
...
@@ -1645,6 +1650,7 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"preallocated"
,
f
"std::unordered_map<std::string, tv::Tensor>"
,
"std::unordered_map<std::string, tv::Tensor>{}"
,
"Dict[str, cumm.tensorview.Tensor] = {}"
)
code
.
arg
(
"do_sort"
,
f
"bool"
,
"true"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
...
...
@@ -1788,7 +1794,7 @@ class SpconvOps(pccm.Class):
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count
, do_sort
);
}}
"""
)
with
code
.
else_
():
...
...
@@ -1952,6 +1958,7 @@ Your Conv Params: )" << "\\n";
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_sort",
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (is_mask_split){{
TV_ASSERT_RT_ERR(do_sort, "not implemented for now");
for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
if (!is_train){{
...
...
@@ -1967,12 +1974,12 @@ Your Conv Params: )" << "\\n";
}}else{{
if (!is_train){{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
mask_argsort_fwd[0], stream_int, mask_int_count
, do_sort
);
}}else{{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
mask_argsort_fwd[0], stream_int, mask_int_count
, do_sort
);
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count);
mask_argsort_bwd[0], stream_int, mask_int_count
, do_sort
);
}}
}}
}}
...
...
spconv/csrc/sparse/alloc.py
View file @
1f6deed6
...
...
@@ -304,6 +304,7 @@ class StaticAllocator(ExternalAllocator):
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"scale"
,
"float"
,
"1.0"
)
code
.
raw
(
f
"""
auto tvctx = tv::Context();
"""
)
...
...
@@ -328,6 +329,7 @@ class StaticAllocator(ExternalAllocator):
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"scale"
,
"float"
,
"1.0"
)
code
.
raw
(
f
"""
if (name ==
{
pccm
.
literal
(
AllocKeys
.
ThrustTemp
)
}
){{
// thrust tmp shouldn't inside tensor_dict. use a simple method to allocate
...
...
spconv/csrc/sparse/convops.py
View file @
1f6deed6
...
...
@@ -2201,7 +2201,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = output_add_scale;
beta = output_add_scale
/ output_scale
;
}}
if (j > 0){{
...
...
spconv/gencode/__main__.py
View file @
1f6deed6
...
...
@@ -34,8 +34,11 @@ def main(include: str,
cu
.
namespace
=
"cumm.gemm.main"
all_imp
=
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_AMPERE_PARAMS
)
# all_imp = IMPLGEMM_SIMT_PARAMS
all_imp
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_imp
))
# keep all int8 kernels in libspconv
for
x
in
all_imp
:
if
x
.
int8_inference
:
x
.
is_nvrtc
=
False
all_imp
=
list
(
filter
(
lambda
x
:
(
not
x
.
is_nvrtc
),
all_imp
))
if
inference_only
:
all_imp
=
list
(
filter
(
lambda
x
:
x
.
op_type
==
ConvOpType
.
kForward
,
all_imp
))
convcu
=
ConvMainUnitTest
(
all_imp
)
...
...
spconv/pytorch/conv.py
View file @
1f6deed6
...
...
@@ -137,6 +137,9 @@ class SparseConvolutionBase:
if
self
.
conv1x1
:
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
def
is_inverseable
(
self
):
return
self
.
indice_key
is
not
None
and
not
self
.
subm
def
_conv_forward
(
self
,
training
:
bool
,
input
:
SparseConvTensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
add_input
:
Optional
[
SparseConvTensor
]
=
None
,
channel_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
float
]
=
None
,
name
:
Optional
[
str
]
=
None
,
sparse_unique_name
:
str
=
""
,
...
...
@@ -681,6 +684,9 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
s
+=
', bias=False'
if
self
.
algo
is
not
None
:
s
+=
f
', algo=
{
self
.
algo
}
'
if
self
.
act_type
!=
tv
.
gemm
.
Activation
.
None_
:
s
+=
f
', act=
{
self
.
act_type
}
'
return
s
.
format
(
**
self
.
__dict__
)
def
_calculate_fan_in_and_fan_out
(
self
):
...
...
@@ -730,8 +736,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
bound
=
1
/
math
.
sqrt
(
fan_in
)
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
is_inverseable
(
self
):
return
self
.
indice_key
is
not
None
and
not
self
.
subm
def
forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
return
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight
,
self
.
bias
,
add_input
,
...
...
spconv/pytorch/core.py
View file @
1f6deed6
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Dict
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
,
Dict
import
numpy
as
np
import
torch
...
...
@@ -181,6 +181,10 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
force_algo
=
force_algo
self
.
int8_scale
:
Optional
[
np
.
ndarray
]
=
None
def
__repr__
(
self
):
return
f
"SparseConvTensor[shape=
{
self
.
_features
.
shape
}
]"
@
property
def
is_quantized
(
self
):
...
...
@@ -204,6 +208,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
new_spt
.
thrust_allocator
=
self
.
thrust_allocator
new_spt
.
_timer
=
self
.
_timer
new_spt
.
force_algo
=
self
.
force_algo
new_spt
.
int8_scale
=
self
.
int8_scale
return
new_spt
...
...
@@ -302,6 +307,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
tensor
.
thrust_allocator
=
self
.
thrust_allocator
tensor
.
_timer
=
self
.
_timer
tensor
.
force_algo
=
self
.
force_algo
tensor
.
int8_scale
=
self
.
int8_scale
return
tensor
def
expand_nd
(
ndim
:
int
,
val
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...],
np
.
ndarray
])
->
List
[
int
]:
...
...
spconv/pytorch/cppcore.py
View file @
1f6deed6
...
...
@@ -137,10 +137,11 @@ class TorchAllocator(ExternalAllocator):
else
:
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
).
zero_
()
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
# if self.is_quantized:
# ctx = tv.Context()
# ctx.set_cuda_stream(stream)
# ten_tv.zero_(ctx)
if
self
.
is_quantized
:
# no _zeros_affine_quantized available, so we need to zero_ here.
ctx
=
tv
.
Context
()
ctx
.
set_cuda_stream
(
stream
)
ten_tv
.
zero_
(
ctx
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
if
name
and
not
is_temp_memory
:
self
.
allocated
[
name
]
=
ten
...
...
spconv/pytorch/ops.py
View file @
1f6deed6
...
...
@@ -1468,6 +1468,7 @@ def implicit_gemm(features: torch.Tensor,
bias_tv
=
tv
.
Tensor
()
scale_tv
=
tv
.
Tensor
()
output_add_tv
=
tv
.
Tensor
()
is_int8
=
features
.
is_quantized
and
filters
.
is_quantized
if
output_add
is
not
None
:
assert
features
.
dtype
==
torch
.
qint8
,
"fused residual add only support int8"
if
bias
is
not
None
:
...
...
@@ -1535,6 +1536,23 @@ def implicit_gemm(features: torch.Tensor,
filters
=
filters
.
reshape
(
out_channel
,
-
1
,
filters
.
shape
[
-
1
])
kv
=
filters
.
shape
[
1
]
mask_int_count
=
div_up
(
kv
,
32
)
if
is_int8
:
if
is_subm
:
out_features
=
torch
.
_empty_affine_quantized
(
size
=
(
num_activate_out
,
out_channel
),
scale
=
output_scale
,
zero_point
=
0
,
dtype
=
features
.
dtype
,
device
=
features
.
device
)
# out_features = torch.empty((num_activate_out, out_channel),
# dtype=output_dtype,
# device=features.device)
else
:
out_features
=
torch
.
_empty_affine_quantized
(
size
=
(
num_activate_out
,
out_channel
),
scale
=
output_scale
,
zero_point
=
0
,
dtype
=
features
.
dtype
,
device
=
features
.
device
)
ctx
=
tv
.
Context
()
ctx
.
set_cuda_stream
(
stream
)
torch_tensor_to_tv
(
out_features
).
zero_
(
ctx
)
# out_features = torch.zeros((num_activate_out, out_channel),
# dtype=output_dtype,
# device=features.device)
else
:
if
is_subm
:
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
dtype
=
output_dtype
,
...
...
@@ -1543,7 +1561,6 @@ def implicit_gemm(features: torch.Tensor,
out_features
=
torch
.
zeros
((
num_activate_out
,
out_channel
),
dtype
=
output_dtype
,
device
=
features
.
device
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
features_tv
=
torch_tensor_to_tv
(
features
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
...
...
@@ -1617,7 +1634,7 @@ def implicit_gemm(features: torch.Tensor,
if
bias
is
not
None
and
not
tune_res
.
algo_desp
.
is_int8_inference
:
beta
=
1
if
output_add
is
not
None
and
tune_res
.
algo_desp
.
is_int8_inference
:
beta
=
output_add_scale
beta
=
output_add_scale
/
output_scale
CONV
.
run_with_tuned_result
(
tune_res
,
ConvOpType
.
kForward
,
...
...
@@ -1640,7 +1657,7 @@ def implicit_gemm(features: torch.Tensor,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
scale
=
scale_tv
,
output_add
=
output_add
)
output_add
=
output_add
_tv
)
return
out_features
,
mask_output_fwd
,
mask_width
...
...
spconv/pytorch/quantization/backend_cfg.py
View file @
1f6deed6
...
...
@@ -591,14 +591,14 @@ SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[
snni
.
SpconvReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvReLU
),
snni
.
SpconvAddReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvAddReLU
),
# use simple cumm i8 conv to implement linear
nni
.
LinearReLU
:
(
nnqr
.
Linear
,
snniq
.
LinearPerChannelWeightReLU
),
#
nni.LinearReLU: (nnqr.Linear, snniq.LinearPerChannelWeightReLU),
}
SPCONV_STATIC_LOWER_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]
=
{
snnqr
.
SpConv
:
snnq
.
SparseConv
,
nnqr
.
Linear
:
snnq
.
LinearPerChannelWeight
,
#
nnqr.Linear: snnq.LinearPerChannelWeight,
}
...
...
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
View file @
1f6deed6
...
...
@@ -14,7 +14,7 @@
from
typing
import
Optional
from
spconv.pytorch.core
import
SparseConvTensor
from
spconv.pytorch.cppcore
import
get_current_stream
from
spconv.pytorch.cppcore
import
get_current_stream
,
torch_tensor_to_tv
import
spconv.pytorch.quantization.quantized
as
nnq
from
spconv.pytorch.quantization.intrinsic
import
SpconvReLUNd
,
SpconvAddReLUNd
from
cumm
import
tensorview
as
tv
...
...
@@ -88,7 +88,6 @@ class SparseConvAddReLU(nnq.SparseConv):
def
forward
(
self
,
input
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
msg
=
f
"
{
input
.
features
.
shape
[
0
]
}
,
{
input
.
features
.
shape
[
1
]
}
,
{
self
.
weight
().
shape
[
0
]
}
"
with
tv
.
measure_and_print
(
f
"QuantizedSparseConvAddReLU|
{
msg
}
"
,
get_current_stream
(),
enable
=
False
):
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
().
to
(
torch
.
float32
)
out_scale
=
self
.
scale
...
...
spconv/pytorch/quantization/quantized/conv.py
View file @
1f6deed6
...
...
@@ -87,7 +87,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
bias_float
=
(
torch
.
zeros
(
out_channels
,
dtype
=
torch
.
float
,
**
{
k
:
v
for
k
,
v
in
factory_kwargs
.
items
()
if
k
!=
'dtype'
})
if
bias
else
None
)
self
.
_max_voxels
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
set_weight_bias
(
qweight
,
bias_float
)
self
.
scale
=
1.0
self
.
zero_point
=
0
...
...
@@ -96,6 +96,9 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
self
.
_weight
:
torch
.
Tensor
=
qweight
self
.
_bias
:
torch
.
Tensor
=
bias_float
def
set_max_voxels
(
self
,
max_voxel
):
self
.
_max_voxels
=
max_voxel
def
bias
(
self
):
return
self
.
_bias
...
...
@@ -137,6 +140,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
destination
[
prefix
+
'bias'
]
=
b
destination
[
prefix
+
'scale'
]
=
torch
.
tensor
(
self
.
scale
)
destination
[
prefix
+
'zero_point'
]
=
torch
.
tensor
(
self
.
zero_point
)
destination
[
prefix
+
'max_voxels'
]
=
torch
.
tensor
(
self
.
_max_voxels
)
# @torch.jit.export
# def __getstate__(self):
...
...
@@ -169,6 +173,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
state_dict
.
pop
(
prefix
+
'weight'
)
state_dict
.
pop
(
prefix
+
'bias'
)
self
.
scale
=
float
(
state_dict
[
prefix
+
'scale'
])
state_dict
.
pop
(
prefix
+
'max_voxels'
)
self
.
_max_voxels
=
state_dict
[
prefix
+
'max_voxels'
]
state_dict
.
pop
(
prefix
+
'scale'
)
self
.
zero_point
=
int
(
state_dict
[
prefix
+
'zero_point'
])
state_dict
.
pop
(
prefix
+
'zero_point'
)
...
...
@@ -213,6 +219,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
assert
weight_post_process
.
dtype
==
torch
.
qint8
,
\
'Weight observer must have a dtype of qint8'
qweight
=
_quantize_weight
(
mod
.
weight
.
float
(),
weight_post_process
)
# the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv
=
cls
(
mod
.
ndim
,
mod
.
in_channels
,
mod
.
out_channels
,
mod
.
kernel_size
,
mod
.
stride
,
mod
.
padding
,
mod
.
dilation
,
...
...
@@ -230,6 +237,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
act_alpha
=
mod
.
act_alpha
,
act_beta
=
mod
.
act_beta
)
qconv
.
set_weight_bias
(
qweight
,
mod
.
bias
)
if
mod
.
get_max_num_voxels
()
is
not
None
:
qconv
.
set_max_voxels
(
mod
.
get_max_num_voxels
())
if
activation_post_process
is
None
or
activation_post_process
.
dtype
==
torch
.
float
:
return
qconv
# dynamic quantization doesn't need scale/zero_point
else
:
...
...
@@ -295,6 +304,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
qconv
.
set_weight_bias
(
qweight
,
ref_qconv
.
bias
)
qconv
.
scale
=
float
(
output_scale
)
qconv
.
zero_point
=
int
(
output_zero_point
)
if
ref_qconv
.
get_max_num_voxels
()
is
not
None
:
qconv
.
set_max_voxels
(
ref_qconv
.
get_max_num_voxels
())
return
qconv
...
...
spconv/pytorch/quantization/quantized/reference.py
View file @
1f6deed6
...
...
@@ -85,6 +85,8 @@ class _SpConvNd(sconvmod.SparseConvolution, ReferenceQuantizedModule):
qref_conv
.
weight
=
torch
.
nn
.
Parameter
(
float_conv
.
weight
.
detach
())
if
float_conv
.
bias
is
not
None
:
qref_conv
.
bias
=
torch
.
nn
.
Parameter
(
float_conv
.
bias
.
detach
())
if
conv
.
get_max_num_voxels
()
is
not
None
:
qref_conv
.
get_max_num_voxels
()[:]
=
conv
.
get_max_num_voxels
()
return
qref_conv
...
...
test/test_all_algo.py
View file @
1f6deed6
...
...
@@ -273,6 +273,9 @@ class SparseConvTester:
if
self
.
check_int8_infer
:
rescaled
=
output_ref
.
astype
(
self
.
dtype_comp
)
*
self
.
scales
.
astype
(
self
.
dtype_comp
)
rescaled
+=
self
.
bias
.
astype
(
self
.
dtype_comp
)
if
self
.
subm
:
rescaled
+=
self
.
output_add
.
astype
(
self
.
dtype_comp
)
*
self
.
output_add_scale
else
:
rescaled
+=
self
.
output_add
[
self
.
out_order
].
astype
(
self
.
dtype_comp
)
*
self
.
output_add_scale
if
self
.
check_act
:
rescaled
=
np
.
maximum
(
rescaled
,
0
)
...
...
@@ -1020,8 +1023,8 @@ def _test_native_conv_cuda(subm: bool):
def
test_all_algo_unit
():
# for i in range(5):
#
_test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda
(
False
)
_test_impgemm_conv_cuda
(
True
)
#
_test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(False)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment