Commit 1f6deed6 authored by yan.yan's avatar yan.yan
Browse files

prepare int8 release

parent 5b3fe9e7
......@@ -317,6 +317,7 @@ class ResidualNetPTQ(nn.Module):
super(ResidualNetPTQ, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
# SubMConvBNReLU(32, 32, 3),
SparseBasicBlock2(32, 32),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14
......@@ -474,55 +475,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else:
output = model(image)
def is_dequantize_node(node):
return isinstance(node, torch.fx.Node) and node.op == "call_method" and node.target == "dequantize"
def _get_module(node: torch.fx.Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if node.op == "call_module" and str(node.target) in modules:
return modules[str(node.target)]
else:
return None
def remove_conv_add_dq(model: torch.fx.graph_module.GraphModule):
modules = dict(model.named_modules(remove_duplicate=False))
for n in model.graph.nodes:
if (n.op == "call_module" and type(_get_module(n, modules)) == snniq.SparseConvAddReLU):
# check second input, if it's dequantized, remove that dequantize node
arg1 = n.args[1]
if is_dequantize_node(arg1):
dq_node = arg1
assert(isinstance(dq_node, torch.fx.Node))
dn_input = dq_node.args[0]
n.replace_input_with(dq_node, dn_input)
model.graph.eliminate_dead_code()
model.recompile()
model.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return model
def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for node in m.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.quantize_per_tensor:
node.target = quantize_per_tensor
m.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
m.recompile()
return m
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
......@@ -562,7 +514,7 @@ def main():
help='random seed (default: 1)')
parser.add_argument('--sparse',
action='store_true',
default=False,
default=True,
help='use sparse conv network instead of dense')
parser.add_argument(
'--log-interval',
......@@ -589,7 +541,7 @@ def main():
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
if args.sparse:
model = NetV2().to(device)
model = ResidualNetPTQ().to(device)
else:
model = NetDense().to(device)
......
......@@ -380,7 +380,7 @@ class SpconvOps:
"""
...
@staticmethod
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
def sort_1d_by_key_allocator(data: Tensor, alloc_func, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1, do_sort: bool = True) -> Tensor:
"""
Args:
data:
......@@ -388,10 +388,11 @@ class SpconvOps:
indices:
stream:
mask_count:
do_sort:
"""
...
@staticmethod
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1) -> Tensor:
def sort_1d_by_key_allocator_v2(data: Tensor, allocator, indices: Tensor = Tensor(), stream: int = 0, mask_count: int = 1, do_sort: bool = True) -> Tensor:
"""
Args:
data:
......@@ -399,6 +400,7 @@ class SpconvOps:
indices:
stream:
mask_count:
do_sort:
"""
...
@staticmethod
......@@ -555,7 +557,7 @@ class SpconvOps:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}, do_sort: bool = True) -> Tuple[Tensor, int]:
"""
Args:
allocator:
......@@ -576,6 +578,7 @@ class SpconvOps:
timer:
direct_table:
preallocated:
do_sort:
"""
...
@staticmethod
......
......@@ -922,6 +922,8 @@ class SpconvOps(pccm.Class):
pyanno="cumm.tensorview.Tensor = Tensor()")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.arg("mask_count", "int", "1", pyanno="int")
code.arg("do_sort", "bool", "true")
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", self.cuda_common_kernel)
if not use_allocator:
......@@ -935,6 +937,9 @@ class SpconvOps(pccm.Class):
}}
tv::cuda::Launch launcher(data.dim(0), stream_cu);
launcher(cudakers::arange_kernel<int32_t>, indices.data_ptr<int32_t>(), indices.dim(0));
if (!do_sort){{
return indices;
}}
// auto timer = tv::CUDATimer();
""")
# nested tv::dispatch may cause compiler bug in msvc.
......@@ -1645,6 +1650,7 @@ class SpconvOps(pccm.Class):
code.arg("preallocated", f"std::unordered_map<std::string, tv::Tensor>",
"std::unordered_map<std::string, tv::Tensor>{}",
"Dict[str, cumm.tensorview.Tensor] = {}")
code.arg("do_sort", f"bool", "true")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
......@@ -1788,7 +1794,7 @@ class SpconvOps(pccm.Class):
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, num_act_in}}, tv::int32, 0, stream_int);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count);
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int, mask_int_count, do_sort);
}}
""")
with code.else_():
......@@ -1952,6 +1958,7 @@ Your Conv Params: )" << "\\n";
tv::CUDAKernelTimerGuard timer_guard("gen_conv_inds_sort",
timer, reinterpret_cast<cudaStream_t>(stream_int));
if (is_mask_split){{
TV_ASSERT_RT_ERR(do_sort, "not implemented for now");
for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
if (!is_train){{
......@@ -1967,12 +1974,12 @@ Your Conv Params: )" << "\\n";
}}else{{
if (!is_train){{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
mask_argsort_fwd[0], stream_int, mask_int_count, do_sort);
}}else{{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int, mask_int_count);
mask_argsort_fwd[0], stream_int, mask_int_count, do_sort);
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int, mask_int_count);
mask_argsort_bwd[0], stream_int, mask_int_count, do_sort);
}}
}}
}}
......
......@@ -304,6 +304,7 @@ class StaticAllocator(ExternalAllocator):
code.arg("device", "int")
code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
code.arg("scale", "float", "1.0")
code.raw(f"""
auto tvctx = tv::Context();
""")
......@@ -328,6 +329,7 @@ class StaticAllocator(ExternalAllocator):
code.arg("device", "int")
code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
code.arg("scale", "float", "1.0")
code.raw(f"""
if (name == {pccm.literal(AllocKeys.ThrustTemp)}){{
// thrust tmp shouldn't inside tensor_dict. use a simple method to allocate
......
......@@ -2201,7 +2201,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
if (!output_add.empty() && tune_res.algo_desp.is_int8_inference){{
// use source as bias
beta = output_add_scale;
beta = output_add_scale / output_scale;
}}
if (j > 0){{
......
......@@ -34,8 +34,11 @@ def main(include: str,
cu.namespace = "cumm.gemm.main"
all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS)
# all_imp = IMPLGEMM_SIMT_PARAMS
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
# keep all int8 kernels in libspconv
for x in all_imp:
if x.int8_inference:
x.is_nvrtc = False
all_imp = list(filter(lambda x: (not x.is_nvrtc), all_imp))
if inference_only:
all_imp = list(filter(lambda x: x.op_type == ConvOpType.kForward, all_imp))
convcu = ConvMainUnitTest(all_imp)
......
......@@ -136,6 +136,9 @@ class SparseConvolutionBase:
self.zero_point = 0
if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def _conv_forward(self, training: bool, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None,
channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None, name: Optional[str] = None,
......@@ -681,6 +684,9 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
s += ', bias=False'
if self.algo is not None:
s += f', algo={self.algo}'
if self.act_type != tv.gemm.Activation.None_:
s += f', act={self.act_type}'
return s.format(**self.__dict__)
def _calculate_fan_in_and_fan_out(self):
......@@ -730,8 +736,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(self.training, input, self.weight, self.bias, add_input,
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
from typing import Any, List, Optional, Tuple, TypeVar, Union, Dict
import numpy as np
import torch
......@@ -181,6 +181,10 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer)
self.force_algo = force_algo
self.int8_scale: Optional[np.ndarray] = None
def __repr__(self):
return f"SparseConvTensor[shape={self._features.shape}]"
@property
def is_quantized(self):
......@@ -204,6 +208,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
new_spt.thrust_allocator = self.thrust_allocator
new_spt._timer = self._timer
new_spt.force_algo = self.force_algo
new_spt.int8_scale = self.int8_scale
return new_spt
......@@ -302,6 +307,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
tensor.thrust_allocator = self.thrust_allocator
tensor._timer = self._timer
tensor.force_algo = self.force_algo
tensor.int8_scale = self.int8_scale
return tensor
def expand_nd(ndim: int, val: Union[int, List[int], Tuple[int, ...], np.ndarray]) -> List[int]:
......
......@@ -137,10 +137,11 @@ class TorchAllocator(ExternalAllocator):
else:
ten = torch.empty(shape, dtype=th_dtype, device=dev).zero_()
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
# if self.is_quantized:
# ctx = tv.Context()
# ctx.set_cuda_stream(stream)
# ten_tv.zero_(ctx)
if self.is_quantized:
# no _zeros_affine_quantized available, so we need to zero_ here.
ctx = tv.Context()
ctx.set_cuda_stream(stream)
ten_tv.zero_(ctx)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
self.allocated[name] = ten
......
......@@ -1468,6 +1468,7 @@ def implicit_gemm(features: torch.Tensor,
bias_tv = tv.Tensor()
scale_tv = tv.Tensor()
output_add_tv = tv.Tensor()
is_int8 = features.is_quantized and filters.is_quantized
if output_add is not None:
assert features.dtype == torch.qint8, "fused residual add only support int8"
if bias is not None:
......@@ -1535,15 +1536,31 @@ def implicit_gemm(features: torch.Tensor,
filters = filters.reshape(out_channel, -1, filters.shape[-1])
kv = filters.shape[1]
mask_int_count = div_up(kv, 32)
if is_subm:
out_features = torch.empty((num_activate_out, out_channel),
dtype=output_dtype,
device=features.device)
if is_int8:
if is_subm:
out_features = torch._empty_affine_quantized(size=(num_activate_out, out_channel),
scale=output_scale, zero_point=0, dtype=features.dtype, device=features.device)
# out_features = torch.empty((num_activate_out, out_channel),
# dtype=output_dtype,
# device=features.device)
else:
out_features = torch._empty_affine_quantized(size=(num_activate_out, out_channel),
scale=output_scale, zero_point=0, dtype=features.dtype, device=features.device)
ctx = tv.Context()
ctx.set_cuda_stream(stream)
torch_tensor_to_tv(out_features).zero_(ctx)
# out_features = torch.zeros((num_activate_out, out_channel),
# dtype=output_dtype,
# device=features.device)
else:
out_features = torch.zeros((num_activate_out, out_channel),
dtype=output_dtype,
device=features.device)
if is_subm:
out_features = torch.empty((num_activate_out, out_channel),
dtype=output_dtype,
device=features.device)
else:
out_features = torch.zeros((num_activate_out, out_channel),
dtype=output_dtype,
device=features.device)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
features_tv = torch_tensor_to_tv(features)
filters_tv = torch_tensor_to_tv(filters)
......@@ -1617,7 +1634,7 @@ def implicit_gemm(features: torch.Tensor,
if bias is not None and not tune_res.algo_desp.is_int8_inference:
beta = 1
if output_add is not None and tune_res.algo_desp.is_int8_inference:
beta = output_add_scale
beta = output_add_scale / output_scale
CONV.run_with_tuned_result(
tune_res,
ConvOpType.kForward,
......@@ -1640,7 +1657,7 @@ def implicit_gemm(features: torch.Tensor,
act_alpha=act_alpha,
act_beta=act_beta,
scale=scale_tv,
output_add=output_add)
output_add=output_add_tv)
return out_features, mask_output_fwd, mask_width
......
......@@ -591,14 +591,14 @@ SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[
snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU),
snni.SpconvAddReLUNd: (snnqr.SpConv, snniq.SparseConvAddReLU),
# use simple cumm i8 conv to implement linear
nni.LinearReLU: (nnqr.Linear, snniq.LinearPerChannelWeightReLU),
# nni.LinearReLU: (nnqr.Linear, snniq.LinearPerChannelWeightReLU),
}
SPCONV_STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module],
Type[WeightedQuantizedModule]] = {
snnqr.SpConv: snnq.SparseConv,
nnqr.Linear: snnq.LinearPerChannelWeight,
# nnqr.Linear: snnq.LinearPerChannelWeight,
}
......
......@@ -14,7 +14,7 @@
from typing import Optional
from spconv.pytorch.core import SparseConvTensor
from spconv.pytorch.cppcore import get_current_stream
from spconv.pytorch.cppcore import get_current_stream, torch_tensor_to_tv
import spconv.pytorch.quantization.quantized as nnq
from spconv.pytorch.quantization.intrinsic import SpconvReLUNd, SpconvAddReLUNd
from cumm import tensorview as tv
......@@ -88,7 +88,6 @@ class SparseConvAddReLU(nnq.SparseConv):
def forward(self, input, add_input: Optional[SparseConvTensor] = None):
msg = f"{input.features.shape[0]}, {input.features.shape[1]}, {self.weight().shape[0]}"
with tv.measure_and_print(f"QuantizedSparseConvAddReLU|{msg}", get_current_stream(), enable=False):
inp_scale = input.q_scale()
w_scales = self.weight().q_per_channel_scales().to(torch.float32)
out_scale = self.scale
......
......@@ -87,7 +87,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
bias_float = (
torch.zeros(out_channels, dtype=torch.float,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
self._max_voxels = torch.zeros(1, dtype=torch.int32, device=device)
self.set_weight_bias(qweight, bias_float)
self.scale = 1.0
self.zero_point = 0
......@@ -96,6 +96,9 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
self._weight: torch.Tensor = qweight
self._bias: torch.Tensor = bias_float
def set_max_voxels(self, max_voxel):
self._max_voxels = max_voxel
def bias(self):
return self._bias
......@@ -137,6 +140,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
destination[prefix + 'bias'] = b
destination[prefix + 'scale'] = torch.tensor(self.scale)
destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
destination[prefix + 'max_voxels'] = torch.tensor(self._max_voxels)
# @torch.jit.export
# def __getstate__(self):
......@@ -169,6 +173,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
state_dict.pop(prefix + 'weight')
state_dict.pop(prefix + 'bias')
self.scale = float(state_dict[prefix + 'scale'])
state_dict.pop(prefix + 'max_voxels')
self._max_voxels = state_dict[prefix + 'max_voxels']
state_dict.pop(prefix + 'scale')
self.zero_point = int(state_dict[prefix + 'zero_point'])
state_dict.pop(prefix + 'zero_point')
......@@ -213,6 +219,7 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
# the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv = cls(mod.ndim, mod.in_channels, mod.out_channels, mod.kernel_size,
mod.stride, mod.padding, mod.dilation,
......@@ -230,6 +237,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
act_alpha=mod.act_alpha,
act_beta=mod.act_beta)
qconv.set_weight_bias(qweight, mod.bias)
if mod.get_max_num_voxels() is not None:
qconv.set_max_voxels(mod.get_max_num_voxels())
if activation_post_process is None or activation_post_process.dtype == torch.float:
return qconv # dynamic quantization doesn't need scale/zero_point
else:
......@@ -295,6 +304,8 @@ class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
qconv.set_weight_bias(qweight, ref_qconv.bias)
qconv.scale = float(output_scale)
qconv.zero_point = int(output_zero_point)
if ref_qconv.get_max_num_voxels() is not None:
qconv.set_max_voxels(ref_qconv.get_max_num_voxels())
return qconv
......
......@@ -85,6 +85,8 @@ class _SpConvNd(sconvmod.SparseConvolution, ReferenceQuantizedModule):
qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
if float_conv.bias is not None:
qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
if conv.get_max_num_voxels() is not None:
qref_conv.get_max_num_voxels()[:] = conv.get_max_num_voxels()
return qref_conv
......
......@@ -273,7 +273,10 @@ class SparseConvTester:
if self.check_int8_infer:
rescaled = output_ref.astype(self.dtype_comp) * self.scales.astype(self.dtype_comp)
rescaled += self.bias.astype(self.dtype_comp)
rescaled += self.output_add[self.out_order].astype(self.dtype_comp) * self.output_add_scale
if self.subm:
rescaled += self.output_add.astype(self.dtype_comp) * self.output_add_scale
else:
rescaled += self.output_add[self.out_order].astype(self.dtype_comp) * self.output_add_scale
if self.check_act:
rescaled = np.maximum(rescaled, 0)
if self.out_dtype == np.int8:
......@@ -1020,8 +1023,8 @@ def _test_native_conv_cuda(subm: bool):
def test_all_algo_unit():
# for i in range(5):
# _test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda(False)
_test_impgemm_conv_cuda(True)
# _test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment