Commit e387ee74 authored by yan.yan's avatar yan.yan
Browse files

sync quantization code

parent b1c57a31
...@@ -13,25 +13,37 @@ ...@@ -13,25 +13,37 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import contextlib
import copy
from typing import Dict, Optional
import torch import torch
import spconv.pytorch as spconv import torch.ao.quantization
import torch.ao.quantization.quantize_fx as qfx
import torch.cuda.amp
import torch.fx
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torchvision import datasets, transforms from torch.ao.quantization import (DeQuantStub, QuantStub,
get_default_qconfig_mapping)
from torch.ao.quantization.fx._lower_to_native_backend import \
STATIC_LOWER_FUSED_MODULE_MAP, STATIC_LOWER_MODULE_MAP
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
import contextlib from torchvision import datasets, transforms
import torch.cuda.amp
import torch.ao.quantization import spconv.pytorch as spconv
from torch.ao.quantization import QuantStub, DeQuantStub
import torch.ao.quantization.quantize_fx as qfx
from spconv.pytorch.quantization.fake_q import get_default_spconv_qconfig_mapping
import spconv.pytorch.quantization as spconvq import spconv.pytorch.quantization as spconvq
from spconv.pytorch.quantization import get_default_spconv_trt_ptq_qconfig from spconv.pytorch.quantization import get_default_spconv_trt_ptq_qconfig
from torch.ao.quantization import get_default_qconfig_mapping from spconv.pytorch.quantization.backend_cfg import \
from spconv.pytorch.quantization.backend_cfg import SPCONV_STATIC_LOWER_FUSED_MODULE_MAP SPCONV_STATIC_LOWER_FUSED_MODULE_MAP, SPCONV_STATIC_LOWER_MODULE_MAP
from torch.ao.quantization.fx._lower_to_native_backend import STATIC_LOWER_FUSED_MODULE_MAP from spconv.pytorch.quantization.core import quantize_per_tensor
from spconv.pytorch.quantization.fake_q import \
get_default_spconv_qconfig_mapping
from spconv.pytorch.quantization.intrinsic.modules import SpconvBnAddReLUNd, SpconvAddReLUNd
import spconv.pytorch.quantization.intrinsic.quantized as snniq
@contextlib.contextmanager @contextlib.contextmanager
def identity_ctx(): def identity_ctx():
...@@ -57,6 +69,142 @@ class SparseConvBNReLU(spconv.SparseSequential): ...@@ -57,6 +69,142 @@ class SparseConvBNReLU(spconv.SparseSequential):
nn.ReLU(inplace=False) nn.ReLU(inplace=False)
) )
class SparseBasicBlock(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True))
self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = nn.Identity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x.features)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1_bn_relu(x)
out = self.conv2_bn(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out.replace_feature(self.relu(out.features + identity))
return out
class SparseBasicBlock1(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
self.conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = nn.Identity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x.features)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out = out.replace_feature(self.relu1(self.norm1(out.features)))
out = self.conv2(out)
out = out.replace_feature(self.norm2(out.features))
# if self.downsample is not None:
# identity = self.downsample(x)
out = out.replace_feature(self.relu2(out.features + identity))
return out
class SparseBasicBlock2(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
self.conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.norm2 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.relu1 = spconv.SparseReLU(inplace=True)
self.relu2 = spconv.SparseReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out = self.relu1(self.norm1(out))
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.relu2(out + identity)
return out
class SparseBasicBlock3(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
norm2 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.residual_conv = SpconvAddReLUNd(conv2, spconv.SparseReLU(inplace=True))
self.relu1 = spconv.SparseReLU(inplace=True)
# self.relu2 = spconv.SparseReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out = self.relu1(self.norm1(out))
if self.downsample is not None:
identity = self.downsample(x)
out = self.residual_conv(out, identity)
return out
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
...@@ -126,7 +274,7 @@ class NetV2(nn.Module): ...@@ -126,7 +274,7 @@ class NetV2(nn.Module):
class NetPTQ(nn.Module): class NetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so """pytorch currently don't support cuda int8 inference, so
we only use sparse ops here. we build a pure sparse network here.
""" """
def __init__(self): def __init__(self):
super(NetPTQ, self).__init__() super(NetPTQ, self).__init__()
...@@ -138,7 +286,6 @@ class NetPTQ(nn.Module): ...@@ -138,7 +286,6 @@ class NetPTQ(nn.Module):
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4 SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4), spconv.SparseConv2d(64, 10, 4, 4),
spconv.ToDense(), spconv.ToDense(),
) )
# self.fc1 = nn.Linear(64 * 1 * 1, 128) # self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10) # self.fc2 = nn.Linear(128, 10)
...@@ -158,22 +305,47 @@ class NetPTQ(nn.Module): ...@@ -158,22 +305,47 @@ class NetPTQ(nn.Module):
# print(x_sp.shape) # print(x_sp.shape)
x = x_sp x = x_sp
x = torch.flatten(x, 1) x = torch.flatten(x, 1)
# x_res = torch.zeros_like(x)
# x_res[x_sp.indices[:, 0].long()] = x
# x = x_res
# x = torch.flatten(x, 1)
# x = self.dropout1(x)
# x = self.fc1(x)
# x = F.relu(x)
# x = self.dropout2(x)
# x = self.fc2(x)
# print(x_sp.features.shape, x_sp.spatial_shape)
x = self.dequant(x) x = self.dequant(x)
output = F.log_softmax(x, dim=1) output = F.log_softmax(x, dim=1)
return output return output
class ResidualNetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so
we build a pure sparse network here.
"""
def __init__(self):
super(ResidualNetPTQ, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SparseBasicBlock2(32, 32),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14
SparseConvBNReLU(64, 64, 2, 2), # 7x7
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4),
spconv.ToDense(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
features = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp = self.net(x_sp)
# print(x_sp.shape)
x = x_sp
x = torch.flatten(x, 1)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetDense(nn.Module): class NetDense(nn.Module):
def __init__(self): def __init__(self):
...@@ -184,6 +356,8 @@ class NetDense(nn.Module): ...@@ -184,6 +356,8 @@ class NetDense(nn.Module):
self.dropout2 = nn.Dropout(0.5) self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128) self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10) self.fc2 = nn.Linear(128, 10)
self.iden = spconv.SparseIdentity()
self.quant = QuantStub() self.quant = QuantStub()
self.dequant = DeQuantStub() self.dequant = DeQuantStub()
...@@ -195,6 +369,7 @@ class NetDense(nn.Module): ...@@ -195,6 +369,7 @@ class NetDense(nn.Module):
x = F.relu(x) x = F.relu(x)
x = self.conv2(x) x = self.conv2(x)
x = F.relu(x) x = F.relu(x)
x = self.iden(x)
x = F.max_pool2d(x, 2) x = F.max_pool2d(x, 2)
x = self.dropout1(x) x = self.dropout1(x)
x = torch.flatten(x, 1) x = torch.flatten(x, 1)
...@@ -299,6 +474,54 @@ def calibrate(args, model: torch.nn.Module, data_loader, device): ...@@ -299,6 +474,54 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else: else:
output = model(image) output = model(image)
def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for node in m.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.quantize_per_tensor:
node.target = quantize_per_tensor
m.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
m.recompile()
return m
def is_dequantize_node(node):
return isinstance(node, torch.fx.Node) and node.op == "call_method" and node.target == "dequantize"
def _get_module(node: torch.fx.Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if node.op == "call_module" and str(node.target) in modules:
return modules[str(node.target)]
else:
return None
def remove_conv_add_dq(model: torch.fx.graph_module.GraphModule):
modules = dict(model.named_modules(remove_duplicate=False))
for n in model.graph.nodes:
if (n.op == "call_module" and type(_get_module(n, modules)) == snniq.SparseConvAddReLU):
# check second input, if it's dequantized, remove that dequantize node
arg1 = n.args[1]
if is_dequantize_node(arg1):
dq_node = arg1
assert(isinstance(dq_node, torch.fx.Node))
dn_input = dq_node.args[0]
n.replace_input_with(dq_node, dn_input)
model.graph.eliminate_dead_code()
model.recompile()
model.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return model
def main(): def main():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
...@@ -361,11 +584,11 @@ def main(): ...@@ -361,11 +584,11 @@ def main():
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda and args.sparse else "cpu")
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu") qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
if args.sparse: if args.sparse:
model = NetPTQ().to(device) model = ResidualNetPTQ().to(device)
else: else:
model = NetDense().to(device) model = NetDense().to(device)
...@@ -401,42 +624,61 @@ def main(): ...@@ -401,42 +624,61 @@ def main():
train(args, model, device, train_loader, optimizer, epoch) train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader) test(args, model, device, test_loader)
scheduler.step() scheduler.step()
# if args.save_model: if args.save_model:
# torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
model.eval() model.eval()
STATIC_LOWER_FUSED_MODULE_MAP.update(SPCONV_STATIC_LOWER_FUSED_MODULE_MAP)
if not args.sparse: if not args.sparse:
model = model.cpu() model = model.cpu()
# qconfig_mapping_default = get_default_qconfig_mapping("x86")
model_qat = copy.deepcopy(model)
STATIC_LOWER_FUSED_MODULE_MAP.update(SPCONV_STATIC_LOWER_FUSED_MODULE_MAP)
STATIC_LOWER_MODULE_MAP.update(SPCONV_STATIC_LOWER_MODULE_MAP)
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
qconfig_mapping = get_default_spconv_qconfig_mapping(False) qconfig_mapping = get_default_spconv_qconfig_mapping(False)
prepare_cfg = spconvq.get_spconv_prepare_custom_config() prepare_cfg = spconvq.get_spconv_prepare_custom_config()
backend_cfg = spconvq.get_spconv_backend_config() backend_cfg = spconvq.get_spconv_backend_config()
convert_cfg = spconvq.get_spconv_convert_custom_config() # convert_cfg = spconvq.get_spconv_convert_custom_config()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic # prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model. # then add observers to fused model.
prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg) prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# prepared_model.print_readable() # print(prepared_model)
print([type(m) for m in prepared_model.modules()]) # breakpoint()
print(prepared_model)
# print(prepared_model)
# calibrate: run model with some inputs # calibrate: run model with some inputs
calibrate(args, prepared_model, test_loader, qdevice) # calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules # convert (ptq): replace intrinsic blocks with quantized modules
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
converted_model = transform_qdq(converted_model)
# test converted ptq model with int8 kernel
remove_conv_add_dq(converted_model)
converted_model = qfx.convert_to_reference_fx(prepared_model, convert_cfg, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
print([type(m) for m in converted_model.modules()])
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
# model.qconfig = get_default_spconv_trt_ptq_qconfig()
# prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# convert_custom_config_dict = spconvq.get_convert_custom_config()
# torch.ao.quantization.prepare(model, inplace=True)
# print('Post Training Quantization Prepare: Inserting Observers')
# print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# test(args, model, device, test_loader)
print(converted_model) print(converted_model)
breakpoint()
test(args, converted_model, qdevice, test_loader) test(args, converted_model, qdevice, test_loader)
# do qat
# qconfig_mapping_qat = get_default_spconv_qconfig_mapping(True)
# prepared_model_qat = qfx.prepare_qat_fx(model_qat, qconfig_mapping_qat, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# # converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
# # breakpoint()
# print(prepared_model_qat)
# train(args, prepared_model_qat, qdevice, train_loader, optimizer, 1)
# converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
# converted_model = transform_qdq(converted_model)
# test(args, converted_model, qdevice, test_loader)
# # [type(m) for m in prepared_model_qat.modules()]
# # model.qconfig = get_default_spconv_trt_ptq_qconfig()
# # prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# # convert_custom_config_dict = spconvq.get_convert_custom_config()
# # torch.ao.quantization.prepare(model, inplace=True)
# # print('Post Training Quantization Prepare: Inserting Observers')
# # print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# # test(args, model, device, test_loader)
# print(converted_model)
# you will see some nvrtc compile log here, which means int8 kernel is used.
breakpoint() breakpoint()
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -188,10 +188,16 @@ class ConvTunerSimple(ConvTunerSimpleBase): ...@@ -188,10 +188,16 @@ class ConvTunerSimple(ConvTunerSimpleBase):
cudadevrt_p = get_cudadevrt_path() cudadevrt_p = get_cudadevrt_path()
assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt" assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt"
cudadevrt = str(cudadevrt_p) cudadevrt = str(cudadevrt_p)
# mod = CummNVRTCModule([kernel],
# cudadevrt_path=cudadevrt,
# verbose=True,
# custom_names=custom_names,
# verbose_path="/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8")
mod = CummNVRTCModule([kernel], mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt, cudadevrt_path=cudadevrt,
verbose=False, verbose=False,
custom_names=custom_names) custom_names=custom_names)
mod.load() mod.load()
return mod, kernel return mod, kernel
......
...@@ -18,10 +18,10 @@ from typing import List ...@@ -18,10 +18,10 @@ from typing import List
import pccm import pccm
from pccm.utils import project_is_editable, project_is_installed from pccm.utils import project_is_editable, project_is_installed
from ccimport.compat import InWindows from ccimport.compat import InWindows
from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT, SPCONV_INT8_DEBUG
if project_is_installed(PACKAGE_NAME) and project_is_editable( if project_is_installed(PACKAGE_NAME) and project_is_editable(
PACKAGE_NAME) and not DISABLE_JIT and False: PACKAGE_NAME) and not DISABLE_JIT and not SPCONV_INT8_DEBUG:
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
......
...@@ -116,3 +116,5 @@ SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1 ...@@ -116,3 +116,5 @@ SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
SPCONV_ALLOW_TF32 = False SPCONV_ALLOW_TF32 = False
SPCONV_INT8_DEBUG = False
\ No newline at end of file
...@@ -19,7 +19,7 @@ from cumm.gemm.algospec.core import TensorOp ...@@ -19,7 +19,7 @@ from cumm.gemm.algospec.core import TensorOp
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo
from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout, from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType) ConvLayoutType, ConvMode, ConvOpType)
from spconv.constants import NDIM_DONT_CARE from spconv.constants import NDIM_DONT_CARE, SPCONV_INT8_DEBUG
class ConvAlgo(Enum): class ConvAlgo(Enum):
...@@ -39,18 +39,18 @@ class AlgoHint(Enum): ...@@ -39,18 +39,18 @@ class AlgoHint(Enum):
# TODO two step build: build gemm kernels first, then bind for every python # TODO two step build: build gemm kernels first, then bind for every python
SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], "", # *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), # 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s8,s32,s32"], "", # *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), # 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], # *gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"],
"", 2, kernel.GemmAlgo.SimtDP4A, None), # "", 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params( # *gen_shuffle_params(
(128, 128, 32), # (128, 128, 32),
(64, 32, 32), ["s8,s8,s8,s32,s32"], "", 2, # (64, 32, 32), ["s8,s8,s8,s32,s32"], "", 2,
kernel.GemmAlgo.SimtDP4A, None), # kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s8,s32,s32"], "", # *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), # 2, kernel.GemmAlgo.SimtDP4A, None),
*gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], *gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# *gen_shuffle_params( # *gen_shuffle_params(
...@@ -164,39 +164,39 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ ...@@ -164,39 +164,39 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
(64, 128, 32), (64, 128, 32),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s8,s32,s32"], "", # *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s8,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), # 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params(
(128, 128, 32),
(32, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, # (32, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
# kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), # TensorOp((8, 8, 16))),
*gen_shuffle_params( # # *gen_shuffle_params(
(128, 256, 32), # # (128, 128, 32),
(64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing, # # (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
TensorOp((8, 8, 16))), # # kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( # *gen_shuffle_params(
(256, 128, 32), # (128, 256, 32),
(64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing, # (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOp((8, 8, 16))), # TensorOp((8, 8, 16))),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s8,s32,s32"], "", # *gen_shuffle_params(
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), # (256, 128, 32),
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], "", # (64, 64, 32), ["s8,s8,s8,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))), # TensorOp((8, 8, 16))),
# *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s8,s32,s32"], "",
# 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
# *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s8,s32,s32"], "",
# 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
] ]
SHUFFLE_AMPERE_PARAMS = [ SHUFFLE_AMPERE_PARAMS = [
*gen_shuffle_params( # *gen_shuffle_params(
(128, 128, 64), # (128, 128, 64),
(64, 64, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere, # (64, 64, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
TensorOp((8, 8, 16))), # TensorOp((8, 8, 16))),
*gen_shuffle_params( # *gen_shuffle_params(
(128, 64, 64), # (128, 64, 64),
(64, 32, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere, # (64, 32, 64), ["s8,s8,s8,s32,s32"], "", 3, kernel.GemmAlgo.Ampere,
TensorOp((8, 8, 16))), # TensorOp((8, 8, 16))),
] ]
# SHUFFLE_TURING_PARAMS = [] # SHUFFLE_TURING_PARAMS = []
...@@ -619,182 +619,170 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -619,182 +619,170 @@ IMPLGEMM_AMPERE_PARAMS = [
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64), ]
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), if not SPCONV_INT8_DEBUG:
NDIM_DONT_CARE, IMPLGEMM_AMPERE_PARAMS.extend([
ConvIterAlgo.Optimized, *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
[2, 3, 4], NDIM_DONT_CARE,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"], ConvIterAlgo.Optimized,
NHWC, [2, 3, 4],
NHWC, ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
GemmAlgo.Ampere, NHWC,
TensorOp((16, 8, 16)), NHWC,
mask_sparse=True, GemmAlgo.Ampere,
increment_k_first=True, TensorOp((16, 8, 32)),
access_per_vector=1, mask_sparse=True,
is_nvrtc=True, increment_k_first=True,
int8_inference=True), access_per_vector=1,
is_nvrtc=True,
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64), int8_inference=True),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 16)), TensorOp((16, 8, 16)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 16)), TensorOp((16, 8, 32)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 16)), TensorOp((16, 8, 16)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 32)), TensorOp((16, 8, 32)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 32)), TensorOp((16, 8, 32)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 32)), TensorOp((16, 8, 32)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 32)), TensorOp((16, 8, 32)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3], [2, 3],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 32)), TensorOp((16, 8, 32)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
] ])
IMPLGEMM_TURING_PARAMS = [ IMPLGEMM_TURING_PARAMS = [
...@@ -828,151 +816,6 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -828,151 +816,6 @@ IMPLGEMM_TURING_PARAMS = [
access_per_vector=0, access_per_vector=0,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16), *gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
...@@ -1228,6 +1071,153 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -1228,6 +1071,153 @@ IMPLGEMM_TURING_PARAMS = [
# gen_conv_params(ConvFwdAndBwdInput, ) # gen_conv_params(ConvFwdAndBwdInput, )
] ]
if not SPCONV_INT8_DEBUG:
IMPLGEMM_TURING_PARAMS.extend([
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
])
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS
...@@ -14,7 +14,8 @@ from spconv.pytorch.conv import (SparseConv1d, SparseConv2d, SparseConv3d, ...@@ -14,7 +14,8 @@ from spconv.pytorch.conv import (SparseConv1d, SparseConv2d, SparseConv3d,
SubMConv3d, SubMConv4d) SubMConv3d, SubMConv4d)
from spconv.pytorch.identity import Identity from spconv.pytorch.identity import Identity
from spconv.pytorch.modules import (SparseModule, SparseSequential, from spconv.pytorch.modules import (SparseModule, SparseSequential,
assign_name_for_sparse_modules) assign_name_for_sparse_modules, SparseBatchNorm,
SparseReLU, SparseIdentity)
from spconv.pytorch.ops import ConvAlgo from spconv.pytorch.ops import ConvAlgo
from spconv.pytorch.pool import (SparseMaxPool1d, SparseMaxPool2d, from spconv.pytorch.pool import (SparseMaxPool1d, SparseMaxPool2d,
SparseMaxPool3d, SparseMaxPool4d, SparseMaxPool3d, SparseMaxPool4d,
......
...@@ -157,35 +157,10 @@ class SparseConvolutionBase: ...@@ -157,35 +157,10 @@ class SparseConvolutionBase:
batch_size = input.batch_size batch_size = input.batch_size
bias_for_training = bias if training else None bias_for_training = bias if training else None
bias_for_infer = bias if not training else None bias_for_infer = bias if not training else None
output_add_scale = 1.0 output_add_scale = 0.0
if is_int8: if is_int8:
if add_input is not None: if add_input is not None:
output_add_scale = add_input.q_scale() output_add_scale = add_input.q_scale()
# if self.enable_int8_test_mode:
# assert not self.training, "must in eval mode"
# assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
# assert bias_for_infer is not None, "conv-bn-relu must be fused"
# assert self._int8_input_scale is not None
# if features.dtype != torch.int8:
# # quantize
# features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
# output_scale = self._int8_output_scale
# int8_out_scale = output_scale
# if int8_out_scale is None:
# int8_out_scale = 1
# if add_input is not None:
# assert add_input.int8_scale is not None, "only support int8 add"
# output_add_scale = add_input.int8_scale
# if self._int8_weight.numel() == 0:
# with torch.no_grad():
# assert ALL_WEIGHT_IS_KRSC
# weight_scale = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
# num_1s = [1] * (self.ndim + 1)
# self._int8_weight = (weight / weight_scale.view(self.out_channels, *num_1s) * 127).to(torch.int8)
# if self._int8_weight_scale.numel() == 0:
# self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scale)
# self._int8_bias = bias_for_infer * int8_out_scale
if training: if training:
msg = "act don't support backward, only used in inference" msg = "act don't support backward, only used in inference"
assert self.act_type == tv.gemm.Activation.None_, msg assert self.act_type == tv.gemm.Activation.None_, msg
...@@ -340,9 +315,9 @@ class SparseConvolutionBase: ...@@ -340,9 +315,9 @@ class SparseConvolutionBase:
algo, algo,
input._timer, input._timer,
bias_for_infer, bias_for_infer,
self.act_alpha, act_alpha,
self.act_beta, act_beta,
self.act_type) act_type)
else: else:
if self.inverse: if self.inverse:
out_features = Fsp.indice_inverse_conv( out_features = Fsp.indice_inverse_conv(
...@@ -354,9 +329,9 @@ class SparseConvolutionBase: ...@@ -354,9 +329,9 @@ class SparseConvolutionBase:
algo, algo,
input._timer, input._timer,
bias_for_infer, bias_for_infer,
self.act_alpha, act_alpha,
self.act_beta, act_beta,
self.act_type) act_type)
else: else:
out_features = Fsp.indice_conv( out_features = Fsp.indice_conv(
features, features,
...@@ -367,10 +342,9 @@ class SparseConvolutionBase: ...@@ -367,10 +342,9 @@ class SparseConvolutionBase:
algo, algo,
input._timer, input._timer,
bias_for_infer, bias_for_infer,
self.act_alpha, act_type,
self.act_beta, act_beta,
self.act_type) act_type)
else: else:
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
if datas is not None: if datas is not None:
...@@ -490,9 +464,9 @@ class SparseConvolutionBase: ...@@ -490,9 +464,9 @@ class SparseConvolutionBase:
num_activate_out, masks, training, self.subm, num_activate_out, masks, training, self.subm,
input._timer, self.fp32_accum, input._timer, self.fp32_accum,
bias_cur, bias_cur,
self.act_alpha, act_alpha,
self.act_beta, act_beta,
self.act_type) act_type)
else: else:
output_dtype = None output_dtype = None
if output_scale is None: if output_scale is None:
...@@ -503,9 +477,9 @@ class SparseConvolutionBase: ...@@ -503,9 +477,9 @@ class SparseConvolutionBase:
num_activate_out, masks, training, self.subm, num_activate_out, masks, training, self.subm,
input._timer, self.fp32_accum, input._timer, self.fp32_accum,
bias_cur, bias_cur,
self.act_alpha, act_alpha,
self.act_beta, act_beta,
self.act_type, act_type,
# TODO do we really need output scale to scale bias in kernel? # TODO do we really need output scale to scale bias in kernel?
1.0 if output_scale is None else output_scale, # output_scale 1.0 if output_scale is None else output_scale, # output_scale
channel_scale, # scale channel_scale, # scale
...@@ -764,446 +738,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule): ...@@ -764,446 +738,6 @@ class SparseConvolution(SparseConvolutionBase, SparseModule):
name=self.name, sparse_unique_name=self._sparse_unique_name, act_type=self.act_type, name=self.name, sparse_unique_name=self._sparse_unique_name, act_type=self.act_type,
act_alpha=self.act_alpha, act_beta=self.act_beta) act_alpha=self.act_alpha, act_beta=self.act_beta)
# def _conv_forward(self, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None,
# channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None):
# assert isinstance(input, SparseConvTensor)
# is_int8 = input.is_quantized and weight.is_quantized
# if is_int8:
# assert output_scale is not None and channel_scale is not None, "int8 must be called in static quantized module"
# assert bias is not None, "currently you must specify a bias"
# assert input.features.shape[
# 1] == self.in_channels, "channel size mismatch"
# features = input.features
# device = features.device
# indices = input.indices
# spatial_shape = input.spatial_shape
# batch_size = input.batch_size
# bias_for_training = bias if self.training else None
# bias_for_infer = bias if not self.training else None
# output_add_scale = 1.0
# if is_int8:
# if add_input is not None:
# output_add_scale = add_input.q_scale()
# # if self.enable_int8_test_mode:
# # assert not self.training, "must in eval mode"
# # assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
# # assert bias_for_infer is not None, "conv-bn-relu must be fused"
# # assert self._int8_input_scale is not None
# # if features.dtype != torch.int8:
# # # quantize
# # features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
# # output_scale = self._int8_output_scale
# # int8_out_scale = output_scale
# # if int8_out_scale is None:
# # int8_out_scale = 1
# # if add_input is not None:
# # assert add_input.int8_scale is not None, "only support int8 add"
# # output_add_scale = add_input.int8_scale
# # if self._int8_weight.numel() == 0:
# # with torch.no_grad():
# # assert ALL_WEIGHT_IS_KRSC
# # weight_scale = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
# # num_1s = [1] * (self.ndim + 1)
# # self._int8_weight = (weight / weight_scale.view(self.out_channels, *num_1s) * 127).to(torch.int8)
# # if self._int8_weight_scale.numel() == 0:
# # self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scale)
# # self._int8_bias = bias_for_infer * int8_out_scale
# if self.training:
# msg = "act don't support backward, only used in inference"
# assert self.act_type == tv.gemm.Activation.None_, msg
# if not self.subm:
# if self.transposed:
# out_spatial_shape = ops.get_deconv_output_size(
# spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding)
# else:
# out_spatial_shape = ops.get_conv_output_size(
# spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation)
# else:
# out_spatial_shape = spatial_shape
# # print(self._sparse_unique_name, spatial_shape, out_spatial_shape)
# # input.update_grid(out_spatial_shape)
# # t = time.time()
# out_tensor = input.shadow_copy()
# if input.benchmark:
# if self.name is None:
# raise ValueError(
# "you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
# )
# if self.name not in input.benchmark_record:
# input.benchmark_record[self.name] = {
# "type": "SparseConvolution",
# "indice_gen_time": [],
# "time": [],
# "num_points": [],
# "num_out_points": [],
# "params": {
# "kernel_size": self.kernel_size,
# "stride": self.stride,
# "padding": self.padding,
# "dilation": self.dilation,
# "output_padding": self.output_padding,
# "subm": self.subm,
# "transposed": self.transposed,
# "input_channels": self.in_channels,
# "out_channels": self.out_channels,
# }
# }
# if self.conv1x1 and not is_int8:
# # in int8 test mode, we don't implement conv1x1 via mm.
# if FILTER_HWIO:
# features = torch.mm(
# input.features,
# weight.view(self.out_channels, self.in_channels).T)
# else:
# features = torch.mm(
# input.features,
# weight.view(self.in_channels, self.out_channels))
# if bias is not None:
# features += bias
# out_tensor = out_tensor.replace_feature(features)
# # padding may change spatial shape of conv 1x1.
# out_tensor.spatial_shape = out_spatial_shape
# return out_tensor
# indice_dict = input.indice_dict.copy()
# # only support contiguous tensor for now
# if not features.is_contiguous():
# features = features.contiguous()
# algo = self.algo
# if self.indice_key is not None:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
# assert algo == datas.algo, msg
# # algo = datas.algo
# profile_ctx = nullcontext()
# if input._timer is not None and self._sparse_unique_name:
# profile_ctx = input._timer.namespace(self._sparse_unique_name)
# with profile_ctx:
# if algo == ConvAlgo.Native:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# assert isinstance(datas, IndiceData)
# if self.inverse:
# assert datas is not None and self.indice_key is not None
# assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
# outids = datas.indices
# indice_pairs = datas.indice_pairs
# indice_pair_num = datas.indice_pair_num
# out_spatial_shape = datas.spatial_shape
# self._check_inverse_reuse_valid(input, spatial_shape,
# datas)
# else:
# if self.indice_key is not None and datas is not None:
# outids = datas.out_indices
# indice_pairs = datas.indice_pairs
# indice_pair_num = datas.indice_pair_num
# assert self.subm, "only support reuse subm indices"
# self._check_subm_reuse_valid(input, spatial_shape,
# datas)
# else:
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# try:
# outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
# indices, batch_size, spatial_shape, algo,
# self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding, self.subm,
# self.transposed)
# except Exception as e:
# msg = "[Exception|native_pair]"
# msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
# msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
# msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
# msg += f"transpose={self.transposed}"
# print(msg, file=sys.stderr)
# spconv_save_debug_data(indices)
# raise e
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[
# self.name]["indice_gen_time"].append(interval)
# indice_data = IndiceData(outids,
# indices,
# indice_pairs,
# indice_pair_num,
# spatial_shape,
# out_spatial_shape,
# is_subm=self.subm,
# algo=algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation)
# if self.indice_key is not None:
# msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
# assert self.indice_key not in indice_dict, msg
# indice_dict[self.indice_key] = indice_data
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# indice_pairs_calc = indice_pairs
# if indice_pairs.device != features.device:
# indice_pairs_calc = indice_pairs.to(features.device)
# if self.subm:
# out_features = Fsp.indice_subm_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# if self.inverse:
# out_features = Fsp.indice_inverse_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# out_features = Fsp.indice_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# assert isinstance(datas, ImplicitGemmIndiceData)
# if self.inverse:
# assert datas is not None and self.indice_key is not None
# assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
# outids = datas.indices
# pair_fwd = datas.pair_bwd
# pair_bwd = datas.pair_fwd
# pair_mask_fwd_splits = datas.pair_mask_bwd_splits
# pair_mask_bwd_splits = datas.pair_mask_fwd_splits
# mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
# mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
# masks = datas.masks
# out_spatial_shape = datas.spatial_shape
# # assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
# self._check_inverse_reuse_valid(input, spatial_shape,
# datas)
# else:
# if self.indice_key is not None and datas is not None:
# outids = datas.out_indices
# pair_fwd = datas.pair_fwd
# pair_bwd = datas.pair_bwd
# pair_mask_fwd_splits = datas.pair_mask_fwd_splits
# pair_mask_bwd_splits = datas.pair_mask_bwd_splits
# mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
# mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
# masks = datas.masks
# assert self.subm, "only support reuse subm indices"
# self._check_subm_reuse_valid(input, spatial_shape,
# datas)
# else:
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# with input._timer.namespace("gen_pairs"):
# # we need to gen bwd indices for regular conv
# # because it may be inversed.
# try:
# res = ops.get_indice_pairs_implicit_gemm(
# indices,
# batch_size,
# spatial_shape,
# algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation,
# out_padding=self.output_padding,
# subm=self.subm,
# transpose=self.transposed,
# is_train=(not self.subm) or self.training,
# alloc=input.thrust_allocator,
# timer=input._timer)
# except Exception as e:
# msg = "[Exception|implicit_gemm_pair]"
# msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
# msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
# msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
# msg += f"transpose={self.transposed}"
# print(msg, file=sys.stderr)
# spconv_save_debug_data(indices)
# raise e
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[
# self.name]["indice_gen_time"].append(interval)
# outids = res[0]
# num_inds_per_loc = res[1]
# pair_fwd = res[2]
# pair_bwd = res[3]
# pair_mask_fwd_splits = res[4]
# pair_mask_bwd_splits = res[5]
# mask_argsort_fwd_splits = res[6]
# mask_argsort_bwd_splits = res[7]
# masks = res[8]
# if self.indice_key is not None:
# indice_data = ImplicitGemmIndiceData(
# outids,
# indices,
# pair_fwd,
# pair_bwd,
# pair_mask_fwd_splits=pair_mask_fwd_splits,
# pair_mask_bwd_splits=pair_mask_bwd_splits,
# mask_argsort_fwd_splits=mask_argsort_fwd_splits,
# mask_argsort_bwd_splits=mask_argsort_bwd_splits,
# masks=masks,
# is_subm=self.subm,
# spatial_shape=spatial_shape,
# out_spatial_shape=out_spatial_shape,
# algo=algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation)
# msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
# assert self.indice_key not in indice_dict, msg
# indice_dict[self.indice_key] = indice_data
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# num_activate_out = outids.shape[0]
# weight_cur = weight
# bias_cur = bias_for_infer
# # if self.enable_int8_test_mode:
# # assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
# # weight_cur = self._int8_weight
# # bias_cur = self._int8_bias
# if self.training:
# out_features = Fsp.implicit_gemm(
# features, weight_cur, pair_fwd, pair_bwd,
# pair_mask_fwd_splits, pair_mask_bwd_splits,
# mask_argsort_fwd_splits, mask_argsort_bwd_splits,
# num_activate_out, masks, self.training, self.subm,
# input._timer, self.fp32_accum,
# bias_cur,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# output_dtype = None
# if output_scale is None:
# output_dtype = weight.dtype
# out_features, _, _ = ops.implicit_gemm(
# features, weight_cur, pair_fwd, pair_mask_fwd_splits,
# mask_argsort_fwd_splits,
# num_activate_out, masks, self.training, self.subm,
# input._timer, self.fp32_accum,
# bias_cur,
# self.act_alpha,
# self.act_beta,
# self.act_type,
# # TODO do we really need output scale to scale bias in kernel?
# 1.0 if output_scale is None else output_scale, # output_scale
# channel_scale, # scale
# output_add=add_input.features if add_input is not None else None,
# output_add_scale=output_add_scale,
# output_dtype=output_dtype)
# if bias_for_training is not None:
# out_features += bias_for_training
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[self.name]["time"].append(interval)
# out_tensor.benchmark_record[self.name]["num_points"].append(
# features.shape[0])
# out_tensor.benchmark_record[self.name]["num_out_points"].append(
# out_features.shape[0])
# if not self.subm and not self.inverse and self.record_voxel_count:
# if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
# ops.maximum_value_int_(
# getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING),
# outids.shape[0])
# out_tensor = out_tensor.replace_feature(out_features)
# out_tensor.indices = outids
# out_tensor.indice_dict = indice_dict
# out_tensor.spatial_shape = out_spatial_shape
# if add_input is not None and not is_int8:
# # in int8, we apply add + act in kernel.
# out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
# return out_tensor
# def _check_subm_reuse_valid(self, inp: SparseConvTensor,
# spatial_shape: List[int],
# datas: Union[ImplicitGemmIndiceData,
# IndiceData]):
# assert datas.is_subm, "only support reuse subm indices"
# if self.kernel_size != datas.ksize:
# raise ValueError(
# f"subm with same indice_key must have same kernel"
# f" size, expect {datas.ksize}, this layer {self.kernel_size}")
# if self.dilation != datas.dilation:
# raise ValueError(
# f"subm with same indice_key must have same dilation"
# f", expect {datas.dilation}, this layer {self.dilation}")
# if inp.spatial_shape != datas.spatial_shape:
# raise ValueError(
# f"subm with same indice_key must have same spatial structure"
# f", expect {datas.spatial_shape}, input {spatial_shape}")
# if inp.indices.shape[0] != datas.indices.shape[0]:
# raise ValueError(
# f"subm with same indice_key must have same num of indices"
# f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
# )
# def _check_inverse_reuse_valid(self, inp: SparseConvTensor,
# spatial_shape: List[int],
# datas: Union[ImplicitGemmIndiceData,
# IndiceData]):
# if self.kernel_size != datas.ksize:
# raise ValueError(
# f"Inverse with same indice_key must have same kernel"
# f" size, expect {datas.ksize}, this layer {self.kernel_size}, "
# "please check Inverse Convolution in docs/USAGE.md.")
# if inp.spatial_shape != datas.out_spatial_shape:
# raise ValueError(
# f"Inverse with same indice_key must have same spatial structure (spatial shape)"
# f", expect {datas.spatial_shape}, input {spatial_shape}, "
# "please check Inverse Convolution in docs/USAGE.md.")
# if inp.indices.shape[0] != datas.out_indices.shape[0]:
# raise ValueError(
# f"Inverse with same indice_key must have same num of indices"
# f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, "
# "please check Inverse Convolution in ."
# )
class SparseConv1d(SparseConvolution): class SparseConv1d(SparseConvolution):
def __init__(self, def __init__(self,
......
...@@ -233,6 +233,9 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -233,6 +233,9 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
features_th = x_sp.values() features_th = x_sp.values()
return cls(features_th, indices_th, spatial_shape, batch_size) return cls(features_th, indices_th, spatial_shape, batch_size)
def dequantize(self):
return self.replace_feature(self.features.dequantize())
@property @property
def spatial_size(self): def spatial_size(self):
return np.prod(self.spatial_shape) return np.prod(self.spatial_shape)
...@@ -264,6 +267,19 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -264,6 +267,19 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
# return self.indices.shape[0] / np.prod( # return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size # self.spatial_shape) / self.batch_size
def __add__(self, other: "SparseConvTensor"):
assert isinstance(other, SparseConvTensor)
return self.replace_feature(self.features + other.features)
def __iadd__(self, other: "SparseConvTensor"):
assert isinstance(other, SparseConvTensor)
self.features += other.features
return self
def __radd__(self, other: "SparseConvTensor"):
assert isinstance(other, SparseConvTensor)
return other.replace_feature(self.features + other.features)
def shadow_copy(self) -> "SparseConvTensor": def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged""" """create a new spconv tensor with all member unchanged"""
tensor = SparseConvTensor(self.features, self.indices, tensor = SparseConvTensor(self.features, self.indices,
......
...@@ -23,7 +23,7 @@ from spconv import pytorch as spconv ...@@ -23,7 +23,7 @@ from spconv import pytorch as spconv
def is_spconv_module(module): def is_spconv_module(module):
spconv_modules = (SparseModule, ) spconv_modules = (SparseModule, SparseBatchNorm, SparseReLU)
return isinstance(module, spconv_modules) return isinstance(module, spconv_modules)
...@@ -148,3 +148,37 @@ def assign_name_for_sparse_modules(module: nn.Module): ...@@ -148,3 +148,37 @@ def assign_name_for_sparse_modules(module: nn.Module):
for k, n in module.named_modules(): for k, n in module.named_modules():
if isinstance(n, SparseModule): if isinstance(n, SparseModule):
n._sparse_unique_name = k n._sparse_unique_name = k
class SparseBatchNorm(nn.BatchNorm1d):
"""this module is exists only for torch.fx transformation for quantization.
"""
def forward(self, input):
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features))
return super().forward(input)
class SparseSyncBatchNorm(nn.SyncBatchNorm):
"""this module is exists only for torch.fx transformation for quantization.
"""
def forward(self, input):
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features))
return super().forward(input)
class SparseReLU(nn.ReLU):
"""this module is exists only for torch.fx transformation for quantization.
"""
def forward(self, input):
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features))
return super().forward(input)
class SparseIdentity(nn.Identity):
"""this module is exists only for torch.fx transformation for quantization.
"""
def forward(self, input):
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features))
return super().forward(input)
...@@ -1462,14 +1462,14 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1462,14 +1462,14 @@ def implicit_gemm(features: torch.Tensor,
output_scale: float = 1.0, output_scale: float = 1.0,
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
output_add: Optional[torch.Tensor] = None, output_add: Optional[torch.Tensor] = None,
output_add_scale: float = 1.0, output_add_scale: float = 0.0,
output_dtype: Optional[torch.dtype] = None): output_dtype: Optional[torch.dtype] = None):
stream = get_current_stream() stream = get_current_stream()
bias_tv = tv.Tensor() bias_tv = tv.Tensor()
scale_tv = tv.Tensor() scale_tv = tv.Tensor()
output_add_tv = tv.Tensor() output_add_tv = tv.Tensor()
if output_add is not None: if output_add is not None:
assert features.dtype == torch.int8, "fused residual add only support int8" assert features.dtype == torch.qint8, "fused residual add only support int8"
if bias is not None: if bias is not None:
bias_tv = torch_tensor_to_tv(bias) bias_tv = torch_tensor_to_tv(bias)
if scale is not None: if scale is not None:
...@@ -1485,7 +1485,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1485,7 +1485,7 @@ def implicit_gemm(features: torch.Tensor,
output_dtype = features.dtype output_dtype = features.dtype
if SPCONV_CPP_GEMM and CONV_CPP is not None: if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device) alloc = TorchAllocator(features.device, features.dtype == torch.qint8)
features_tv = torch_tensor_to_tv(features) features_tv = torch_tensor_to_tv(features)
pair_fwd_tv = torch_tensor_to_tv(pair_fwd) pair_fwd_tv = torch_tensor_to_tv(pair_fwd)
pair_mask_fwd_splits_tv = [ pair_mask_fwd_splits_tv = [
...@@ -1963,9 +1963,15 @@ def indice_maxpool_implicit_gemm(features: torch.Tensor, ...@@ -1963,9 +1963,15 @@ def indice_maxpool_implicit_gemm(features: torch.Tensor,
features = features.contiguous() features = features.contiguous()
out_channel = features.shape[-1] out_channel = features.shape[-1]
out_features = torch.empty((num_activate_out, out_channel), if features.is_quantized:
dtype=features.dtype, out_features = torch._empty_affine_quantized((num_activate_out, out_channel),
device=features.device) scale=features.q_scale(),
dtype=features.dtype,
device=features.device)
else:
out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype,
device=features.device)
assert features.is_cuda assert features.is_cuda
stream = get_current_stream() stream = get_current_stream()
out_features_tv = torch_tensor_to_tv(out_features) out_features_tv = torch_tensor_to_tv(out_features)
...@@ -2016,9 +2022,16 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor, ...@@ -2016,9 +2022,16 @@ def indice_avgpool_implicit_gemm(features: torch.Tensor,
features = features.contiguous() features = features.contiguous()
out_channel = features.shape[-1] out_channel = features.shape[-1]
out_features = torch.empty((num_activate_out, out_channel), if features.is_quantized:
dtype=features.dtype, out_features = torch._empty_affine_quantized((num_activate_out, out_channel),
device=features.device) scale=features.q_scale(),
dtype=features.dtype,
device=features.device)
else:
out_features = torch.empty((num_activate_out, out_channel),
dtype=features.dtype,
device=features.device)
assert features.is_cuda assert features.is_cuda
stream = get_current_stream() stream = get_current_stream()
out_features_tv = torch_tensor_to_tv(out_features) out_features_tv = torch_tensor_to_tv(out_features)
......
...@@ -66,14 +66,14 @@ class SparseMaxPool(SparseModule): ...@@ -66,14 +66,14 @@ class SparseMaxPool(SparseModule):
if algo is None: if algo is None:
# keep in mind that this algorithm is set for Inverse Sparse Conv # keep in mind that this algorithm is set for Inverse Sparse Conv
# maxpool itself don't need mask. # maxpool itself don't need mask.
if kv <= 32 and not CPU_ONLY_BUILD: if kv <= 128 and not CPU_ONLY_BUILD:
if kv < 8: if kv < 8:
algo = ConvAlgo.MaskImplicitGemm algo = ConvAlgo.MaskImplicitGemm
else: else:
algo = ConvAlgo.MaskImplicitGemm algo = ConvAlgo.MaskImplicitGemm
else: else:
algo = ConvAlgo.Native algo = ConvAlgo.Native
if kv > 32: if kv > 128:
assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now" assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
if CPU_ONLY_BUILD: if CPU_ONLY_BUILD:
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm" assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
...@@ -96,7 +96,10 @@ class SparseMaxPool(SparseModule): ...@@ -96,7 +96,10 @@ class SparseMaxPool(SparseModule):
return None return None
def forward(self, input): def forward(self, input: spconv.SparseConvTensor):
is_int8 = input.is_quantized
if is_int8:
assert self.algo == ConvAlgo.MaskImplicitGemm, "only ConvAlgo.MaskImplicitGemm support int8."
assert isinstance(input, spconv.SparseConvTensor) assert isinstance(input, spconv.SparseConvTensor)
features = input.features features = input.features
device = features.device device = features.device
...@@ -296,6 +299,10 @@ class SparseAvgPool(SparseModule): ...@@ -296,6 +299,10 @@ class SparseAvgPool(SparseModule):
def forward(self, input): def forward(self, input):
assert isinstance(input, spconv.SparseConvTensor) assert isinstance(input, spconv.SparseConvTensor)
is_int8 = input.is_quantized
if is_int8:
assert self.algo == ConvAlgo.MaskImplicitGemm, "only ConvAlgo.MaskImplicitGemm support int8."
features = input.features features = input.features
device = features.device device = features.device
indices = input.indices indices = input.indices
...@@ -534,3 +541,8 @@ class SparseAvgPool3d(SparseAvgPool): ...@@ -534,3 +541,8 @@ class SparseAvgPool3d(SparseAvgPool):
algo=algo, algo=algo,
record_voxel_count=record_voxel_count, record_voxel_count=record_voxel_count,
name=name) name=name)
ALL_POOL_LAYERS = set([
SparseAvgPool3d, SparseAvgPool2d, SparseAvgPool1d, SparseMaxPool1d, SparseMaxPool2d, SparseMaxPool3d, SparseMaxPool4d, SparseAvgPool, SparseMaxPool
])
\ No newline at end of file
...@@ -19,3 +19,4 @@ from .fake_q import (get_default_spconv_trt_ptq_qconfig, ...@@ -19,3 +19,4 @@ from .fake_q import (get_default_spconv_trt_ptq_qconfig,
get_default_spconv_trt_qat_qconfig) get_default_spconv_trt_qat_qconfig)
from .qmapping import (get_spconv_fmod_to_qat_mapping, from .qmapping import (get_spconv_fmod_to_qat_mapping,
get_spconv_qat_to_static_mapping) get_spconv_qat_to_static_mapping)
from .core import quantize_per_tensor
\ No newline at end of file
from collections import namedtuple from collections import namedtuple
from typing import List, Dict, Union, Type, Tuple import operator
from typing import Dict, List, Tuple, Type, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.ao.quantization.fx.match_utils import (
MatchAllNode, )
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
from torch.ao.quantization.backend_config import (BackendConfig, from torch.ao.quantization.backend_config import (BackendConfig,
BackendPatternConfig, BackendPatternConfig,
DTypeConfig, ObservationType, DTypeConfig, ObservationType,
...@@ -11,56 +15,184 @@ from torch.ao.quantization.backend_config import (BackendConfig, ...@@ -11,56 +15,184 @@ from torch.ao.quantization.backend_config import (BackendConfig,
from torch.ao.quantization.fx.custom_config import (ConvertCustomConfig, from torch.ao.quantization.fx.custom_config import (ConvertCustomConfig,
FuseCustomConfig, FuseCustomConfig,
PrepareCustomConfig) PrepareCustomConfig)
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
import spconv.pytorch.conv as sconvmod import spconv.pytorch.conv as sconvmod
from spconv.pytorch.modules import SparseBatchNorm, SparseIdentity, SparseReLU, SparseSyncBatchNorm
import spconv.pytorch.quantization.intrinsic as snni import spconv.pytorch.quantization.intrinsic as snni
import spconv.pytorch.quantization.intrinsic.qat as snniqat import spconv.pytorch.quantization.intrinsic.qat as snniqat
import spconv.pytorch.quantization.intrinsic.quantized as snniq import spconv.pytorch.quantization.intrinsic.quantized as snniq
import spconv.pytorch.quantization.quantized as snnq import spconv.pytorch.quantization.quantized as snnq
import spconv.pytorch.quantization.quantized.reference as snnqr import spconv.pytorch.quantization.quantized.reference as snnqr
from spconv.pytorch import ToDense
from spconv.pytorch.constants import PYTORCH_VERSION from spconv.pytorch.constants import PYTORCH_VERSION
from spconv.pytorch.pool import ALL_POOL_LAYERS
from spconv.pytorch.quantization.fuse_mapping import (fuse_conv_bn, from spconv.pytorch.quantization.fuse_mapping import (fuse_conv_bn,
fuse_conv_bn_relu) fuse_conv_bn_relu,
from spconv.pytorch import ToDense fuse_conv_bn_add_relu)
_SpConvMetadataDef = namedtuple( _SpConvMetadataDef = namedtuple("_ConvMetadata", [
"_ConvMetadata", "root", "bn", "reference", "fused_conv_relu", "fused_conv_bn",
["root", "bn", "reference", "fused_conv_bn_relu", "fused_conv_add_relu", "fused_conv_bn_add_relu",
"fused_conv_relu", "fused_conv_bn", "fused_conv_bn_relu", "qat", "relu_qat", "bn_qat", "bn_relu_qat", "add_relu_qat",
"qat", "relu_qat", "bn_qat", "bn_relu_qat"]) "bn_add_relu_qat"
])
_SpConvMetadatas: List[_SpConvMetadataDef] = [] _SpConvMetadatas: List[_SpConvMetadataDef] = []
for t in sconvmod.DEFAULT_SPARSE_CONV_TYPES: for t in sconvmod.DEFAULT_SPARSE_CONV_TYPES:
_SpConvMetadatas.append(_SpConvMetadataDef(t, nn.BatchNorm1d, _SpConvMetadatas.append(
snnqr.SpConv, _SpConvMetadataDef(t, nn.BatchNorm1d, snnqr.SpConv, snni.SpconvReLUNd,
snni.SpconvReLUNd, snni.SpconvBnNd, snni.SpconvBnReLUNd, snni.SpconvBnNd, snni.SpconvBnReLUNd,
snniqat.SparseConv, snniqat.SparseConvReLU, snniqat.SparseConvBn, snniqat.SparseConvBnReLU)) snni.SpconvAddReLUNd, snni.SpconvBnAddReLUNd,
snniqat.SparseConv, snniqat.SparseConvReLU,
snniqat.SparseConvBn, snniqat.SparseConvBnReLU,
snniqat.SparseConvAddReLU,
snniqat.SparseConvBnAddReLU))
_SpConvMetadatas.append(
_SpConvMetadataDef(sconvmod.SparseConvolution, nn.BatchNorm1d,
snnqr.SpConv, snni.SpconvReLUNd, snni.SpconvBnNd,
snni.SpconvBnReLUNd, snni.SpconvAddReLUNd,
snni.SpconvBnAddReLUNd, snniqat.SparseConv,
snniqat.SparseConvReLU, snniqat.SparseConvBn,
snniqat.SparseConvBnReLU, snniqat.SparseConvAddReLU,
snniqat.SparseConvBnAddReLU))
_SpConvMetadatas.append(_SpConvMetadataDef(
sconvmod.SparseConvolution, nn.BatchNorm1d,
snnqr.SpConv,
snni.SpconvReLUNd, snni.SpconvBnNd, snni.SpconvBnReLUNd,
snniqat.SparseConv, snniqat.SparseConvReLU, snniqat.SparseConvBn, snniqat.SparseConvBnReLU))
def _sequential_wrapper2(sequential): def _sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes """ Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules and always returns the sequential that combines the two input modules
""" """
def fuser_method(is_qat, m1, m2): def fuser_method(is_qat, m1, m2):
return sequential(m1, m2) return sequential(m1, m2)
return fuser_method return fuser_method
# new cfg remove reverse pattern. # new cfg remove reverse pattern.
def _get_spconv_configs(dtype_configs):
def _conv_bn_res_relu_root_node_getter(pattern):
relu, add_pattern = pattern
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def _conv_bn_res_relu_extra_inputs_getter(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu, add_pattern = pattern
_, bn_pattern, extra_input = add_pattern
bn, conv = bn_pattern
return [extra_input]
def _conv_res_relu_root_node_getter(pattern):
relu, add_pattern = pattern
_, conv, _ = add_pattern
return conv
def _conv_res_relu_extra_inputs_getter(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu, add_pattern = pattern
_, conv, extra_input = add_pattern
return [extra_input]
def _get_bn_spconv_configs(bn_cls, dtype_configs):
""" """
Return all configs related to conv modules and ops. Return all configs related to conv modules and ops.
""" """
conv_configs = [] conv_configs = []
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
if PYTORCH_VERSION <= [1, 13, 1]: if PYTORCH_VERSION[:2] <= [1, 13]:
from torch.ao.quantization.fuser_method_mappings import (
reverse2, reverse3, reverse_sequential_wrapper2)
for convs in _SpConvMetadatas:
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs.append(
BackendPatternConfig((bn_cls, convs.root)).set_dtype_configs(
dtype_configs) # noqa: E131
.set_fuser_method(reverse2(fuse_conv_bn)).set_fused_module(
convs.fused_conv_bn))
# conv + bn + relu module fusion
for relu_type in [torch.nn.ReLU, F.relu, SparseReLU]:
conv_configs.append(
BackendPatternConfig(
(relu_type, (bn_cls, convs.root))).set_dtype_configs(
dtype_configs) # noqa: E131
.set_fuser_method(
reverse3(fuse_conv_bn_relu)).set_fused_module(
convs.fused_conv_bn_relu))
# 5.1 fuse conv + bn + add + relu to one op
for add_op in [torch.add, operator.add]:
for relu_op in [SparseReLU]:
conv_configs.append(
BackendPatternConfig((relu_op, (add_op, (bn_cls, convs.root), MatchAllNode)))
.set_dtype_configs(dtype_configs)
# .set_root_module(convs.root)
.set_fuser_method(fuse_conv_bn_add_relu) \
._set_root_node_getter(_conv_bn_res_relu_root_node_getter) \
._set_extra_inputs_getter(_conv_bn_res_relu_extra_inputs_getter)
.set_fused_module(convs.fused_conv_bn_add_relu))
return conv_configs
else:
for convs in _SpConvMetadatas:
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs.append(
BackendPatternConfig(
(convs.root,
bn_cls)).set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(fuse_conv_bn).set_fused_module(
convs.fused_conv_bn))
# conv + bn + relu module fusion
for relu_type in [torch.nn.ReLU, F.relu, SparseReLU]:
conv_configs.append(
BackendPatternConfig(
(convs.root, bn_cls, relu_type)).set_dtype_configs(
dtype_configs) # noqa: E131
.set_fuser_method(fuse_conv_bn_relu).set_fused_module(
convs.fused_conv_bn_relu))
# (5) conv add and its fusion
# 5.1 fuse conv + bn + add + relu to one op
for add_op in [torch.add, operator.add]:
for relu_op in [SparseReLU]:
conv_configs.append(
BackendPatternConfig() \
._set_pattern_complex_format((relu_op, (add_op, (bn_cls, convs.root), MatchAllNode)))
.set_dtype_configs(dtype_configs)
# .set_root_module(convs.root)
.set_fuser_method(fuse_conv_bn_add_relu) \
._set_root_node_getter(_conv_bn_res_relu_root_node_getter) \
._set_extra_inputs_getter(_conv_bn_res_relu_extra_inputs_getter)
.set_fused_module(convs.fused_conv_bn_add_relu))
return conv_configs
def _get_spconv_configs(dtype_configs):
"""
Return all configs related to conv modules and ops.
"""
conv_configs = (_get_bn_spconv_configs(SparseBatchNorm, dtype_configs) +
_get_bn_spconv_configs(nn.BatchNorm1d, dtype_configs) +
_get_bn_spconv_configs(SparseSyncBatchNorm, dtype_configs))
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
if PYTORCH_VERSION[:2] <= [1, 13]:
from torch.ao.quantization.fuser_method_mappings import ( from torch.ao.quantization.fuser_method_mappings import (
reverse2, reverse3, reverse_sequential_wrapper2) reverse2, reverse3, reverse_sequential_wrapper2)
for convs in _SpConvMetadatas: for convs in _SpConvMetadatas:
...@@ -68,114 +200,90 @@ def _get_spconv_configs(dtype_configs): ...@@ -68,114 +200,90 @@ def _get_spconv_configs(dtype_configs):
# ----------------------------------- # -----------------------------------
# conv module # conv module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.root) BackendPatternConfig(convs.root).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference) convs.reference).set_qat_module(convs.qat))
.set_qat_module(convs.qat))
# conv qat module # conv qat module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.qat) BackendPatternConfig(convs.qat).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# (2) Conv + relu # (2) Conv + relu
# ----------------- # -----------------
# 2.1 conv module + relu fusion configs # 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module # conv relu fusion, conv module + relu module
conv_configs.append( for relu_type in [torch.nn.ReLU, F.relu, SparseReLU]:
BackendPatternConfig((torch.nn.ReLU, convs.root)) conv_configs.append(
.set_dtype_configs(dtype_configs) # noqa: E131 BackendPatternConfig(
.set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) (relu_type, convs.root)).set_dtype_configs(
.set_fused_module(convs.fused_conv_relu)) dtype_configs) # noqa: E131
# conv relu fusion, conv module + functional relu .set_fuser_method(
conv_configs.append( reverse_sequential_wrapper2(
BackendPatternConfig((F.relu, convs.root)) convs.fused_conv_relu)).set_fused_module(
.set_dtype_configs(dtype_configs) # noqa: E131 convs.fused_conv_relu))
.set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# 2.2 conv module + relu fused module configs # 2.2 conv module + relu fused module configs
# conv relu, fused module # conv relu, fused module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu) BackendPatternConfig(
.set_observation_type(observation_type) # noqa: E131 convs.fused_conv_relu).set_observation_type(
.set_dtype_configs(dtype_configs) observation_type) # noqa: E131
.set_root_module(convs.root) .set_dtype_configs(dtype_configs).set_root_module(
.set_reference_quantized_module(convs.reference) convs.root).set_reference_quantized_module(
.set_qat_module(convs.relu_qat)) convs.reference).set_qat_module(convs.relu_qat))
# conv relu, qat fused module # conv relu, qat fused module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.relu_qat) BackendPatternConfig(convs.relu_qat).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# 2.3 functional conv + relu configs # 2.3 functional conv + relu configs
# fused conv relu # fused conv relu
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu) BackendPatternConfig(convs.fused_conv_relu).set_dtype_configs(
.set_dtype_configs(dtype_configs) # noqa: E131 dtype_configs) # noqa: E131
.set_qat_module(convs.relu_qat)) .set_qat_module(convs.relu_qat))
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.relu_qat) BackendPatternConfig(convs.relu_qat).set_dtype_configs(
.set_dtype_configs(dtype_configs) # noqa: E131 dtype_configs) # noqa: E131
.set_root_module(convs.root) .set_root_module(convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs.append(
BackendPatternConfig((convs.bn, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse2(fuse_conv_bn))
.set_fused_module(convs.fused_conv_bn))
# conv + bn + relu module fusion
conv_configs.append(
BackendPatternConfig((nn.ReLU, (convs.bn, convs.root)))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse3(fuse_conv_bn_relu))
.set_fused_module(convs.fused_conv_bn_relu))
# conv + bn + relu functional fusion
conv_configs.append(
BackendPatternConfig((F.relu, (convs.bn, convs.root)))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_fuser_method(reverse3(fuse_conv_bn_relu))
.set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well # TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs # 3.2 conv + bn (+ relu) fused module configs
# fused conv bn # fused conv bn
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn) BackendPatternConfig(convs.fused_conv_bn).set_dtype_configs(
.set_dtype_configs(dtype_configs) # noqa: E131 dtype_configs) # noqa: E131
.set_qat_module(convs.bn_qat)) .set_qat_module(convs.bn_qat))
# fused conv bn relu # fused conv bn relu
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn_relu) BackendPatternConfig(
.set_dtype_configs(dtype_configs) # noqa: E131 convs.fused_conv_bn_relu).set_dtype_configs(
.set_qat_module(convs.bn_relu_qat)) dtype_configs) # noqa: E131
.set_qat_module(convs.bn_relu_qat))
# conv bn, qat fused module # conv bn, qat fused module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.bn_qat) BackendPatternConfig(convs.bn_qat).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# conv bn relu, qat fused module # conv bn relu, qat fused module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.bn_relu_qat) BackendPatternConfig(convs.bn_relu_qat).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# (4) conv transpose and its fusion # (4) conv transpose and its fusion
# 4.1 conv transpose config # 4.1 conv transpose config
...@@ -192,6 +300,56 @@ def _get_spconv_configs(dtype_configs): ...@@ -192,6 +300,56 @@ def _get_spconv_configs(dtype_configs):
# .set_fuser_method(reverse2(fuse_conv_bn)) # .set_fuser_method(reverse2(fuse_conv_bn))
# .set_root_module(convs.transpose) # .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference)) # .set_reference_quantized_module(convs.transpose_reference))
# (5) conv add and its fusion
# 5.1 fuse conv + bn + add + relu to one op
for add_op in [torch.add, operator.add]:
for relu_op in [SparseReLU]:
conv_configs.append(
BackendPatternConfig((relu_op, (add_op, convs.root, MatchAllNode)))
.set_dtype_configs(dtype_configs)
# .set_root_module(convs.root)
.set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_add_relu)) \
._set_root_node_getter(_conv_res_relu_root_node_getter) \
._set_extra_inputs_getter(_conv_res_relu_extra_inputs_getter)
.set_fused_module(convs.fused_conv_add_relu))
# 5.2 fused add
# fused conv bn relu
conv_configs.append(
BackendPatternConfig(
convs.fused_conv_add_relu).set_dtype_configs(
dtype_configs) # noqa: E131
.set_qat_module(convs.add_relu_qat))
conv_configs.append(
BackendPatternConfig(
convs.fused_conv_bn_add_relu).set_dtype_configs(
dtype_configs) # noqa: E131
.set_qat_module(convs.bn_add_relu_qat))
conv_configs.append(
BackendPatternConfig(
convs.fused_conv_add_relu).set_observation_type(
observation_type) # noqa: E131
.set_dtype_configs(dtype_configs).set_root_module(
convs.root).set_reference_quantized_module(
convs.reference).set_qat_module(convs.add_relu_qat))
# conv bn, qat fused module
conv_configs.append(
BackendPatternConfig(convs.add_relu_qat).set_observation_type(
observation_type) # noqa: E131
.set_dtype_configs(dtype_configs).set_root_module(
convs.root).set_reference_quantized_module(
convs.reference))
conv_configs.append(
BackendPatternConfig(
convs.bn_add_relu_qat).set_observation_type(
observation_type) # noqa: E131
.set_dtype_configs(dtype_configs).set_root_module(
convs.root).set_reference_quantized_module(
convs.reference))
return conv_configs return conv_configs
else: else:
for convs in _SpConvMetadatas: for convs in _SpConvMetadatas:
...@@ -199,114 +357,102 @@ def _get_spconv_configs(dtype_configs): ...@@ -199,114 +357,102 @@ def _get_spconv_configs(dtype_configs):
# ----------------------------------- # -----------------------------------
# conv module # conv module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.root) BackendPatternConfig(convs.root).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference) convs.reference).set_qat_module(convs.qat))
.set_qat_module(convs.qat))
# conv qat module # conv qat module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.qat) BackendPatternConfig(convs.qat).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# (2) Conv + relu # (2) Conv + relu
# ----------------- # -----------------
# 2.1 conv module + relu fusion configs # 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module # conv relu fusion, conv module + relu module
conv_configs.append( for relu_type in [torch.nn.ReLU, F.relu, SparseReLU]:
BackendPatternConfig((convs.root, torch.nn.ReLU)) conv_configs.append(
.set_dtype_configs(dtype_configs) # noqa: E131 BackendPatternConfig(
.set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) (convs.root, relu_type)).set_dtype_configs(
.set_fused_module(convs.fused_conv_relu)) dtype_configs) # noqa: E131
# conv relu fusion, conv module + functional relu .set_fuser_method(
conv_configs.append( _sequential_wrapper2(
BackendPatternConfig((convs.root, F.relu)) convs.fused_conv_relu)).set_fused_module(
.set_dtype_configs(dtype_configs) # noqa: E131 convs.fused_conv_relu))
.set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# 2.2 conv module + relu fused module configs # 2.2 conv module + relu fused module configs
# conv relu, fused module # conv relu, fused module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu) BackendPatternConfig(
.set_observation_type(observation_type) # noqa: E131 convs.fused_conv_relu).set_observation_type(
.set_dtype_configs(dtype_configs) observation_type) # noqa: E131
.set_root_module(convs.root) .set_dtype_configs(dtype_configs).set_root_module(
.set_reference_quantized_module(convs.reference) convs.root).set_reference_quantized_module(
.set_qat_module(convs.relu_qat)) convs.reference).set_qat_module(convs.relu_qat))
# conv relu, qat fused module # conv relu, qat fused module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.relu_qat) BackendPatternConfig(convs.relu_qat).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# fused conv relu # fused conv relu
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu) BackendPatternConfig(convs.fused_conv_relu).set_dtype_configs(
.set_dtype_configs(dtype_configs) # noqa: E131 dtype_configs) # noqa: E131
.set_qat_module(convs.relu_qat)) .set_qat_module(convs.relu_qat))
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.relu_qat) BackendPatternConfig(convs.relu_qat).set_dtype_configs(
.set_dtype_configs(dtype_configs) # noqa: E131 dtype_configs) # noqa: E131
.set_root_module(convs.root) .set_root_module(convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# (3) Conv + batchnorm (+ relu) # (3) Conv + batchnorm (+ relu)
# ------------------------------- # -------------------------------
# 3.1 conv bn fusion configs # 3.1 conv bn fusion configs
# conv + bn fusion # conv + bn fusion
conv_configs.append(
BackendPatternConfig((convs.root, convs.bn)) # # conv + bn + relu functional fusion
.set_dtype_configs(dtype_configs) # noqa: E131 # conv_configs.append(
.set_fuser_method(fuse_conv_bn) # BackendPatternConfig((convs.root, convs.bn, F.relu))
.set_fused_module(convs.fused_conv_bn)) # .set_dtype_configs(dtype_configs) # noqa: E131
# conv + bn + relu module fusion # .set_root_module(convs.root)
conv_configs.append( # .set_fuser_method(fuse_conv_bn_relu)
BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) # .set_fused_module(convs.fused_conv_bn_relu))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(fuse_conv_bn_relu)
.set_fused_module(convs.fused_conv_bn_relu))
# conv + bn + relu functional fusion
conv_configs.append(
BackendPatternConfig((convs.root, convs.bn, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_fuser_method(fuse_conv_bn_relu)
.set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well # TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs # 3.2 conv + bn (+ relu) fused module configs
# fused conv bn # fused conv bn
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn) BackendPatternConfig(convs.fused_conv_bn).set_dtype_configs(
.set_dtype_configs(dtype_configs) # noqa: E131 dtype_configs) # noqa: E131
.set_qat_module(convs.bn_qat)) .set_qat_module(convs.bn_qat))
# fused conv bn relu # fused conv bn relu
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn_relu) BackendPatternConfig(
.set_dtype_configs(dtype_configs) # noqa: E131 convs.fused_conv_bn_relu).set_dtype_configs(
.set_qat_module(convs.bn_relu_qat)) dtype_configs) # noqa: E131
.set_qat_module(convs.bn_relu_qat))
# conv bn, qat fused module # conv bn, qat fused module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.bn_qat) BackendPatternConfig(convs.bn_qat).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# conv bn relu, qat fused module # conv bn relu, qat fused module
conv_configs.append( conv_configs.append(
BackendPatternConfig(convs.bn_relu_qat) BackendPatternConfig(convs.bn_relu_qat).set_observation_type(
.set_observation_type(observation_type) # noqa: E131 observation_type) # noqa: E131
.set_dtype_configs(dtype_configs) .set_dtype_configs(dtype_configs).set_root_module(
.set_root_module(convs.root) convs.root).set_reference_quantized_module(
.set_reference_quantized_module(convs.reference)) convs.reference))
# # (4) conv transpose and its fusion # # (4) conv transpose and its fusion
# # 4.1 conv transpose config # # 4.1 conv transpose config
...@@ -323,38 +469,123 @@ def _get_spconv_configs(dtype_configs): ...@@ -323,38 +469,123 @@ def _get_spconv_configs(dtype_configs):
# .set_fuser_method(fuse_conv_bn) # .set_fuser_method(fuse_conv_bn)
# .set_root_module(convs.transpose) # .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference)) # .set_reference_quantized_module(convs.transpose_reference))
# (5) conv add and its fusion
# 5.1 fuse conv + bn + add + relu to one op
for add_op in [torch.add, operator.add]:
for relu_op in [SparseReLU]:
conv_configs.append(
BackendPatternConfig() \
._set_pattern_complex_format((relu_op, (add_op, convs.root, MatchAllNode)))
.set_dtype_configs(dtype_configs)
# .set_root_module(convs.root)
.set_fuser_method(_sequential_wrapper2(convs.fused_conv_add_relu)) \
._set_root_node_getter(_conv_res_relu_root_node_getter) \
._set_extra_inputs_getter(_conv_res_relu_extra_inputs_getter)
.set_fused_module(convs.fused_conv_add_relu))
# 5.2 fused add
# fused conv bn relu
conv_configs.append(
BackendPatternConfig(
convs.fused_conv_add_relu).set_dtype_configs(
dtype_configs) # noqa: E131
.set_qat_module(convs.add_relu_qat))
conv_configs.append(
BackendPatternConfig(
convs.fused_conv_bn_add_relu).set_dtype_configs(
dtype_configs) # noqa: E131
.set_qat_module(convs.bn_add_relu_qat))
# conv bn, qat fused module
conv_configs.append(
BackendPatternConfig(convs.add_relu_qat).set_observation_type(
observation_type) # noqa: E131
.set_dtype_configs(dtype_configs).set_root_module(
convs.root).set_reference_quantized_module(
convs.reference))
conv_configs.append(
BackendPatternConfig(
convs.bn_add_relu_qat).set_observation_type(
observation_type) # noqa: E131
.set_dtype_configs(dtype_configs).set_root_module(
convs.root).set_reference_quantized_module(
convs.reference))
return conv_configs return conv_configs
def _get_share_observer_ops(dtype_configs):
res: List[BackendPatternConfig] = []
_to_dense_cfg = (BackendPatternConfig(ToDense).set_observation_type(
ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT).set_dtype_configs(
dtype_configs))
iden_cfg = (BackendPatternConfig(SparseIdentity).set_observation_type(
ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT).set_dtype_configs(
dtype_configs))
res.append(_to_dense_cfg)
res.append(iden_cfg)
for p in ALL_POOL_LAYERS:
_pool_cfg = (BackendPatternConfig(p).set_observation_type(
ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT).
set_dtype_configs(dtype_configs))
res.append(_pool_cfg)
return res
weighted_op_qint8_dtype_config = DTypeConfig( weighted_op_qint8_dtype_config = DTypeConfig(
input_dtype=torch.qint8, input_dtype=torch.qint8,
output_dtype=torch.qint8, output_dtype=torch.qint8,
weight_dtype=torch.qint8, weight_dtype=torch.qint8,
bias_dtype=torch.float, bias_dtype=torch.float,
) )
non_weighted_op_qint8_dtype_config = DTypeConfig(
input_dtype=torch.qint8,
output_dtype=torch.qint8,
)
conv_dtype_configs = [ conv_dtype_configs = [
weighted_op_qint8_dtype_config, weighted_op_qint8_dtype_config,
] ]
_to_dense_cfg = (BackendPatternConfig(ToDense)
.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT))
backend_config = get_tensorrt_backend_config() \ backend_config = get_tensorrt_backend_config() \
.set_backend_pattern_configs(_get_spconv_configs(conv_dtype_configs) + [_to_dense_cfg]) .set_backend_pattern_configs(_get_spconv_configs(conv_dtype_configs) + _get_share_observer_ops([non_weighted_op_qint8_dtype_config]))
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[
Type[nn.Module], Type[WeightedQuantizedModule]]] = {
snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU),
snni.SpconvAddReLUNd: (snnqr.SpConv, snniq.SparseConvAddReLU),
}
SPCONV_STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module],
Type[WeightedQuantizedModule]] = {
snnqr.SpConv: snnq.SparseConv,
}
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU),
}
def get_spconv_backend_config(): def get_spconv_backend_config():
return backend_config return backend_config
def get_spconv_prepare_custom_config(): def get_spconv_prepare_custom_config():
cfg = PrepareCustomConfig() cfg = PrepareCustomConfig()
cfg.non_traceable_module_classes = [*sconvmod.DEFAULT_SPARSE_CONV_TYPES] cfg.non_traceable_module_classes = [*sconvmod.DEFAULT_SPARSE_CONV_TYPES]
return cfg cfg.non_traceable_module_classes.extend(
[SparseReLU, SparseBatchNorm, SparseSyncBatchNorm])
return cfg
def get_spconv_convert_custom_config(): def get_spconv_convert_custom_config():
cfg = ConvertCustomConfig() cfg = ConvertCustomConfig()
cfg.set_observed_to_quantized_mapping(snni.SpconvReLUNd, snniq.SparseConvReLU) cfg.set_observed_to_quantized_mapping(snni.SpconvReLUNd,
snniq.SparseConvReLU)
cfg.set_observed_to_quantized_mapping(snni.SpconvAddReLUNd,
snniq.SparseConvReLU)
# cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU) # cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU)
return cfg return cfg
\ No newline at end of file
from typing import Union, List, Dict
import torch
from spconv.pytorch.core import SparseConvTensor
def quantize_per_tensor(ten: Union[Union[SparseConvTensor, torch.Tensor], List[Union[SparseConvTensor, torch.Tensor]]], scale, zero_point, dtype):
if isinstance(ten, (list, tuple)):
res = []
for i, v in enumerate(ten):
if isinstance(v, SparseConvTensor):
res.append(v.replace_feature(torch.quantize_per_tensor(v.features, scale[i], zero_point[i], dtype)))
else:
res.append(torch.quantize_per_tensor(v, scale[i], zero_point[i], dtype))
return res
else:
if isinstance(ten, SparseConvTensor):
return ten.replace_feature(torch.quantize_per_tensor(ten.features, scale, zero_point, dtype))
else:
return torch.quantize_per_tensor(ten, scale, zero_point, dtype)
\ No newline at end of file
...@@ -11,7 +11,7 @@ from torch.ao.quantization.observer import (HistogramObserver, ...@@ -11,7 +11,7 @@ from torch.ao.quantization.observer import (HistogramObserver,
from torch.ao.quantization.qconfig import QConfig, QConfigAny, default_reuse_input_qconfig from torch.ao.quantization.qconfig import QConfig, QConfigAny, default_reuse_input_qconfig
from torch.ao.quantization.qconfig_mapping import QConfigMapping, _FIXED_QPARAMS_OP_TO_OBSERVER from torch.ao.quantization.qconfig_mapping import QConfigMapping, _FIXED_QPARAMS_OP_TO_OBSERVER
from typing import Any, Callable, Dict, Tuple, Union, List from typing import Any, Callable, Dict, Tuple, Union, List
from torch.ao.quantization import get_default_qconfig from torch.ao.quantization import get_default_qconfig, get_default_qat_qconfig
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
__all__ = ["get_default_spconv_trt_ptq_qconfig", "get_default_spconv_trt_qat_qconfig"] __all__ = ["get_default_spconv_trt_ptq_qconfig", "get_default_spconv_trt_qat_qconfig"]
...@@ -80,13 +80,14 @@ def get_default_spconv_trt_ptq_qconfig(backend, version): ...@@ -80,13 +80,14 @@ def get_default_spconv_trt_ptq_qconfig(backend, version):
def get_default_spconv_trt_qat_qconfig(backend, version): def get_default_spconv_trt_qat_qconfig(backend, version):
return default_symmetric_spconv_qat_qconfig return default_symmetric_spconv_qat_qconfig
def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "x86", version: int = 0) -> QConfigMapping: def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "fbgemm", version: int = 0) -> QConfigMapping:
""" """
From torch.ao.quantization.qconfig_mapping From torch.ao.quantization.qconfig_mapping
Return the default QConfigMapping for the given quantization type and backend. Return the default QConfigMapping for the given quantization type and backend.
""" """
# get_default_qconfig(backend, version) # get_default_qconfig(backend, version)
if is_qat: if is_qat:
# qconfig = get_default_qat_qconfig(backend, version)
qconfig = get_default_spconv_trt_qat_qconfig(backend, version) qconfig = get_default_spconv_trt_qat_qconfig(backend, version)
else: else:
# qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8), # qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8),
...@@ -144,3 +145,4 @@ def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "x86", versi ...@@ -144,3 +145,4 @@ def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "x86", versi
.set_object_type(torch.nn.functional.tanh, qconfig) .set_object_type(torch.nn.functional.tanh, qconfig)
return qconfig_mapping return qconfig_mapping
from functools import partial
from typing import Union, Callable, Tuple, Dict, Optional, Type, Any from typing import Union, Callable, Tuple, Dict, Optional, Type, Any
import torch.nn as nn import torch.nn as nn
import spconv.pytorch as spconv import spconv.pytorch as spconv
...@@ -5,7 +6,8 @@ from .utils import fuse_spconv_bn_eval ...@@ -5,7 +6,8 @@ from .utils import fuse_spconv_bn_eval
from . import intrinsic as snni from . import intrinsic as snni
from .intrinsic.qat.modules import SparseConvBn, SparseConvBnReLU, SparseConvBnAddReLU from .intrinsic.qat.modules import SparseConvBn, SparseConvBnReLU, SparseConvBnAddReLU
from spconv.pytorch.conv import DEFAULT_SPARSE_CONV_TYPES from spconv.pytorch.conv import DEFAULT_SPARSE_CONV_TYPES
def fuse_conv_bn(is_qat, conv, bn):
def fuse_conv_bn(is_qat, conv, bn, is_add_fuse: bool = False):
r"""Given the conv and bn modules, fuses them and returns the fused module r"""Given the conv and bn modules, fuses them and returns the fused module
Args: Args:
...@@ -20,11 +22,10 @@ def fuse_conv_bn(is_qat, conv, bn): ...@@ -20,11 +22,10 @@ def fuse_conv_bn(is_qat, conv, bn):
""" """
assert(conv.training == bn.training),\ assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)." "Conv and BN both must be in the same mode (train or eval)."
fuse_cls = snni.SpconvAddReLUNd if is_add_fuse else snni.SpconvBnNd
fused_module_class_map = { fused_module_class_map = {
k: snni.SpconvBnNd for k in DEFAULT_SPARSE_CONV_TYPES k: fuse_cls for k in DEFAULT_SPARSE_CONV_TYPES
} }
if is_qat: if is_qat:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
...@@ -37,7 +38,7 @@ def fuse_conv_bn(is_qat, conv, bn): ...@@ -37,7 +38,7 @@ def fuse_conv_bn(is_qat, conv, bn):
else: else:
return fuse_spconv_bn_eval(conv, bn) return fuse_spconv_bn_eval(conv, bn)
def fuse_conv_bn_relu(is_qat, conv, bn, relu): def fuse_conv_bn_relu(is_qat, conv, bn, relu, is_add_fuse: bool = False):
r"""Given the conv and bn modules, fuses them and returns the fused module r"""Given the conv and bn modules, fuses them and returns the fused module
Args: Args:
...@@ -54,8 +55,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): ...@@ -54,8 +55,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
"Conv and BN both must be in the same mode (train or eval)." "Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[spconv.SparseSequential]] = None fused_module : Optional[Type[spconv.SparseSequential]] = None
if is_qat: if is_qat:
fuse_cls = snni.SpconvBnAddReLUNd if is_add_fuse else snni.SpconvBnReLUNd
map_to_fused_module_train = { map_to_fused_module_train = {
k: snni.SpconvBnReLUNd for k in DEFAULT_SPARSE_CONV_TYPES k: fuse_cls for k in DEFAULT_SPARSE_CONV_TYPES
} }
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True' assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
...@@ -66,8 +68,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): ...@@ -66,8 +68,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
else: else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu))) raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else: else:
fuse_cls = snni.SpconvAddReLUNd if is_add_fuse else snni.SpconvReLUNd
map_to_fused_module_eval = { map_to_fused_module_eval = {
k: snni.SpconvReLUNd for k in DEFAULT_SPARSE_CONV_TYPES k: fuse_cls for k in DEFAULT_SPARSE_CONV_TYPES
} }
fused_module = map_to_fused_module_eval.get(type(conv), None) fused_module = map_to_fused_module_eval.get(type(conv), None)
if fused_module is not None: if fused_module is not None:
...@@ -76,28 +79,21 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): ...@@ -76,28 +79,21 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
else: else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
# DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
# (spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# }
# def get_spconv_fuse_method_mapping(): def fuse_conv_bn_add_relu(is_qat, relu, add_pattern):
# return DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return fuse_conv_bn_relu(is_qat, conv, bn, relu, True)
# Default map for swapping float module to qat modules
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .modules import SpconvBnNd, SpconvBnReLUNd, SpconvBnAddReLUNd, SpconvReLUNd from .modules import SpconvBnNd, SpconvBnReLUNd, SpconvBnAddReLUNd, SpconvReLUNd, SpconvAddReLUNd
...@@ -60,3 +60,27 @@ class SpconvBnAddReLUNd(_FusedSparseModule): ...@@ -60,3 +60,27 @@ class SpconvBnAddReLUNd(_FusedSparseModule):
isinstance(relu, ReLU), 'Incorrect types for input modules{}{}{}' \ isinstance(relu, ReLU), 'Incorrect types for input modules{}{}{}' \
.format(type(conv), type(bn), type(relu)) .format(type(conv), type(bn), type(relu))
super().__init__(conv, bn, relu) super().__init__(conv, bn, relu)
def forward(self, input, add_input):
conv = self[0]
bn = self[1]
relu = self[2]
conv_res = conv(input)
conv_res = conv_res.replace_feature(bn(conv_res.features))
return conv_res.replace_feature(relu(conv_res.features + add_input.features))
class SpconvAddReLUNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert isinstance(conv, SparseConvolution) and isinstance(relu, ReLU), \
'Incorrect types for input modules{}{}'.format(
type(conv), type(relu))
super().__init__(conv, relu)
def forward(self, input, add_input):
conv = self[0]
relu = self[1]
conv_res = conv(input)
return conv_res.replace_feature(relu(conv_res.features + add_input.features))
...@@ -12,4 +12,5 @@ ...@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .modules import SparseConvBn, SparseConvBnAddReLU, SparseConvBnReLU, SparseConv, SparseConvReLU from .modules import (SparseConv, SparseConvAddReLU, SparseConvBn,
\ No newline at end of file SparseConvBnAddReLU, SparseConvBnReLU, SparseConvReLU)
...@@ -17,7 +17,7 @@ import spconv.pytorch.quantization.intrinsic as snni ...@@ -17,7 +17,7 @@ import spconv.pytorch.quantization.intrinsic as snni
from spconv.pytorch.quantization.utils import fuse_spconv_bn_weights from spconv.pytorch.quantization.utils import fuse_spconv_bn_weights
MOD = TypeVar('MOD', bound=SparseConvolution) MOD = TypeVar('MOD', bound=SparseConvolution)
class _SparseConv(SparseConvolution, nni._FusedModule): class _SparseConv(SparseConvolution):
_FLOAT_MODULE = MOD _FLOAT_MODULE = MOD
_FLOAT_CONV_MODULE = SparseConvolution _FLOAT_CONV_MODULE = SparseConvolution
...@@ -67,7 +67,7 @@ class _SparseConv(SparseConvolution, nni._FusedModule): ...@@ -67,7 +67,7 @@ class _SparseConv(SparseConvolution, nni._FusedModule):
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input): def forward(self, input):
return self._conv_forward(False, input, self.weight_fake_quant(self.weight), self.bias) return self._conv_forward(self.training, input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod @staticmethod
def from_float(cls, mod): def from_float(cls, mod):
...@@ -77,11 +77,12 @@ class _SparseConv(SparseConvolution, nni._FusedModule): ...@@ -77,11 +77,12 @@ class _SparseConv(SparseConvolution, nni._FusedModule):
`mod`: a float module, either produced by torch.ao.quantization utilities `mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user or directly from user
""" """
assert type(mod) == cls._FLOAT_MODULE, ( assert issubclass(type(mod), cls._FLOAT_MODULE), (
"qat." "qat."
+ cls.__name__ + cls.__name__
+ ".from_float only works for " + ".from_float only works for "
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
+ f" not {type(mod).__qualname__}"
) )
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig' assert mod.qconfig, 'Input float module must have a valid qconfig'
...@@ -197,6 +198,33 @@ class SparseConvReLU(SparseConv, nni._FusedModule): ...@@ -197,6 +198,33 @@ class SparseConvReLU(SparseConv, nni._FusedModule):
def from_float(cls, mod): def from_float(cls, mod):
return super(SparseConvReLU, cls).from_float(mod) return super(SparseConvReLU, cls).from_float(mod)
class SparseConvAddReLU(SparseConv, nni._FusedModule):
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = snni.SpconvAddReLUNd
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = None
_FLOAT_RELU_MODULE = nn.ReLU
def forward(self, input, add_input):
x = self._conv_forward(self.training, input, self.weight_fake_quant(self.weight), self.bias,
add_input=add_input)
return x.replace_feature(F.relu(x.features))
@classmethod
def from_float(cls, mod):
return super(SparseConvAddReLU, cls).from_float(mod)
class _SparseConvBn(SparseConvolution, nni._FusedModule): class _SparseConvBn(SparseConvolution, nni._FusedModule):
_version = 2 _version = 2
...@@ -323,9 +351,9 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule): ...@@ -323,9 +351,9 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.features.dtype) zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.features.dtype)
conv_spt = self._conv_forward(self.training, input, scaled_weight, zero_bias) conv_spt = self._conv_forward(self.training, input, scaled_weight, zero_bias)
conv = conv_spt.features conv = conv_spt.features
conv_orig = conv / scale_factor.reshape(bias_shape) conv_orig = conv / scale_factor# .reshape(bias_shape)
if self.bias is not None: if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape(bias_shape) conv_orig = conv_orig + self.bias# .reshape(bias_shape)
conv = self.bn(conv_orig) conv = self.bn(conv_orig)
if add_input is not None: if add_input is not None:
conv = conv + add_input.features conv = conv + add_input.features
...@@ -377,7 +405,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule): ...@@ -377,7 +405,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
conv_out = torch.Tensor() conv_out = torch.Tensor()
if self.bn.training: if self.bn.training:
# needed to compute batch mean/std # needed to compute batch mean/std
conv_spt = self._conv_forward(input, self.weight, zero_bias) conv_spt = self._conv_forward(self.training, input, self.weight, zero_bias)
conv_out = conv_spt.features conv_out = conv_spt.features
# update bn statistics # update bn statistics
with torch.no_grad(): with torch.no_grad():
...@@ -393,7 +421,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule): ...@@ -393,7 +421,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
self.weight * scale_factor.reshape(weight_shape) self.weight * scale_factor.reshape(weight_shape)
) )
# fused conv without bias for inference: (r * W / running_std) * X # fused conv without bias for inference: (r * W / running_std) * X
conv_bn_spt = self._conv_forward(input, scaled_weight, zero_bias) conv_bn_spt = self._conv_forward(self.training, input, scaled_weight, zero_bias)
conv_bn = conv_bn_spt.features conv_bn = conv_bn_spt.features
if self.bn.training: if self.bn.training:
avg_dims = [0] + list(range(2, len(self.weight.shape))) avg_dims = [0] + list(range(2, len(self.weight.shape)))
...@@ -669,12 +697,12 @@ class SparseConvBnAddReLU(_SparseConvBn): ...@@ -669,12 +697,12 @@ class SparseConvBnAddReLU(_SparseConvBn):
""" """
# base class defines _FLOAT_MODULE as "ConvBn1d" # base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE = snni.SpconvBnReLUNd # type: ignore[assignment] _FLOAT_MODULE = snni.SpconvBnAddReLUNd # type: ignore[assignment]
_FLOAT_CONV_MODULE = SparseConvolution _FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = nn.BatchNorm1d _FLOAT_BN_MODULE = nn.BatchNorm1d
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
# module class after fusing bn into conv # module class after fusing bn into conv
_FUSED_FLOAT_MODULE = snni.SpconvReLUNd _FUSED_FLOAT_MODULE = snni.SpconvAddReLUNd
def forward(self, input, add_input): def forward(self, input, add_input):
x = _SparseConvBn._forward(self, input, add_input) x = _SparseConvBn._forward(self, input, add_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment