Commit 1f6deed6 authored by yan.yan's avatar yan.yan
Browse files

prepare int8 release

parent 5b3fe9e7
...@@ -317,6 +317,7 @@ class ResidualNetPTQ(nn.Module): ...@@ -317,6 +317,7 @@ class ResidualNetPTQ(nn.Module):
super(ResidualNetPTQ, self).__init__() super(ResidualNetPTQ, self).__init__()
self.net = spconv.SparseSequential( self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3), SubMConvBNReLU(1, 32, 3),
# SubMConvBNReLU(32, 32, 3),
SparseBasicBlock2(32, 32), SparseBasicBlock2(32, 32),
SubMConvBNReLU(32, 64, 3), SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14 SparseConvBNReLU(64, 64, 2, 2), # 14x14
...@@ -474,55 +475,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device): ...@@ -474,55 +475,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else: else:
output = model(image) output = model(image)
def is_dequantize_node(node):
return isinstance(node, torch.fx.Node) and node.op == "call_method" and node.target == "dequantize"
def _get_module(node: torch.fx.Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if node.op == "call_module" and str(node.target) in modules:
return modules[str(node.target)]
else:
return None
def remove_conv_add_dq(model: torch.fx.graph_module.GraphModule):
modules = dict(model.named_modules(remove_duplicate=False))
for n in model.graph.nodes:
if (n.op == "call_module" and type(_get_module(n, modules)) == snniq.SparseConvAddReLU):
# check second input, if it's dequantized, remove that dequantize node
arg1 = n.args[1]
if is_dequantize_node(arg1):
dq_node = arg1
assert(isinstance(dq_node, torch.fx.Node))
dn_input = dq_node.args[0]
n.replace_input_with(dq_node, dn_input)
model.graph.eliminate_dead_code()
model.recompile()
model.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return model
def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for node in m.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.quantize_per_tensor:
node.target = quantize_per_tensor
m.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
m.recompile()
return m
def main(): def main():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
...@@ -562,7 +514,7 @@ def main(): ...@@ -562,7 +514,7 @@ def main():
help='random seed (default: 1)') help='random seed (default: 1)')
parser.add_argument('--sparse', parser.add_argument('--sparse',
action='store_true', action='store_true',
default=False, default=True,
help='use sparse conv network instead of dense') help='use sparse conv network instead of dense')
parser.add_argument( parser.add_argument(
'--log-interval', '--log-interval',
...@@ -589,7 +541,7 @@ def main(): ...@@ -589,7 +541,7 @@ def main():
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu") qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
if args.sparse: if args.sparse:
model = NetV2().to(device) model = ResidualNetPTQ().to(device)
else: else:
model = NetDense().to(device) model = NetDense().to(device)
......
...@@ -380,7 +380,7 @@ class SpconvOps: ...@@ -380,7 +380,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor: def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1, do_sort: bool = True) -> Tensor:
""" """
Args: Args:
data: data:
...@@ -388,10 +388,11 @@ class SpconvOps: ...@@ -388,10 +388,11 @@ class SpconvOps:
indices: indices:
stream: stream:
mask_count: mask_count:
do_sort:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor: def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1, do_sort: bool = True) -> Tensor:
""" """
Args: Args:
data: data:
...@@ -399,6 +400,7 @@ class SpconvOps: ...@@ -399,6 +400,7 @@ class SpconvOps:
indices: indices:
stream: stream:
mask_count: mask_count:
do_sort:
""" """
... ...
@staticmethod @staticmethod
...@@ -555,7 +557,7 @@ class SpconvOps: ...@@ -555,7 +557,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]: def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}, do_sort: bool = True) -> Tuple[Tensor, int]:
""" """
Args: Args:
allocator: allocator:
...@@ -576,6 +578,7 @@ class SpconvOps: ...@@ -576,6 +578,7 @@ class SpconvOps:
timer: timer:
direct_table: direct_table:
preallocated: preallocated:
do_sort:
""" """
... ...
@staticmethod @staticmethod
......
...@@ -922,6 +922,8 @@ class SpconvOps(pccm.Class): ...@@ -922,6 +922,8 @@ class SpconvOps(pccm.Class):
pyanno="cumm.tensorview.Tensor = Tensor()") pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_count", "int", "1", pyanno="int") code.arg("mask_count", "int", "1", pyanno="int")
code.arg("do_sort", "bool", "true")
code.add_dependency(CustomThrustLib, TensorViewKernel) code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", self.cuda_common_kernel) code.add_param_class("cudakers", self.cuda_common_kernel)
if not use_allocator: if not use_allocator:
...@@ -935,6 +937,9 @@ class SpconvOps(pccm.Class): ...@@ -935,6 +937,9 @@ class SpconvOps(pccm.Class):
}} }}
tv::cuda::Launch launcher(data.dim(0), stream_cu); tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0)); launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
if (!do_sort){{
return indices;
}}
// auto timer = tv::CUDATimer(); // auto timer = tv::CUDATimer();
""") """)
# nested tv::dispatch may cause compiler bug in msvc. # nested tv::dispatch may cause compiler bug in msvc.
...@@ -1645,6 +1650,7 @@ class SpconvOps(pccm.Class): ...@@ -1645,6 +1650,7 @@ class SpconvOps(pccm.Class):
code.arg("preallocated", f"std::unordered_map<std::string, tv::Tensor>", code.arg("preallocated", f"std::unordered_map<std::string, tv::Tensor>",
"std::unordered_map<std::string, tv::Tensor>{}", "std::unordered_map<std::string, tv::Tensor>{}",
"Dict[str, cumm.tensorview.Tensor] = {}") "Dict[str, cumm.tensorview.Tensor] = {}")
code.arg("do_sort", f"bool", "true")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f""" code.raw(f"""
...@@ -1788,7 +1794,7 @@ class SpconvOps(pccm.Class): ...@@ -1788,7 +1794,7 @@ class SpconvOps(pccm.Class):
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)}, auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int); {{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{ for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count); sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count, do_sort);
}} }}
""") """)
with code.else_(): with code.else_():
...@@ -1952,6 +1958,7 @@ Your Conv Params: )" << "\\n"; ...@@ -1952,6 +1958,7 @@ Your Conv Params: )" << "\\n";
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_sort", tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_sort",
timer, reinterpret_cast<cudaStream_t>(stream_int)); timer, reinterpret_cast<cudaStream_t>(stream_int));
if (is_mask_split){{ if (is_mask_split){{
TV_ASSERT_RT_ERR(do_sort, "not implemented for now");
for (int j = 0; j < mask_split_count; ++j){{ for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1); auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
if (!is_train){{ if (!is_train){{
...@@ -1967,12 +1974,12 @@ Your Conv Params: )" << "\\n"; ...@@ -1967,12 +1974,12 @@ Your Conv Params: )" << "\\n";
}}else{{ }}else{{
if (!is_train){{ if (!is_train){{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count); mask_argsort_fwd[0], stream_int, mask_int_count, do_sort);
}}else{{ }}else{{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count); mask_argsort_fwd[0], stream_int, mask_int_count, do_sort);
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc, sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count); mask_argsort_bwd[0], stream_int, mask_int_count, do_sort);
}} }}
}} }}
}} }}
......
...@@ -304,6 +304,7 @@ class StaticAllocator(ExternalAllocator): ...@@ -304,6 +304,7 @@ class StaticAllocator(ExternalAllocator):
code.arg("device", "int") code.arg("device", "int")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false") code.arg("is_temp_memory", "bool", "false")
code.arg("scale", "float", "1.0")
code.raw(f""" code.raw(f"""
auto tvctx = tv::Context(); auto tvctx = tv::Context();
""") """)
...@@ -328,6 +329,7 @@ class StaticAllocator(ExternalAllocator): ...@@ -328,6 +329,7 @@ class StaticAllocator(ExternalAllocator):
code.arg("device", "int") code.arg("device", "int")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false") code.arg("is_temp_memory", "bool", "false")
code.arg("scale", "float", "1.0")
code.raw(f""" code.raw(f"""
if (name == {pccm.literal(AllocKeys.ThrustTemp)}){{ if (name == {pccm.literal(AllocKeys.ThrustTemp)}){{
// thrust tmp shouldn't inside tensor_dict. use a simple method to allocate // thrust tmp shouldn't inside tensor_dict. use a simple method to allocate
......
...@@ -2201,7 +2201,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2201,7 +2201,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}} }}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{ if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias // use source as bias
beta = output_add_scale; beta = output_add_scale / output_scale;
}} }}
if (j > 0){{ if (j > 0){{
......
...@@ -34,8 +34,11 @@ def main(include: str, ...@@ -34,8 +34,11 @@ def main(include: str,
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS) IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
# all_imp = IMPLGEMM_SIMT_PARAMS # keep all int8 kernels in libspconv
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) for x in all_imp:
if x.int8_inference:
x.is_nvrtc = False
all_imp = list(filter(lambda x: (not x.is_nvrtc), all_imp))
if inference_only: if inference_only:
all_imp = list(filter(lambda x: x.op_type == ConvOpType.kForward, all_imp)) all_imp = list(filter(lambda x: x.op_type == ConvOpType.kForward, all_imp))
convcu = ConvMainUnitTest(all_imp) convcu = ConvMainUnitTest(all_imp)
......
...@@ -137,6 +137,9 @@ class SparseConvolutionBase: ...@@ -137,6 +137,9 @@ class SparseConvolutionBase:
if self.conv1x1: if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act" assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def _conv_forward(self, training: bool, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None, def _conv_forward(self, training: bool, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None,
channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None, name: Optional[str] = None, channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None, name: Optional[str] = None,
sparse_unique_name: str = "", sparse_unique_name: str = "",
...@@ -681,6 +684,9 @@ class SparseConvolution(SparseConvolutionBase, SparseModule): ...@@ -681,6 +684,9 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
s += ', bias=False' s += ', bias=False'
if self.algo is not None: if self.algo is not None:
s += f', algo={self.algo}' s += f', algo={self.algo}'
if self.act_type != tv.gemm.Activation.None_:
s += f', act={self.act_type}'
return s.format(**self.__dict__) return s.format(**self.__dict__)
def _calculate_fan_in_and_fan_out(self): def _calculate_fan_in_and_fan_out(self):
...@@ -730,8 +736,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule): ...@@ -730,8 +736,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None): def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(self.training, input, self.weight, self.bias, add_input, return self._conv_forward(self.training, input, self.weight, self.bias, add_input,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict from typing import Any, List, Optional, Tuple, TypeVar, Union, Dict
import numpy as np import numpy as np
import torch import torch
...@@ -181,6 +181,10 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -181,6 +181,10 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.thrust_allocator = ThrustSortAllocator(features.device) self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer) self._timer = CUDAKernelTimer(enable_timer)
self.force_algo = force_algo self.force_algo = force_algo
self.int8_scale: Optional[np.ndarray] = None
def __repr__(self):
return f"SparseConvTensor[shape={self._features.shape}]"
@property @property
def is_quantized(self): def is_quantized(self):
...@@ -204,6 +208,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -204,6 +208,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
new_spt.thrust_allocator = self.thrust_allocator new_spt.thrust_allocator = self.thrust_allocator
new_spt._timer = self._timer new_spt._timer = self._timer
new_spt.force_algo = self.force_algo new_spt.force_algo = self.force_algo
new_spt.int8_scale = self.int8_scale
return new_spt return new_spt
...@@ -302,6 +307,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -302,6 +307,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
tensor.thrust_allocator = self.thrust_allocator tensor.thrust_allocator = self.thrust_allocator
tensor._timer = self._timer tensor._timer = self._timer
tensor.force_algo = self.force_algo tensor.force_algo = self.force_algo
tensor.int8_scale = self.int8_scale
return tensor return tensor
def expand_nd(ndim: int, val: Union[int, List[int], Tuple[int, ...], np.ndarray]) -> List[int]: def expand_nd(ndim: int, val: Union[int, List[int], Tuple[int, ...], np.ndarray]) -> List[int]:
......
...@@ -137,10 +137,11 @@ class TorchAllocator(ExternalAllocator): ...@@ -137,10 +137,11 @@ class TorchAllocator(ExternalAllocator):
else: else:
ten = torch.empty(shape, dtype=th_dtype, device=dev).zero_() ten = torch.empty(shape, dtype=th_dtype, device=dev).zero_()
ten_tv = torch_tensor_to_tv(ten, dtype_bkp) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
# if self.is_quantized: if self.is_quantized:
# ctx = tv.Context() # no _zeros_affine_quantized available, so we need to zero_ here.
# ctx.set_cuda_stream(stream) ctx = tv.Context()
# ten_tv.zero_(ctx) ctx.set_cuda_stream(stream)
ten_tv.zero_(ctx)
self.allocated[ten_tv.byte_pointer()] = ten self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
......
...@@ -1468,6 +1468,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1468,6 +1468,7 @@ def implicit_gemm(features: torch.Tensor,
bias_tv = tv.Tensor() bias_tv = tv.Tensor()
scale_tv = tv.Tensor() scale_tv = tv.Tensor()
output_add_tv = tv.Tensor() output_add_tv = tv.Tensor()
is_int8 = features.is_quantized and filters.is_quantized
if output_add is not None: if output_add is not None:
assert features.dtype == torch.qint8, "fused residual add only support int8" assert features.dtype == torch.qint8, "fused residual add only support int8"
if bias is not None: if bias is not None:
...@@ -1535,6 +1536,23 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1535,6 +1536,23 @@ def implicit_gemm(features: torch.Tensor,
filters = filters.reshape(out_channel, -1, filters.shape[-1]) filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1] kv = filters.shape[1]
mask_int_count = div_up(kv, 32) mask_int_count = div_up(kv, 32)
if is_int8:
if is_subm:
out_features = torch._empty_affine_quantized(size=(num_activate_out, out_channel),
scale=output_scale, zero_point=0, dtype=features.dtype, device=features.device)
# out_features = torch.empty((num_activate_out, out_channel),
# dtype=output_dtype,
# device=features.device)
else:
out_features = torch._empty_affine_quantized(size=(num_activate_out, out_channel),
scale=output_scale, zero_point=0, dtype=features.dtype, device=features.device)
ctx = tv.Context()
ctx.set_cuda_stream(stream)
torch_tensor_to_tv(out_features).zero_(ctx)
# out_features = torch.zeros((num_activate_out, out_channel),
# dtype=output_dtype,
# device=features.device)
else:
if is_subm: if is_subm:
out_features = torch.empty((num_activate_out, out_channel), out_features = torch.empty((num_activate_out, out_channel),
dtype=output_dtype, dtype=output_dtype,
...@@ -1543,7 +1561,6 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1543,7 +1561,6 @@ def implicit_gemm(features: torch.Tensor,
out_features = torch.zeros((num_activate_out, out_channel), out_features = torch.zeros((num_activate_out, out_channel),
dtype=output_dtype, dtype=output_dtype,
device=features.device) device=features.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd) pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
features_tv = torch_tensor_to_tv(features) features_tv = torch_tensor_to_tv(features)
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
...@@ -1617,7 +1634,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1617,7 +1634,7 @@ def implicit_gemm(features: torch.Tensor,
if bias is not None and not tune_res.algo_desp.is_int8_inference: if bias is not None and not tune_res.algo_desp.is_int8_inference:
beta = 1 beta = 1
if output_add is not None and tune_res.algo_desp.is_int8_inference: if output_add is not None and tune_res.algo_desp.is_int8_inference:
beta = output_add_scale beta = output_add_scale / output_scale
CONV.run_with_tuned_result( CONV.run_with_tuned_result(
tune_res, tune_res,
ConvOpType.kForward, ConvOpType.kForward,
...@@ -1640,7 +1657,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1640,7 +1657,7 @@ def implicit_gemm(features: torch.Tensor,
act_alpha=act_alpha, act_alpha=act_alpha,
act_beta=act_beta, act_beta=act_beta,
scale=scale_tv, scale=scale_tv,
output_add=output_add) output_add=output_add_tv)
return out_features, mask_output_fwd, mask_width return out_features, mask_output_fwd, mask_width
......
...@@ -591,14 +591,14 @@ SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[ ...@@ -591,14 +591,14 @@ SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[
snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU), snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU),
snni.SpconvAddReLUNd: (snnqr.SpConv, snniq.SparseConvAddReLU), snni.SpconvAddReLUNd: (snnqr.SpConv, snniq.SparseConvAddReLU),
# use simple cumm i8 conv to implement linear # use simple cumm i8 conv to implement linear
nni.LinearReLU: (nnqr.Linear, snniq.LinearPerChannelWeightReLU), # nni.LinearReLU: (nnqr.Linear, snniq.LinearPerChannelWeightReLU),
} }
SPCONV_STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], SPCONV_STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module],
Type[WeightedQuantizedModule]] = { Type[WeightedQuantizedModule]] = {
snnqr.SpConv: snnq.SparseConv, snnqr.SpConv: snnq.SparseConv,
nnqr.Linear: snnq.LinearPerChannelWeight, # nnqr.Linear: snnq.LinearPerChannelWeight,
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from typing import Optional from typing import Optional
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from spconv.pytorch.cppcore import get_current_stream from spconv.pytorch.cppcore import get_current_stream, torch_tensor_to_tv
import spconv.pytorch.quantization.quantized as nnq import spconv.pytorch.quantization.quantized as nnq
from spconv.pytorch.quantization.intrinsic import SpconvReLUNd, SpconvAddReLUNd from spconv.pytorch.quantization.intrinsic import SpconvReLUNd, SpconvAddReLUNd
from cumm import tensorview as tv from cumm import tensorview as tv
...@@ -88,7 +88,6 @@ class SparseConvAddReLU(nnq.SparseConv): ...@@ -88,7 +88,6 @@ class SparseConvAddReLU(nnq.SparseConv):
def forward(self, input, add_input: Optional[SparseConvTensor] = None): def forward(self, input, add_input: Optional[SparseConvTensor] = None):
msg = f"{input.features.shape[0]}, {input.features.shape[1]}, {self.weight().shape[0]}" msg = f"{input.features.shape[0]}, {input.features.shape[1]}, {self.weight().shape[0]}"
with tv.measure_and_print(f"QuantizedSparseConvAddReLU|{msg}", get_current_stream(), enable=False): with tv.measure_and_print(f"QuantizedSparseConvAddReLU|{msg}", get_current_stream(), enable=False):
inp_scale = input.q_scale() inp_scale = input.q_scale()
w_scales = self.weight().q_per_channel_scales().to(torch.float32) w_scales = self.weight().q_per_channel_scales().to(torch.float32)
out_scale = self.scale out_scale = self.scale
......
...@@ -87,7 +87,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule): ...@@ -87,7 +87,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
bias_float = ( bias_float = (
torch.zeros(out_channels, dtype=torch.float, torch.zeros(out_channels, dtype=torch.float,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None) **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
self._max_voxels = torch.zeros(1, dtype=torch.int32, device=device)
self.set_weight_bias(qweight, bias_float) self.set_weight_bias(qweight, bias_float)
self.scale = 1.0 self.scale = 1.0
self.zero_point = 0 self.zero_point = 0
...@@ -96,6 +96,9 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule): ...@@ -96,6 +96,9 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
self._weight: torch.Tensor = qweight self._weight: torch.Tensor = qweight
self._bias: torch.Tensor = bias_float self._bias: torch.Tensor = bias_float
def set_max_voxels(self, max_voxel):
self._max_voxels = max_voxel
def bias(self): def bias(self):
return self._bias return self._bias
...@@ -137,6 +140,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule): ...@@ -137,6 +140,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
destination[prefix + 'bias'] = b destination[prefix + 'bias'] = b
destination[prefix + 'scale'] = torch.tensor(self.scale) destination[prefix + 'scale'] = torch.tensor(self.scale)
destination[prefix + 'zero_point'] = torch.tensor(self.zero_point) destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
destination[prefix + 'max_voxels'] = torch.tensor(self._max_voxels)
# @torch.jit.export # @torch.jit.export
# def __getstate__(self): # def __getstate__(self):
...@@ -169,6 +173,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule): ...@@ -169,6 +173,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
state_dict.pop(prefix + 'weight') state_dict.pop(prefix + 'weight')
state_dict.pop(prefix + 'bias') state_dict.pop(prefix + 'bias')
self.scale = float(state_dict[prefix + 'scale']) self.scale = float(state_dict[prefix + 'scale'])
state_dict.pop(prefix + 'max_voxels')
self._max_voxels = state_dict[prefix + 'max_voxels']
state_dict.pop(prefix + 'scale') state_dict.pop(prefix + 'scale')
self.zero_point = int(state_dict[prefix + 'zero_point']) self.zero_point = int(state_dict[prefix + 'zero_point'])
state_dict.pop(prefix + 'zero_point') state_dict.pop(prefix + 'zero_point')
...@@ -213,6 +219,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule): ...@@ -213,6 +219,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
assert weight_post_process.dtype == torch.qint8, \ assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8' 'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process) qweight = _quantize_weight(mod.weight.float(), weight_post_process)
# the __init__ call used is the one from derived classes and not the one from _ConvNd # the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv = cls(mod.ndim, mod.in_channels, mod.out_channels, mod.kernel_size, qconv = cls(mod.ndim, mod.in_channels, mod.out_channels, mod.kernel_size,
mod.stride, mod.padding, mod.dilation, mod.stride, mod.padding, mod.dilation,
...@@ -230,6 +237,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule): ...@@ -230,6 +237,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
act_alpha=mod.act_alpha, act_alpha=mod.act_alpha,
act_beta=mod.act_beta) act_beta=mod.act_beta)
qconv.set_weight_bias(qweight, mod.bias) qconv.set_weight_bias(qweight, mod.bias)
if mod.get_max_num_voxels() is not None:
qconv.set_max_voxels(mod.get_max_num_voxels())
if activation_post_process is None or activation_post_process.dtype == torch.float: if activation_post_process is None or activation_post_process.dtype == torch.float:
return qconv # dynamic quantization doesn't need scale/zero_point return qconv # dynamic quantization doesn't need scale/zero_point
else: else:
...@@ -295,6 +304,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule): ...@@ -295,6 +304,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
qconv.set_weight_bias(qweight, ref_qconv.bias) qconv.set_weight_bias(qweight, ref_qconv.bias)
qconv.scale = float(output_scale) qconv.scale = float(output_scale)
qconv.zero_point = int(output_zero_point) qconv.zero_point = int(output_zero_point)
if ref_qconv.get_max_num_voxels() is not None:
qconv.set_max_voxels(ref_qconv.get_max_num_voxels())
return qconv return qconv
......
...@@ -85,6 +85,8 @@ class _SpConvNd(sconvmod.SparseConvolution, ReferenceQuantizedModule): ...@@ -85,6 +85,8 @@ class _SpConvNd(sconvmod.SparseConvolution, ReferenceQuantizedModule):
qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
if float_conv.bias is not None: if float_conv.bias is not None:
qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
if conv.get_max_num_voxels() is not None:
qref_conv.get_max_num_voxels()[:] = conv.get_max_num_voxels()
return qref_conv return qref_conv
......
...@@ -273,6 +273,9 @@ class SparseConvTester: ...@@ -273,6 +273,9 @@ class SparseConvTester:
if self.check_int8_infer: if self.check_int8_infer:
rescaled = output_ref.astype(self.dtype_comp) * self.scales.astype(self.dtype_comp) rescaled = output_ref.astype(self.dtype_comp) * self.scales.astype(self.dtype_comp)
rescaled += self.bias.astype(self.dtype_comp) rescaled += self.bias.astype(self.dtype_comp)
if self.subm:
rescaled += self.output_add.astype(self.dtype_comp) * self.output_add_scale
else:
rescaled += self.output_add[self.out_order].astype(self.dtype_comp) * self.output_add_scale rescaled += self.output_add[self.out_order].astype(self.dtype_comp) * self.output_add_scale
if self.check_act: if self.check_act:
rescaled = np.maximum(rescaled, 0) rescaled = np.maximum(rescaled, 0)
...@@ -1020,8 +1023,8 @@ def _test_native_conv_cuda(subm: bool): ...@@ -1020,8 +1023,8 @@ def _test_native_conv_cuda(subm: bool):
def test_all_algo_unit(): def test_all_algo_unit():
# for i in range(5): # for i in range(5):
# _test_impgemm_conv_cuda(True) _test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda(False) # _test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True) # _test_native_conv_cuda(True)
# _test_native_conv_cuda(False) # _test_native_conv_cuda(False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment