Commit e387ee74 authored by yan.yan's avatar yan.yan
Browse files

sync quantization code

parent b1c57a31
......@@ -13,25 +13,37 @@
# limitations under the License.
from __future__ import print_function
import argparse
import contextlib
import copy
from typing import Dict, Optional
import torch
import spconv.pytorch as spconv
import torch.ao.quantization
import torch.ao.quantization.quantize_fx as qfx
import torch.cuda.amp
import torch.fx
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.ao.quantization import (DeQuantStub, QuantStub,
get_default_qconfig_mapping)
from torch.ao.quantization.fx._lower_to_native_backend import \
STATIC_LOWER_FUSED_MODULE_MAP, STATIC_LOWER_MODULE_MAP
from torch.optim.lr_scheduler import StepLR
import contextlib
import torch.cuda.amp
import torch.ao.quantization
from torch.ao.quantization import QuantStub, DeQuantStub
import torch.ao.quantization.quantize_fx as qfx
from spconv.pytorch.quantization.fake_q import get_default_spconv_qconfig_mapping
from torchvision import datasets, transforms
import spconv.pytorch as spconv
import spconv.pytorch.quantization as spconvq
from spconv.pytorch.quantization import get_default_spconv_trt_ptq_qconfig
from torch.ao.quantization import get_default_qconfig_mapping
from spconv.pytorch.quantization.backend_cfg import SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
from torch.ao.quantization.fx._lower_to_native_backend import STATIC_LOWER_FUSED_MODULE_MAP
from spconv.pytorch.quantization.backend_cfg import \
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP, SPCONV_STATIC_LOWER_MODULE_MAP
from spconv.pytorch.quantization.core import quantize_per_tensor
from spconv.pytorch.quantization.fake_q import \
get_default_spconv_qconfig_mapping
from spconv.pytorch.quantization.intrinsic.modules import SpconvBnAddReLUNd, SpconvAddReLUNd
import spconv.pytorch.quantization.intrinsic.quantized as snniq
@contextlib.contextmanager
def identity_ctx():
......@@ -57,6 +69,142 @@ class SparseConvBNReLU(spconv.SparseSequential):
nn.ReLU(inplace=False)
)
class SparseBasicBlock(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True))
self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = nn.Identity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x.features)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1_bn_relu(x)
out = self.conv2_bn(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out.replace_feature(self.relu(out.features + identity))
return out
class SparseBasicBlock1(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
self.conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = nn.Identity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x.features)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out = out.replace_feature(self.relu1(self.norm1(out.features)))
out = self.conv2(out)
out = out.replace_feature(self.norm2(out.features))
# if self.downsample is not None:
# identity = self.downsample(x)
out = out.replace_feature(self.relu2(out.features + identity))
return out
class SparseBasicBlock2(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
self.conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.norm2 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.relu1 = spconv.SparseReLU(inplace=True)
self.relu2 = spconv.SparseReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out = self.relu1(self.norm1(out))
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.relu2(out + identity)
return out
class SparseBasicBlock3(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
norm2 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.residual_conv = SpconvAddReLUNd(conv2, spconv.SparseReLU(inplace=True))
self.relu1 = spconv.SparseReLU(inplace=True)
# self.relu2 = spconv.SparseReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out = self.relu1(self.norm1(out))
if self.downsample is not None:
identity = self.downsample(x)
out = self.residual_conv(out, identity)
return out
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
......@@ -126,7 +274,7 @@ class NetV2(nn.Module):
class NetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so
we only use sparse ops here.
we build a pure sparse network here.
"""
def __init__(self):
super(NetPTQ, self).__init__()
......@@ -138,7 +286,6 @@ class NetPTQ(nn.Module):
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4),
spconv.ToDense(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
......@@ -158,22 +305,47 @@ class NetPTQ(nn.Module):
# print(x_sp.shape)
x = x_sp
x = torch.flatten(x, 1)
# x_res = torch.zeros_like(x)
# x_res[x_sp.indices[:, 0].long()] = x
# x = x_res
# x = torch.flatten(x, 1)
# x = self.dropout1(x)
# x = self.fc1(x)
# x = F.relu(x)
# x = self.dropout2(x)
# x = self.fc2(x)
# print(x_sp.features.shape, x_sp.spatial_shape)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class ResidualNetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so
we build a pure sparse network here.
"""
def __init__(self):
super(ResidualNetPTQ, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SparseBasicBlock2(32, 32),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14
SparseConvBNReLU(64, 64, 2, 2), # 7x7
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4),
spconv.ToDense(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
features = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp = self.net(x_sp)
# print(x_sp.shape)
x = x_sp
x = torch.flatten(x, 1)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetDense(nn.Module):
def __init__(self):
......@@ -184,6 +356,8 @@ class NetDense(nn.Module):
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.iden = spconv.SparseIdentity()
self.quant = QuantStub()
self.dequant = DeQuantStub()
......@@ -195,6 +369,7 @@ class NetDense(nn.Module):
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.iden(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
......@@ -299,6 +474,54 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else:
output = model(image)
def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for node in m.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.quantize_per_tensor:
node.target = quantize_per_tensor
m.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
m.recompile()
return m
def is_dequantize_node(node):
return isinstance(node, torch.fx.Node) and node.op == "call_method" and node.target == "dequantize"
def _get_module(node: torch.fx.Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if node.op == "call_module" and str(node.target) in modules:
return modules[str(node.target)]
else:
return None
def remove_conv_add_dq(model: torch.fx.graph_module.GraphModule):
modules = dict(model.named_modules(remove_duplicate=False))
for n in model.graph.nodes:
if (n.op == "call_module" and type(_get_module(n, modules)) == snniq.SparseConvAddReLU):
# check second input, if it's dequantized, remove that dequantize node
arg1 = n.args[1]
if is_dequantize_node(arg1):
dq_node = arg1
assert(isinstance(dq_node, torch.fx.Node))
dn_input = dq_node.args[0]
n.replace_input_with(dq_node, dn_input)
model.graph.eliminate_dead_code()
model.recompile()
model.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return model
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
......@@ -361,11 +584,11 @@ def main():
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
device = torch.device("cuda" if use_cuda and args.sparse else "cpu")
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
if args.sparse:
model = NetPTQ().to(device)
model = ResidualNetPTQ().to(device)
else:
model = NetDense().to(device)
......@@ -401,42 +624,61 @@ def main():
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
scheduler.step()
# if args.save_model:
# torch.save(model.state_dict(), "mnist_cnn.pt")
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
model.eval()
STATIC_LOWER_FUSED_MODULE_MAP.update(SPCONV_STATIC_LOWER_FUSED_MODULE_MAP)
if not args.sparse:
model = model.cpu()
# qconfig_mapping_default = get_default_qconfig_mapping("x86")
model_qat = copy.deepcopy(model)
STATIC_LOWER_FUSED_MODULE_MAP.update(SPCONV_STATIC_LOWER_FUSED_MODULE_MAP)
STATIC_LOWER_MODULE_MAP.update(SPCONV_STATIC_LOWER_MODULE_MAP)
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
qconfig_mapping = get_default_spconv_qconfig_mapping(False)
prepare_cfg = spconvq.get_spconv_prepare_custom_config()
backend_cfg = spconvq.get_spconv_backend_config()
convert_cfg = spconvq.get_spconv_convert_custom_config()
# convert_cfg = spconvq.get_spconv_convert_custom_config()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# prepared_model.print_readable()
print([type(m) for m in prepared_model.modules()])
print(prepared_model)
# print(prepared_model)
# breakpoint()
# print(prepared_model)
# calibrate: run model with some inputs
calibrate(args, prepared_model, test_loader, qdevice)
# calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
converted_model = transform_qdq(converted_model)
# test converted ptq model with int8 kernel
remove_conv_add_dq(converted_model)
converted_model = qfx.convert_to_reference_fx(prepared_model, convert_cfg, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
print([type(m) for m in converted_model.modules()])
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
# model.qconfig = get_default_spconv_trt_ptq_qconfig()
# prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# convert_custom_config_dict = spconvq.get_convert_custom_config()
# torch.ao.quantization.prepare(model, inplace=True)
# print('Post Training Quantization Prepare: Inserting Observers')
# print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# test(args, model, device, test_loader)
print(converted_model)
breakpoint()
test(args, converted_model, qdevice, test_loader)
# do qat
# qconfig_mapping_qat = get_default_spconv_qconfig_mapping(True)
# prepared_model_qat = qfx.prepare_qat_fx(model_qat, qconfig_mapping_qat, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# # converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
# # breakpoint()
# print(prepared_model_qat)
# train(args, prepared_model_qat, qdevice, train_loader, optimizer, 1)
# converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
# converted_model = transform_qdq(converted_model)
# test(args, converted_model, qdevice, test_loader)
# # [type(m) for m in prepared_model_qat.modules()]
# # model.qconfig = get_default_spconv_trt_ptq_qconfig()
# # prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# # convert_custom_config_dict = spconvq.get_convert_custom_config()
# # torch.ao.quantization.prepare(model, inplace=True)
# # print('Post Training Quantization Prepare: Inserting Observers')
# # print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# # test(args, model, device, test_loader)
# print(converted_model)
# you will see some nvrtc compile log here, which means int8 kernel is used.
breakpoint()
if __name__ == '__main__':
main()
......@@ -188,10 +188,16 @@ class ConvTunerSimple(ConvTunerSimpleBase):
cudadevrt_p = get_cudadevrt_path()
assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt"
cudadevrt = str(cudadevrt_p)
# mod = CummNVRTCModule([kernel],
# cudadevrt_path=cudadevrt,
# verbose=True,
# custom_names=custom_names,
# verbose_path="/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8")
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
verbose=False,
custom_names=custom_names)
mod.load()
return mod, kernel
......
......@@ -18,10 +18,10 @@ from typing import List
import pccm
from pccm.utils import project_is_editable, project_is_installed
from ccimport.compat import InWindows
from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT
from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT, SPCONV_INT8_DEBUG
if project_is_installed(PACKAGE_NAME) and project_is_editable(
PACKAGE_NAME) and not DISABLE_JIT and False:
PACKAGE_NAME) and not DISABLE_JIT and not SPCONV_INT8_DEBUG:
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
......
......@@ -116,3 +116,5 @@ SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
SPCONV_ALLOW_TF32 = False
SPCONV_INT8_DEBUG = False
\ No newline at end of file
This diff is collapsed.
......@@ -14,7 +14,8 @@ from spconv.pytorch.conv import (SparseConv1d, SparseConv2d, SparseConv3d,
SubMConv3d, SubMConv4d)
from spconv.pytorch.identity import Identity
from spconv.pytorch.modules import (SparseModule, SparseSequential,
assign_name_for_sparse_modules)
assign_name_for_sparse_modules, SparseBatchNorm,
SparseReLU, SparseIdentity)
from spconv.pytorch.ops import ConvAlgo
from spconv.pytorch.pool import (SparseMaxPool1d, SparseMaxPool2d,
SparseMaxPool3d, SparseMaxPool4d,
......
This diff is collapsed.
......@@ -233,6 +233,9 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
features_th = x_sp.values()
return cls(features_th, indices_th, spatial_shape, batch_size)
def dequantize(self):
return self.replace_feature(self.features.dequantize())
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
......@@ -264,6 +267,19 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
# return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size
def __add__(self, other: "SparseConvTensor"):
assert isinstance(other, SparseConvTensor)
return self.replace_feature(self.features + other.features)
def __iadd__(self, other: "SparseConvTensor"):
assert isinstance(other, SparseConvTensor)
self.features += other.features
return self
def __radd__(self, other: "SparseConvTensor"):
assert isinstance(other, SparseConvTensor)
return other.replace_feature(self.features + other.features)
def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices,
......
......@@ -23,7 +23,7 @@ from spconv import pytorch as spconv
def is_spconv_module(module):
spconv_modules = (SparseModule, )
spconv_modules = (SparseModule, SparseBatchNorm, SparseReLU)
return isinstance(module, spconv_modules)
......@@ -148,3 +148,37 @@ def assign_name_for_sparse_modules(module: nn.Module):
for k, n in module.named_modules():
if isinstance(n, SparseModule):
n._sparse_unique_name = k
class SparseBatchNorm(nn.BatchNorm1d):
"""this module is exists only for torch.fx transformation for quantization.
"""
def forward(self, input):
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features))
return super().forward(input)
class SparseSyncBatchNorm(nn.SyncBatchNorm):
"""this module is exists only for torch.fx transformation for quantization.
"""
def forward(self, input):
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features))
return super().forward(input)
class SparseReLU(nn.ReLU):
"""this module is exists only for torch.fx transformation for quantization.
"""
def forward(self, input):
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features))
return super().forward(input)
class SparseIdentity(nn.Identity):
"""this module is exists only for torch.fx transformation for quantization.
"""
def forward(self, input):
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features))
return super().forward(input)
......@@ -1462,14 +1462,14 @@ def implicit_gemm(features: torch.Tensor,
output_scale: float = 1.0,
scale: Optional[torch.Tensor] = None,
output_add: Optional[torch.Tensor] = None,
output_add_scale: float = 1.0,
output_add_scale: float = 0.0,
output_dtype: Optional[torch.dtype] = None):
stream = get_current_stream()
bias_tv = tv.Tensor()
scale_tv = tv.Tensor()
output_add_tv = tv.Tensor()
if output_add is not None:
assert features.dtype == torch.int8, "fused residual add only support int8"
assert features.dtype == torch.qint8, "fused residual add only support int8"
if bias is not None:
bias_tv = torch_tensor_to_tv(bias)
if scale is not None:
......@@ -1485,7 +1485,7 @@ def implicit_gemm(features: torch.Tensor,
output_dtype = features.dtype
if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device)
alloc = TorchAllocator(features.device, features.dtype == torch.qint8)
features_tv = torch_tensor_to_tv(features)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_splits_tv = [
......@@ -1963,6 +1963,12 @@ def indice_maxpool_implicit_gemm(features: torch.Tensor,
features = features.contiguous()
out_channel = features.shape[-1]
if features.is_quantized:
out_features = torch._empty_affine_quantized((num_activate_out, out_channel),
scale=features.q_scale(),
dtype=features.dtype,
device=features.device)
else:
out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype,
device=features.device)
......@@ -2016,9 +2022,16 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
features = features.contiguous()
out_channel = features.shape[-1]
if features.is_quantized:
out_features = torch._empty_affine_quantized((num_activate_out, out_channel),
scale=features.q_scale(),
dtype=features.dtype,
device=features.device)
else:
out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype,
device=features.device)
assert features.is_cuda
stream = get_current_stream()
out_features_tv = torch_tensor_to_tv(out_features)
......
......@@ -66,14 +66,14 @@ class SparseMaxPool(SparseModule):
if algo is None:
# keep in mind that this algorithm is set for Inverse Sparse Conv
# maxpool itself don't need mask.
if kv <= 32 and not CPU_ONLY_BUILD:
if kv <= 128 and not CPU_ONLY_BUILD:
if kv < 8:
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.Native
if kv > 32:
if kv > 128:
assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
if CPU_ONLY_BUILD:
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
......@@ -96,7 +96,10 @@ class SparseMaxPool(SparseModule):
return None
def forward(self, input):
def forward(self, input: spconv.SparseConvTensor):
is_int8 = input.is_quantized
if is_int8:
assert self.algo == ConvAlgo.MaskImplicitGemm, "only ConvAlgo.MaskImplicitGemm support int8."
assert isinstance(input, spconv.SparseConvTensor)
features = input.features
device = features.device
......@@ -296,6 +299,10 @@ class SparseAvgPool(SparseModule):
def forward(self, input):
assert isinstance(input, spconv.SparseConvTensor)
is_int8 = input.is_quantized
if is_int8:
assert self.algo == ConvAlgo.MaskImplicitGemm, "only ConvAlgo.MaskImplicitGemm support int8."
features = input.features
device = features.device
indices = input.indices
......@@ -534,3 +541,8 @@ class SparseAvgPool3d(SparseAvgPool):
algo=algo,
record_voxel_count=record_voxel_count,
name=name)
ALL_POOL_LAYERS = set([
SparseAvgPool3d, SparseAvgPool2d, SparseAvgPool1d, SparseMaxPool1d, SparseMaxPool2d, SparseMaxPool3d, SparseMaxPool4d, SparseAvgPool, SparseMaxPool
])
\ No newline at end of file
......@@ -19,3 +19,4 @@ from .fake_q import (get_default_spconv_trt_ptq_qconfig,
get_default_spconv_trt_qat_qconfig)
from .qmapping import (get_spconv_fmod_to_qat_mapping,
get_spconv_qat_to_static_mapping)
from .core import quantize_per_tensor
\ No newline at end of file
This diff is collapsed.
from typing import Union, List, Dict
import torch
from spconv.pytorch.core import SparseConvTensor
def quantize_per_tensor(ten: Union[Union[SparseConvTensor, torch.Tensor], List[Union[SparseConvTensor, torch.Tensor]]], scale, zero_point, dtype):
if isinstance(ten, (list, tuple)):
res = []
for i, v in enumerate(ten):
if isinstance(v, SparseConvTensor):
res.append(v.replace_feature(torch.quantize_per_tensor(v.features, scale[i], zero_point[i], dtype)))
else:
res.append(torch.quantize_per_tensor(v, scale[i], zero_point[i], dtype))
return res
else:
if isinstance(ten, SparseConvTensor):
return ten.replace_feature(torch.quantize_per_tensor(ten.features, scale, zero_point, dtype))
else:
return torch.quantize_per_tensor(ten, scale, zero_point, dtype)
\ No newline at end of file
......@@ -11,7 +11,7 @@ from torch.ao.quantization.observer import (HistogramObserver,
from torch.ao.quantization.qconfig import QConfig, QConfigAny, default_reuse_input_qconfig
from torch.ao.quantization.qconfig_mapping import QConfigMapping, _FIXED_QPARAMS_OP_TO_OBSERVER
from typing import Any, Callable, Dict, Tuple, Union, List
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization import get_default_qconfig, get_default_qat_qconfig
from spconv.pytorch.core import SparseConvTensor
__all__ = ["get_default_spconv_trt_ptq_qconfig", "get_default_spconv_trt_qat_qconfig"]
......@@ -80,13 +80,14 @@ def get_default_spconv_trt_ptq_qconfig(backend, version):
def get_default_spconv_trt_qat_qconfig(backend, version):
return default_symmetric_spconv_qat_qconfig
def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "x86", version: int = 0) -> QConfigMapping:
def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "fbgemm", version: int = 0) -> QConfigMapping:
"""
From torch.ao.quantization.qconfig_mapping
Return the default QConfigMapping for the given quantization type and backend.
"""
# get_default_qconfig(backend, version)
if is_qat:
# qconfig = get_default_qat_qconfig(backend, version)
qconfig = get_default_spconv_trt_qat_qconfig(backend, version)
else:
# qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8),
......@@ -144,3 +145,4 @@ def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "x86", versi
.set_object_type(torch.nn.functional.tanh, qconfig)
return qconfig_mapping
from functools import partial
from typing import Union, Callable, Tuple, Dict, Optional, Type, Any
import torch.nn as nn
import spconv.pytorch as spconv
......@@ -5,7 +6,8 @@ from .utils import fuse_spconv_bn_eval
from . import intrinsic as snni
from .intrinsic.qat.modules import SparseConvBn, SparseConvBnReLU, SparseConvBnAddReLU
from spconv.pytorch.conv import DEFAULT_SPARSE_CONV_TYPES
def fuse_conv_bn(is_qat, conv, bn):
def fuse_conv_bn(is_qat, conv, bn, is_add_fuse: bool = False):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
......@@ -20,11 +22,10 @@ def fuse_conv_bn(is_qat, conv, bn):
"""
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
fuse_cls = snni.SpconvAddReLUNd if is_add_fuse else snni.SpconvBnNd
fused_module_class_map = {
k: snni.SpconvBnNd for k in DEFAULT_SPARSE_CONV_TYPES
k: fuse_cls for k in DEFAULT_SPARSE_CONV_TYPES
}
if is_qat:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
......@@ -37,7 +38,7 @@ def fuse_conv_bn(is_qat, conv, bn):
else:
return fuse_spconv_bn_eval(conv, bn)
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
def fuse_conv_bn_relu(is_qat, conv, bn, relu, is_add_fuse: bool = False):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
......@@ -54,8 +55,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[spconv.SparseSequential]] = None
if is_qat:
fuse_cls = snni.SpconvBnAddReLUNd if is_add_fuse else snni.SpconvBnReLUNd
map_to_fused_module_train = {
k: snni.SpconvBnReLUNd for k in DEFAULT_SPARSE_CONV_TYPES
k: fuse_cls for k in DEFAULT_SPARSE_CONV_TYPES
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
......@@ -66,8 +68,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
fuse_cls = snni.SpconvAddReLUNd if is_add_fuse else snni.SpconvReLUNd
map_to_fused_module_eval = {
k: snni.SpconvReLUNd for k in DEFAULT_SPARSE_CONV_TYPES
k: fuse_cls for k in DEFAULT_SPARSE_CONV_TYPES
}
fused_module = map_to_fused_module_eval.get(type(conv), None)
if fused_module is not None:
......@@ -76,28 +79,21 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
# DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
# (spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# }
# def get_spconv_fuse_method_mapping():
# return DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
def fuse_conv_bn_add_relu(is_qat, relu, add_pattern):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return fuse_conv_bn_relu(is_qat, conv, bn, relu, True)
# Default map for swapping float module to qat modules
......@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .modules import SpconvBnNd, SpconvBnReLUNd, SpconvBnAddReLUNd, SpconvReLUNd
from .modules import SpconvBnNd, SpconvBnReLUNd, SpconvBnAddReLUNd, SpconvReLUNd, SpconvAddReLUNd
......@@ -60,3 +60,27 @@ class SpconvBnAddReLUNd(_FusedSparseModule):
isinstance(relu, ReLU), 'Incorrect types for input modules{}{}{}' \
.format(type(conv), type(bn), type(relu))
super().__init__(conv, bn, relu)
def forward(self, input, add_input):
conv = self[0]
bn = self[1]
relu = self[2]
conv_res = conv(input)
conv_res = conv_res.replace_feature(bn(conv_res.features))
return conv_res.replace_feature(relu(conv_res.features + add_input.features))
class SpconvAddReLUNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert isinstance(conv, SparseConvolution) and isinstance(relu, ReLU), \
'Incorrect types for input modules{}{}'.format(
type(conv), type(relu))
super().__init__(conv, relu)
def forward(self, input, add_input):
conv = self[0]
relu = self[1]
conv_res = conv(input)
return conv_res.replace_feature(relu(conv_res.features + add_input.features))
......@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .modules import SparseConvBn, SparseConvBnAddReLU, SparseConvBnReLU, SparseConv, SparseConvReLU
\ No newline at end of file
from .modules import (SparseConv, SparseConvAddReLU, SparseConvBn,
SparseConvBnAddReLU, SparseConvBnReLU, SparseConvReLU)
......@@ -17,7 +17,7 @@ import spconv.pytorch.quantization.intrinsic as snni
from spconv.pytorch.quantization.utils import fuse_spconv_bn_weights
MOD = TypeVar('MOD', bound=SparseConvolution)
class _SparseConv(SparseConvolution, nni._FusedModule):
class _SparseConv(SparseConvolution):
_FLOAT_MODULE = MOD
_FLOAT_CONV_MODULE = SparseConvolution
......@@ -67,7 +67,7 @@ class _SparseConv(SparseConvolution, nni._FusedModule):
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return self._conv_forward(False, input, self.weight_fake_quant(self.weight), self.bias)
return self._conv_forward(self.training, input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod
def from_float(cls, mod):
......@@ -77,11 +77,12 @@ class _SparseConv(SparseConvolution, nni._FusedModule):
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
assert issubclass(type(mod), cls._FLOAT_MODULE), (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
+ f" not {type(mod).__qualname__}"
)
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
......@@ -197,6 +198,33 @@ class SparseConvReLU(SparseConv, nni._FusedModule):
def from_float(cls, mod):
return super(SparseConvReLU, cls).from_float(mod)
class SparseConvAddReLU(SparseConv, nni._FusedModule):
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = snni.SpconvAddReLUNd
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = None
_FLOAT_RELU_MODULE = nn.ReLU
def forward(self, input, add_input):
x = self._conv_forward(self.training, input, self.weight_fake_quant(self.weight), self.bias,
add_input=add_input)
return x.replace_feature(F.relu(x.features))
@classmethod
def from_float(cls, mod):
return super(SparseConvAddReLU, cls).from_float(mod)
class _SparseConvBn(SparseConvolution, nni._FusedModule):
_version = 2
......@@ -323,9 +351,9 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.features.dtype)
conv_spt = self._conv_forward(self.training, input, scaled_weight, zero_bias)
conv = conv_spt.features
conv_orig = conv / scale_factor.reshape(bias_shape)
conv_orig = conv / scale_factor# .reshape(bias_shape)
if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape(bias_shape)
conv_orig = conv_orig + self.bias# .reshape(bias_shape)
conv = self.bn(conv_orig)
if add_input is not None:
conv = conv + add_input.features
......@@ -377,7 +405,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
conv_out = torch.Tensor()
if self.bn.training:
# needed to compute batch mean/std
conv_spt = self._conv_forward(input, self.weight, zero_bias)
conv_spt = self._conv_forward(self.training, input, self.weight, zero_bias)
conv_out = conv_spt.features
# update bn statistics
with torch.no_grad():
......@@ -393,7 +421,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
self.weight * scale_factor.reshape(weight_shape)
)
# fused conv without bias for inference: (r * W / running_std) * X
conv_bn_spt = self._conv_forward(input, scaled_weight, zero_bias)
conv_bn_spt = self._conv_forward(self.training, input, scaled_weight, zero_bias)
conv_bn = conv_bn_spt.features
if self.bn.training:
avg_dims = [0] + list(range(2, len(self.weight.shape)))
......@@ -669,12 +697,12 @@ class SparseConvBnAddReLU(_SparseConvBn):
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE = snni.SpconvBnReLUNd # type: ignore[assignment]
_FLOAT_MODULE = snni.SpconvBnAddReLUNd # type: ignore[assignment]
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = nn.BatchNorm1d
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE = snni.SpconvReLUNd
_FUSED_FLOAT_MODULE = snni.SpconvAddReLUNd
def forward(self, input, add_input):
x = _SparseConvBn._forward(self, input, add_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment