Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
1f6deed6
Commit
1f6deed6
authored
Jan 17, 2023
by
yan.yan
Browse files
prepare int8 release
parent
5b3fe9e7
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
96 additions
and
86 deletions
+96
-86
example/mnist/mnist_qat.py
example/mnist/mnist_qat.py
+3
-51
spconv/core_cc/csrc/sparse/all/__init__.pyi
spconv/core_cc/csrc/sparse/all/__init__.pyi
+6
-3
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+11
-4
spconv/csrc/sparse/alloc.py
spconv/csrc/sparse/alloc.py
+2
-0
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+1
-1
spconv/gencode/__main__.py
spconv/gencode/__main__.py
+5
-2
spconv/pytorch/conv.py
spconv/pytorch/conv.py
+6
-2
spconv/pytorch/core.py
spconv/pytorch/core.py
+7
-1
spconv/pytorch/cppcore.py
spconv/pytorch/cppcore.py
+5
-4
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+27
-10
spconv/pytorch/quantization/backend_cfg.py
spconv/pytorch/quantization/backend_cfg.py
+2
-2
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
+1
-2
spconv/pytorch/quantization/quantized/conv.py
spconv/pytorch/quantization/quantized/conv.py
+12
-1
spconv/pytorch/quantization/quantized/reference.py
spconv/pytorch/quantization/quantized/reference.py
+2
-0
test/test_all_algo.py
test/test_all_algo.py
+6
-3
No files found.
example/mnist/mnist_qat.py
View file @
1f6deed6
...
@@ -317,6 +317,7 @@ class ResidualNetPTQ(nn.Module):
...
@@ -317,6 +317,7 @@ class ResidualNetPTQ(nn.Module):
super
(
ResidualNetPTQ
,
self
).
__init__
()
super
(
ResidualNetPTQ
,
self
).
__init__
()
self
.
net
=
spconv
.
SparseSequential
(
self
.
net
=
spconv
.
SparseSequential
(
SubMConvBNReLU
(
1
,
32
,
3
),
SubMConvBNReLU
(
1
,
32
,
3
),
# SubMConvBNReLU(32, 32, 3),
SparseBasicBlock2
(
32
,
32
),
SparseBasicBlock2
(
32
,
32
),
SubMConvBNReLU
(
32
,
64
,
3
),
SubMConvBNReLU
(
32
,
64
,
3
),
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 14x14
SparseConvBNReLU
(
64
,
64
,
2
,
2
),
# 14x14
...
@@ -474,55 +475,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
...
@@ -474,55 +475,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else
:
else
:
output
=
model
(
image
)
output
=
model
(
image
)
def
is_dequantize_node
(
node
):
return
isinstance
(
node
,
torch
.
fx
.
Node
)
and
node
.
op
==
"call_method"
and
node
.
target
==
"dequantize"
def
_get_module
(
node
:
torch
.
fx
.
Node
,
modules
:
Dict
[
str
,
nn
.
Module
])
->
Optional
[
nn
.
Module
]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if
node
.
op
==
"call_module"
and
str
(
node
.
target
)
in
modules
:
return
modules
[
str
(
node
.
target
)]
else
:
return
None
def
remove_conv_add_dq
(
model
:
torch
.
fx
.
graph_module
.
GraphModule
):
modules
=
dict
(
model
.
named_modules
(
remove_duplicate
=
False
))
for
n
in
model
.
graph
.
nodes
:
if
(
n
.
op
==
"call_module"
and
type
(
_get_module
(
n
,
modules
))
==
snniq
.
SparseConvAddReLU
):
# check second input, if it's dequantized, remove that dequantize node
arg1
=
n
.
args
[
1
]
if
is_dequantize_node
(
arg1
):
dq_node
=
arg1
assert
(
isinstance
(
dq_node
,
torch
.
fx
.
Node
))
dn_input
=
dq_node
.
args
[
0
]
n
.
replace_input_with
(
dq_node
,
dn_input
)
model
.
graph
.
eliminate_dead_code
()
model
.
recompile
()
model
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
return
model
def
transform_qdq
(
m
:
torch
.
fx
.
GraphModule
)
->
torch
.
fx
.
GraphModule
:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for
node
in
m
.
graph
.
nodes
:
# Checks if we're calling a function (i.e:
# torch.add)
if
node
.
op
==
'call_function'
:
# The target attribute is the function
# that call_function calls.
if
node
.
target
==
torch
.
quantize_per_tensor
:
node
.
target
=
quantize_per_tensor
m
.
graph
.
lint
()
# Does some checks to make sure the
# Graph is well-formed.
m
.
recompile
()
return
m
def
main
():
def
main
():
# Training settings
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch MNIST Example'
)
...
@@ -562,7 +514,7 @@ def main():
...
@@ -562,7 +514,7 @@ def main():
help
=
'random seed (default: 1)'
)
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--sparse'
,
parser
.
add_argument
(
'--sparse'
,
action
=
'store_true'
,
action
=
'store_true'
,
default
=
Fals
e
,
default
=
Tru
e
,
help
=
'use sparse conv network instead of dense'
)
help
=
'use sparse conv network instead of dense'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--log-interval'
,
'--log-interval'
,
...
@@ -589,7 +541,7 @@ def main():
...
@@ -589,7 +541,7 @@ def main():
qdevice
=
torch
.
device
(
"cuda"
if
use_cuda
and
args
.
sparse
else
"cpu"
)
qdevice
=
torch
.
device
(
"cuda"
if
use_cuda
and
args
.
sparse
else
"cpu"
)
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
if
args
.
sparse
:
if
args
.
sparse
:
model
=
NetV2
().
to
(
device
)
model
=
ResidualNetPTQ
().
to
(
device
)
else
:
else
:
model
=
NetDense
().
to
(
device
)
model
=
NetDense
().
to
(
device
)
...
...
spconv/core_cc/csrc/sparse/all/__init__.pyi
View file @
1f6deed6
...
@@ -380,7 +380,7 @@ class SpconvOps:
...
@@ -380,7 +380,7 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1
, do_sort: bool = True
) -> Tensor:
"""
"""
Args:
Args:
data:
data:
...
@@ -388,10 +388,11 @@ class SpconvOps:
...
@@ -388,10 +388,11 @@ class SpconvOps:
indices:
indices:
stream:
stream:
mask_count:
mask_count:
do_sort:
"""
"""
...
...
@staticmethod
@staticmethod
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1
, do_sort: bool = True
) -> Tensor:
"""
"""
Args:
Args:
data:
data:
...
@@ -399,6 +400,7 @@ class SpconvOps:
...
@@ -399,6 +400,7 @@ class SpconvOps:
indices:
indices:
stream:
stream:
mask_count:
mask_count:
do_sort:
"""
"""
...
...
@staticmethod
@staticmethod
...
@@ -555,7 +557,7 @@ class SpconvOps:
...
@@ -555,7 +557,7 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}
, do_sort: bool = True
) -> Tuple[Tensor, int]:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -576,6 +578,7 @@ class SpconvOps:
...
@@ -576,6 +578,7 @@ class SpconvOps:
timer:
timer:
direct_table:
direct_table:
preallocated:
preallocated:
do_sort:
"""
"""
...
...
@staticmethod
@staticmethod
...
...
spconv/csrc/sparse/all.py
View file @
1f6deed6
...
@@ -922,6 +922,8 @@ class SpconvOps(pccm.Class):
...
@@ -922,6 +922,8 @@ class SpconvOps(pccm.Class):
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
pyanno
=
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_count"
,
"int"
,
"1"
,
pyanno
=
"int"
)
code
.
arg
(
"mask_count"
,
"int"
,
"1"
,
pyanno
=
"int"
)
code
.
arg
(
"do_sort"
,
"bool"
,
"true"
)
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
self
.
cuda_common_kernel
)
code
.
add_param_class
(
"cudakers"
,
self
.
cuda_common_kernel
)
if
not
use_allocator
:
if
not
use_allocator
:
...
@@ -935,6 +937,9 @@ class SpconvOps(pccm.Class):
...
@@ -935,6 +937,9 @@ class SpconvOps(pccm.Class):
}}
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
if (!do_sort){{
return indices;
}}
// auto timer = tv::CUDATimer();
// auto timer = tv::CUDATimer();
"""
)
"""
)
# nested tv::dispatch may cause compiler bug in msvc.
# nested tv::dispatch may cause compiler bug in msvc.
...
@@ -1645,6 +1650,7 @@ class SpconvOps(pccm.Class):
...
@@ -1645,6 +1650,7 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"preallocated"
,
f
"std::unordered_map<std::string, tv::Tensor>"
,
code
.
arg
(
"preallocated"
,
f
"std::unordered_map<std::string, tv::Tensor>"
,
"std::unordered_map<std::string, tv::Tensor>{}"
,
"std::unordered_map<std::string, tv::Tensor>{}"
,
"Dict[str, cumm.tensorview.Tensor] = {}"
)
"Dict[str, cumm.tensorview.Tensor] = {}"
)
code
.
arg
(
"do_sort"
,
f
"bool"
,
"true"
)
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
code
.
raw
(
f
"""
...
@@ -1788,7 +1794,7 @@ class SpconvOps(pccm.Class):
...
@@ -1788,7 +1794,7 @@ class SpconvOps(pccm.Class):
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count
, do_sort
);
}}
}}
"""
)
"""
)
with
code
.
else_
():
with
code
.
else_
():
...
@@ -1952,6 +1958,7 @@ Your Conv Params: )" << "\\n";
...
@@ -1952,6 +1958,7 @@ Your Conv Params: )" << "\\n";
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_sort",
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_sort",
timer, reinterpret_cast<cudaStream_t>(stream_int));
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (is_mask_split){{
if (is_mask_split){{
TV_ASSERT_RT_ERR(do_sort, "not implemented for now");
for (int j = 0; j < mask_split_count; ++j){{
for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
if (!is_train){{
if (!is_train){{
...
@@ -1967,12 +1974,12 @@ Your Conv Params: )" << "\\n";
...
@@ -1967,12 +1974,12 @@ Your Conv Params: )" << "\\n";
}}else{{
}}else{{
if (!is_train){{
if (!is_train){{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
mask_argsort_fwd[0], stream_int, mask_int_count
, do_sort
);
}}else{{
}}else{{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
mask_argsort_fwd[0], stream_int, mask_int_count
, do_sort
);
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count);
mask_argsort_bwd[0], stream_int, mask_int_count
, do_sort
);
}}
}}
}}
}}
}}
}}
...
...
spconv/csrc/sparse/alloc.py
View file @
1f6deed6
...
@@ -304,6 +304,7 @@ class StaticAllocator(ExternalAllocator):
...
@@ -304,6 +304,7 @@ class StaticAllocator(ExternalAllocator):
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"scale"
,
"float"
,
"1.0"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto tvctx = tv::Context();
auto tvctx = tv::Context();
"""
)
"""
)
...
@@ -328,6 +329,7 @@ class StaticAllocator(ExternalAllocator):
...
@@ -328,6 +329,7 @@ class StaticAllocator(ExternalAllocator):
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"scale"
,
"float"
,
"1.0"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
if (name ==
{
pccm
.
literal
(
AllocKeys
.
ThrustTemp
)
}
){{
if (name ==
{
pccm
.
literal
(
AllocKeys
.
ThrustTemp
)
}
){{
// thrust tmp shouldn't inside tensor_dict. use a simple method to allocate
// thrust tmp shouldn't inside tensor_dict. use a simple method to allocate
...
...
spconv/csrc/sparse/convops.py
View file @
1f6deed6
...
@@ -2201,7 +2201,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2201,7 +2201,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias
// use source as bias
beta = output_add_scale;
beta = output_add_scale
/ output_scale
;
}}
}}
if (j > 0){{
if (j > 0){{
...
...
spconv/gencode/__main__.py
View file @
1f6deed6
...
@@ -34,8 +34,11 @@ def main(include: str,
...
@@ -34,8 +34,11 @@ def main(include: str,
cu
.
namespace
=
"cumm.gemm.main"
cu
.
namespace
=
"cumm.gemm.main"
all_imp
=
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
all_imp
=
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_AMPERE_PARAMS
)
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_AMPERE_PARAMS
)
# all_imp = IMPLGEMM_SIMT_PARAMS
# keep all int8 kernels in libspconv
all_imp
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_imp
))
for
x
in
all_imp
:
if
x
.
int8_inference
:
x
.
is_nvrtc
=
False
all_imp
=
list
(
filter
(
lambda
x
:
(
not
x
.
is_nvrtc
),
all_imp
))
if
inference_only
:
if
inference_only
:
all_imp
=
list
(
filter
(
lambda
x
:
x
.
op_type
==
ConvOpType
.
kForward
,
all_imp
))
all_imp
=
list
(
filter
(
lambda
x
:
x
.
op_type
==
ConvOpType
.
kForward
,
all_imp
))
convcu
=
ConvMainUnitTest
(
all_imp
)
convcu
=
ConvMainUnitTest
(
all_imp
)
...
...
spconv/pytorch/conv.py
View file @
1f6deed6
...
@@ -137,6 +137,9 @@ class SparseConvolutionBase:
...
@@ -137,6 +137,9 @@ class SparseConvolutionBase:
if
self
.
conv1x1
:
if
self
.
conv1x1
:
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
assert
act_type
==
tv
.
gemm
.
Activation
.
None_
,
"conv1x1 don't support fused act"
def
is_inverseable
(
self
):
return
self
.
indice_key
is
not
None
and
not
self
.
subm
def
_conv_forward
(
self
,
training
:
bool
,
input
:
SparseConvTensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
add_input
:
Optional
[
SparseConvTensor
]
=
None
,
def
_conv_forward
(
self
,
training
:
bool
,
input
:
SparseConvTensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
add_input
:
Optional
[
SparseConvTensor
]
=
None
,
channel_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
float
]
=
None
,
name
:
Optional
[
str
]
=
None
,
channel_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
float
]
=
None
,
name
:
Optional
[
str
]
=
None
,
sparse_unique_name
:
str
=
""
,
sparse_unique_name
:
str
=
""
,
...
@@ -681,6 +684,9 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
...
@@ -681,6 +684,9 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
s
+=
', bias=False'
s
+=
', bias=False'
if
self
.
algo
is
not
None
:
if
self
.
algo
is
not
None
:
s
+=
f
', algo=
{
self
.
algo
}
'
s
+=
f
', algo=
{
self
.
algo
}
'
if
self
.
act_type
!=
tv
.
gemm
.
Activation
.
None_
:
s
+=
f
', act=
{
self
.
act_type
}
'
return
s
.
format
(
**
self
.
__dict__
)
return
s
.
format
(
**
self
.
__dict__
)
def
_calculate_fan_in_and_fan_out
(
self
):
def
_calculate_fan_in_and_fan_out
(
self
):
...
@@ -730,8 +736,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
...
@@ -730,8 +736,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
bound
=
1
/
math
.
sqrt
(
fan_in
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
is_inverseable
(
self
):
return
self
.
indice_key
is
not
None
and
not
self
.
subm
def
forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
def
forward
(
self
,
input
:
SparseConvTensor
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
return
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight
,
self
.
bias
,
add_input
,
return
self
.
_conv_forward
(
self
.
training
,
input
,
self
.
weight
,
self
.
bias
,
add_input
,
...
...
spconv/pytorch/core.py
View file @
1f6deed6
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Dict
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
,
Dict
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -181,6 +181,10 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
...
@@ -181,6 +181,10 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
thrust_allocator
=
ThrustSortAllocator
(
features
.
device
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
_timer
=
CUDAKernelTimer
(
enable_timer
)
self
.
force_algo
=
force_algo
self
.
force_algo
=
force_algo
self
.
int8_scale
:
Optional
[
np
.
ndarray
]
=
None
def
__repr__
(
self
):
return
f
"SparseConvTensor[shape=
{
self
.
_features
.
shape
}
]"
@
property
@
property
def
is_quantized
(
self
):
def
is_quantized
(
self
):
...
@@ -204,6 +208,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
...
@@ -204,6 +208,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
new_spt
.
thrust_allocator
=
self
.
thrust_allocator
new_spt
.
thrust_allocator
=
self
.
thrust_allocator
new_spt
.
_timer
=
self
.
_timer
new_spt
.
_timer
=
self
.
_timer
new_spt
.
force_algo
=
self
.
force_algo
new_spt
.
force_algo
=
self
.
force_algo
new_spt
.
int8_scale
=
self
.
int8_scale
return
new_spt
return
new_spt
...
@@ -302,6 +307,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
...
@@ -302,6 +307,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
tensor
.
thrust_allocator
=
self
.
thrust_allocator
tensor
.
thrust_allocator
=
self
.
thrust_allocator
tensor
.
_timer
=
self
.
_timer
tensor
.
_timer
=
self
.
_timer
tensor
.
force_algo
=
self
.
force_algo
tensor
.
force_algo
=
self
.
force_algo
tensor
.
int8_scale
=
self
.
int8_scale
return
tensor
return
tensor
def
expand_nd
(
ndim
:
int
,
val
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...],
np
.
ndarray
])
->
List
[
int
]:
def
expand_nd
(
ndim
:
int
,
val
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
...],
np
.
ndarray
])
->
List
[
int
]:
...
...
spconv/pytorch/cppcore.py
View file @
1f6deed6
...
@@ -137,10 +137,11 @@ class TorchAllocator(ExternalAllocator):
...
@@ -137,10 +137,11 @@ class TorchAllocator(ExternalAllocator):
else
:
else
:
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
).
zero_
()
ten
=
torch
.
empty
(
shape
,
dtype
=
th_dtype
,
device
=
dev
).
zero_
()
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
ten_tv
=
torch_tensor_to_tv
(
ten
,
dtype_bkp
)
# if self.is_quantized:
if
self
.
is_quantized
:
# ctx = tv.Context()
# no _zeros_affine_quantized available, so we need to zero_ here.
# ctx.set_cuda_stream(stream)
ctx
=
tv
.
Context
()
# ten_tv.zero_(ctx)
ctx
.
set_cuda_stream
(
stream
)
ten_tv
.
zero_
(
ctx
)
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
self
.
allocated
[
ten_tv
.
byte_pointer
()]
=
ten
if
name
and
not
is_temp_memory
:
if
name
and
not
is_temp_memory
:
self
.
allocated
[
name
]
=
ten
self
.
allocated
[
name
]
=
ten
...
...
spconv/pytorch/ops.py
View file @
1f6deed6
...
@@ -1468,6 +1468,7 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1468,6 +1468,7 @@ def implicit_gemm(features: torch.Tensor,
bias_tv
=
tv
.
Tensor
()
bias_tv
=
tv
.
Tensor
()
scale_tv
=
tv
.
Tensor
()
scale_tv
=
tv
.
Tensor
()
output_add_tv
=
tv
.
Tensor
()
output_add_tv
=
tv
.
Tensor
()
is_int8
=
features
.
is_quantized
and
filters
.
is_quantized
if
output_add
is
not
None
:
if
output_add
is
not
None
:
assert
features
.
dtype
==
torch
.
qint8
,
"fused residual add only support int8"
assert
features
.
dtype
==
torch
.
qint8
,
"fused residual add only support int8"
if
bias
is
not
None
:
if
bias
is
not
None
:
...
@@ -1535,6 +1536,23 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1535,6 +1536,23 @@ def implicit_gemm(features: torch.Tensor,
filters
=
filters
.
reshape
(
out_channel
,
-
1
,
filters
.
shape
[
-
1
])
filters
=
filters
.
reshape
(
out_channel
,
-
1
,
filters
.
shape
[
-
1
])
kv
=
filters
.
shape
[
1
]
kv
=
filters
.
shape
[
1
]
mask_int_count
=
div_up
(
kv
,
32
)
mask_int_count
=
div_up
(
kv
,
32
)
if
is_int8
:
if
is_subm
:
out_features
=
torch
.
_empty_affine_quantized
(
size
=
(
num_activate_out
,
out_channel
),
scale
=
output_scale
,
zero_point
=
0
,
dtype
=
features
.
dtype
,
device
=
features
.
device
)
# out_features = torch.empty((num_activate_out, out_channel),
# dtype=output_dtype,
# device=features.device)
else
:
out_features
=
torch
.
_empty_affine_quantized
(
size
=
(
num_activate_out
,
out_channel
),
scale
=
output_scale
,
zero_point
=
0
,
dtype
=
features
.
dtype
,
device
=
features
.
device
)
ctx
=
tv
.
Context
()
ctx
.
set_cuda_stream
(
stream
)
torch_tensor_to_tv
(
out_features
).
zero_
(
ctx
)
# out_features = torch.zeros((num_activate_out, out_channel),
# dtype=output_dtype,
# device=features.device)
else
:
if
is_subm
:
if
is_subm
:
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
out_features
=
torch
.
empty
((
num_activate_out
,
out_channel
),
dtype
=
output_dtype
,
dtype
=
output_dtype
,
...
@@ -1543,7 +1561,6 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1543,7 +1561,6 @@ def implicit_gemm(features: torch.Tensor,
out_features
=
torch
.
zeros
((
num_activate_out
,
out_channel
),
out_features
=
torch
.
zeros
((
num_activate_out
,
out_channel
),
dtype
=
output_dtype
,
dtype
=
output_dtype
,
device
=
features
.
device
)
device
=
features
.
device
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
pair_fwd_tv
=
torch_tensor_to_tv
(
pair_fwd
)
features_tv
=
torch_tensor_to_tv
(
features
)
features_tv
=
torch_tensor_to_tv
(
features
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
...
@@ -1617,7 +1634,7 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1617,7 +1634,7 @@ def implicit_gemm(features: torch.Tensor,
if
bias
is
not
None
and
not
tune_res
.
algo_desp
.
is_int8_inference
:
if
bias
is
not
None
and
not
tune_res
.
algo_desp
.
is_int8_inference
:
beta
=
1
beta
=
1
if
output_add
is
not
None
and
tune_res
.
algo_desp
.
is_int8_inference
:
if
output_add
is
not
None
and
tune_res
.
algo_desp
.
is_int8_inference
:
beta
=
output_add_scale
beta
=
output_add_scale
/
output_scale
CONV
.
run_with_tuned_result
(
CONV
.
run_with_tuned_result
(
tune_res
,
tune_res
,
ConvOpType
.
kForward
,
ConvOpType
.
kForward
,
...
@@ -1640,7 +1657,7 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1640,7 +1657,7 @@ def implicit_gemm(features: torch.Tensor,
act_alpha
=
act_alpha
,
act_alpha
=
act_alpha
,
act_beta
=
act_beta
,
act_beta
=
act_beta
,
scale
=
scale_tv
,
scale
=
scale_tv
,
output_add
=
output_add
)
output_add
=
output_add
_tv
)
return
out_features
,
mask_output_fwd
,
mask_width
return
out_features
,
mask_output_fwd
,
mask_width
...
...
spconv/pytorch/quantization/backend_cfg.py
View file @
1f6deed6
...
@@ -591,14 +591,14 @@ SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[
...
@@ -591,14 +591,14 @@ SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[
snni
.
SpconvReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvReLU
),
snni
.
SpconvReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvReLU
),
snni
.
SpconvAddReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvAddReLU
),
snni
.
SpconvAddReLUNd
:
(
snnqr
.
SpConv
,
snniq
.
SparseConvAddReLU
),
# use simple cumm i8 conv to implement linear
# use simple cumm i8 conv to implement linear
nni
.
LinearReLU
:
(
nnqr
.
Linear
,
snniq
.
LinearPerChannelWeightReLU
),
#
nni.LinearReLU: (nnqr.Linear, snniq.LinearPerChannelWeightReLU),
}
}
SPCONV_STATIC_LOWER_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
SPCONV_STATIC_LOWER_MODULE_MAP
:
Dict
[
Type
[
nn
.
Module
],
Type
[
WeightedQuantizedModule
]]
=
{
Type
[
WeightedQuantizedModule
]]
=
{
snnqr
.
SpConv
:
snnq
.
SparseConv
,
snnqr
.
SpConv
:
snnq
.
SparseConv
,
nnqr
.
Linear
:
snnq
.
LinearPerChannelWeight
,
#
nnqr.Linear: snnq.LinearPerChannelWeight,
}
}
...
...
spconv/pytorch/quantization/intrinsic/quantized/conv_relu.py
View file @
1f6deed6
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
from
typing
import
Optional
from
typing
import
Optional
from
spconv.pytorch.core
import
SparseConvTensor
from
spconv.pytorch.core
import
SparseConvTensor
from
spconv.pytorch.cppcore
import
get_current_stream
from
spconv.pytorch.cppcore
import
get_current_stream
,
torch_tensor_to_tv
import
spconv.pytorch.quantization.quantized
as
nnq
import
spconv.pytorch.quantization.quantized
as
nnq
from
spconv.pytorch.quantization.intrinsic
import
SpconvReLUNd
,
SpconvAddReLUNd
from
spconv.pytorch.quantization.intrinsic
import
SpconvReLUNd
,
SpconvAddReLUNd
from
cumm
import
tensorview
as
tv
from
cumm
import
tensorview
as
tv
...
@@ -88,7 +88,6 @@ class SparseConvAddReLU(nnq.SparseConv):
...
@@ -88,7 +88,6 @@ class SparseConvAddReLU(nnq.SparseConv):
def
forward
(
self
,
input
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
def
forward
(
self
,
input
,
add_input
:
Optional
[
SparseConvTensor
]
=
None
):
msg
=
f
"
{
input
.
features
.
shape
[
0
]
}
,
{
input
.
features
.
shape
[
1
]
}
,
{
self
.
weight
().
shape
[
0
]
}
"
msg
=
f
"
{
input
.
features
.
shape
[
0
]
}
,
{
input
.
features
.
shape
[
1
]
}
,
{
self
.
weight
().
shape
[
0
]
}
"
with
tv
.
measure_and_print
(
f
"QuantizedSparseConvAddReLU|
{
msg
}
"
,
get_current_stream
(),
enable
=
False
):
with
tv
.
measure_and_print
(
f
"QuantizedSparseConvAddReLU|
{
msg
}
"
,
get_current_stream
(),
enable
=
False
):
inp_scale
=
input
.
q_scale
()
inp_scale
=
input
.
q_scale
()
w_scales
=
self
.
weight
().
q_per_channel_scales
().
to
(
torch
.
float32
)
w_scales
=
self
.
weight
().
q_per_channel_scales
().
to
(
torch
.
float32
)
out_scale
=
self
.
scale
out_scale
=
self
.
scale
...
...
spconv/pytorch/quantization/quantized/conv.py
View file @
1f6deed6
...
@@ -87,7 +87,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
...
@@ -87,7 +87,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
bias_float
=
(
bias_float
=
(
torch
.
zeros
(
out_channels
,
dtype
=
torch
.
float
,
torch
.
zeros
(
out_channels
,
dtype
=
torch
.
float
,
**
{
k
:
v
for
k
,
v
in
factory_kwargs
.
items
()
if
k
!=
'dtype'
})
if
bias
else
None
)
**
{
k
:
v
for
k
,
v
in
factory_kwargs
.
items
()
if
k
!=
'dtype'
})
if
bias
else
None
)
self
.
_max_voxels
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
set_weight_bias
(
qweight
,
bias_float
)
self
.
set_weight_bias
(
qweight
,
bias_float
)
self
.
scale
=
1.0
self
.
scale
=
1.0
self
.
zero_point
=
0
self
.
zero_point
=
0
...
@@ -96,6 +96,9 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
...
@@ -96,6 +96,9 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
self
.
_weight
:
torch
.
Tensor
=
qweight
self
.
_weight
:
torch
.
Tensor
=
qweight
self
.
_bias
:
torch
.
Tensor
=
bias_float
self
.
_bias
:
torch
.
Tensor
=
bias_float
def
set_max_voxels
(
self
,
max_voxel
):
self
.
_max_voxels
=
max_voxel
def
bias
(
self
):
def
bias
(
self
):
return
self
.
_bias
return
self
.
_bias
...
@@ -137,6 +140,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
...
@@ -137,6 +140,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
destination
[
prefix
+
'bias'
]
=
b
destination
[
prefix
+
'bias'
]
=
b
destination
[
prefix
+
'scale'
]
=
torch
.
tensor
(
self
.
scale
)
destination
[
prefix
+
'scale'
]
=
torch
.
tensor
(
self
.
scale
)
destination
[
prefix
+
'zero_point'
]
=
torch
.
tensor
(
self
.
zero_point
)
destination
[
prefix
+
'zero_point'
]
=
torch
.
tensor
(
self
.
zero_point
)
destination
[
prefix
+
'max_voxels'
]
=
torch
.
tensor
(
self
.
_max_voxels
)
# @torch.jit.export
# @torch.jit.export
# def __getstate__(self):
# def __getstate__(self):
...
@@ -169,6 +173,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
...
@@ -169,6 +173,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
state_dict
.
pop
(
prefix
+
'weight'
)
state_dict
.
pop
(
prefix
+
'weight'
)
state_dict
.
pop
(
prefix
+
'bias'
)
state_dict
.
pop
(
prefix
+
'bias'
)
self
.
scale
=
float
(
state_dict
[
prefix
+
'scale'
])
self
.
scale
=
float
(
state_dict
[
prefix
+
'scale'
])
state_dict
.
pop
(
prefix
+
'max_voxels'
)
self
.
_max_voxels
=
state_dict
[
prefix
+
'max_voxels'
]
state_dict
.
pop
(
prefix
+
'scale'
)
state_dict
.
pop
(
prefix
+
'scale'
)
self
.
zero_point
=
int
(
state_dict
[
prefix
+
'zero_point'
])
self
.
zero_point
=
int
(
state_dict
[
prefix
+
'zero_point'
])
state_dict
.
pop
(
prefix
+
'zero_point'
)
state_dict
.
pop
(
prefix
+
'zero_point'
)
...
@@ -213,6 +219,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
...
@@ -213,6 +219,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
assert
weight_post_process
.
dtype
==
torch
.
qint8
,
\
assert
weight_post_process
.
dtype
==
torch
.
qint8
,
\
'Weight observer must have a dtype of qint8'
'Weight observer must have a dtype of qint8'
qweight
=
_quantize_weight
(
mod
.
weight
.
float
(),
weight_post_process
)
qweight
=
_quantize_weight
(
mod
.
weight
.
float
(),
weight_post_process
)
# the __init__ call used is the one from derived classes and not the one from _ConvNd
# the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv
=
cls
(
mod
.
ndim
,
mod
.
in_channels
,
mod
.
out_channels
,
mod
.
kernel_size
,
qconv
=
cls
(
mod
.
ndim
,
mod
.
in_channels
,
mod
.
out_channels
,
mod
.
kernel_size
,
mod
.
stride
,
mod
.
padding
,
mod
.
dilation
,
mod
.
stride
,
mod
.
padding
,
mod
.
dilation
,
...
@@ -230,6 +237,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
...
@@ -230,6 +237,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
act_alpha
=
mod
.
act_alpha
,
act_alpha
=
mod
.
act_alpha
,
act_beta
=
mod
.
act_beta
)
act_beta
=
mod
.
act_beta
)
qconv
.
set_weight_bias
(
qweight
,
mod
.
bias
)
qconv
.
set_weight_bias
(
qweight
,
mod
.
bias
)
if
mod
.
get_max_num_voxels
()
is
not
None
:
qconv
.
set_max_voxels
(
mod
.
get_max_num_voxels
())
if
activation_post_process
is
None
or
activation_post_process
.
dtype
==
torch
.
float
:
if
activation_post_process
is
None
or
activation_post_process
.
dtype
==
torch
.
float
:
return
qconv
# dynamic quantization doesn't need scale/zero_point
return
qconv
# dynamic quantization doesn't need scale/zero_point
else
:
else
:
...
@@ -295,6 +304,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
...
@@ -295,6 +304,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
qconv
.
set_weight_bias
(
qweight
,
ref_qconv
.
bias
)
qconv
.
set_weight_bias
(
qweight
,
ref_qconv
.
bias
)
qconv
.
scale
=
float
(
output_scale
)
qconv
.
scale
=
float
(
output_scale
)
qconv
.
zero_point
=
int
(
output_zero_point
)
qconv
.
zero_point
=
int
(
output_zero_point
)
if
ref_qconv
.
get_max_num_voxels
()
is
not
None
:
qconv
.
set_max_voxels
(
ref_qconv
.
get_max_num_voxels
())
return
qconv
return
qconv
...
...
spconv/pytorch/quantization/quantized/reference.py
View file @
1f6deed6
...
@@ -85,6 +85,8 @@ class _SpConvNd(sconvmod.SparseConvolution, ReferenceQuantizedModule):
...
@@ -85,6 +85,8 @@ class _SpConvNd(sconvmod.SparseConvolution, ReferenceQuantizedModule):
qref_conv
.
weight
=
torch
.
nn
.
Parameter
(
float_conv
.
weight
.
detach
())
qref_conv
.
weight
=
torch
.
nn
.
Parameter
(
float_conv
.
weight
.
detach
())
if
float_conv
.
bias
is
not
None
:
if
float_conv
.
bias
is
not
None
:
qref_conv
.
bias
=
torch
.
nn
.
Parameter
(
float_conv
.
bias
.
detach
())
qref_conv
.
bias
=
torch
.
nn
.
Parameter
(
float_conv
.
bias
.
detach
())
if
conv
.
get_max_num_voxels
()
is
not
None
:
qref_conv
.
get_max_num_voxels
()[:]
=
conv
.
get_max_num_voxels
()
return
qref_conv
return
qref_conv
...
...
test/test_all_algo.py
View file @
1f6deed6
...
@@ -273,6 +273,9 @@ class SparseConvTester:
...
@@ -273,6 +273,9 @@ class SparseConvTester:
if
self
.
check_int8_infer
:
if
self
.
check_int8_infer
:
rescaled
=
output_ref
.
astype
(
self
.
dtype_comp
)
*
self
.
scales
.
astype
(
self
.
dtype_comp
)
rescaled
=
output_ref
.
astype
(
self
.
dtype_comp
)
*
self
.
scales
.
astype
(
self
.
dtype_comp
)
rescaled
+=
self
.
bias
.
astype
(
self
.
dtype_comp
)
rescaled
+=
self
.
bias
.
astype
(
self
.
dtype_comp
)
if
self
.
subm
:
rescaled
+=
self
.
output_add
.
astype
(
self
.
dtype_comp
)
*
self
.
output_add_scale
else
:
rescaled
+=
self
.
output_add
[
self
.
out_order
].
astype
(
self
.
dtype_comp
)
*
self
.
output_add_scale
rescaled
+=
self
.
output_add
[
self
.
out_order
].
astype
(
self
.
dtype_comp
)
*
self
.
output_add_scale
if
self
.
check_act
:
if
self
.
check_act
:
rescaled
=
np
.
maximum
(
rescaled
,
0
)
rescaled
=
np
.
maximum
(
rescaled
,
0
)
...
@@ -1020,8 +1023,8 @@ def _test_native_conv_cuda(subm: bool):
...
@@ -1020,8 +1023,8 @@ def _test_native_conv_cuda(subm: bool):
def
test_all_algo_unit
():
def
test_all_algo_unit
():
# for i in range(5):
# for i in range(5):
#
_test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda
(
True
)
_test_impgemm_conv_cuda
(
False
)
#
_test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(False)
# _test_native_conv_cuda(False)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment