Commit b1c57a31 authored by yan.yan's avatar yan.yan
Browse files

still working on int8

parent aa26c99e
...@@ -23,33 +23,95 @@ from torchvision import datasets, transforms ...@@ -23,33 +23,95 @@ from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
import contextlib import contextlib
import torch.cuda.amp import torch.cuda.amp
import torch.ao.quantization
from torch.ao.quantization import QuantStub, DeQuantStub
import torch.ao.quantization.quantize_fx as qfx
from spconv.pytorch.quantization.fake_q import get_default_spconv_qconfig_mapping
import spconv.pytorch.quantization as spconvq
from spconv.pytorch.quantization import get_default_spconv_trt_ptq_qconfig
from torch.ao.quantization import get_default_qconfig_mapping
from spconv.pytorch.quantization.backend_cfg import SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
from torch.ao.quantization.fx._lower_to_native_backend import STATIC_LOWER_FUSED_MODULE_MAP
@contextlib.contextmanager @contextlib.contextmanager
def identity_ctx(): def identity_ctx():
yield yield
class SubMConvBNReLU(spconv.SparseSequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(SubMConvBNReLU, self).__init__(
spconv.SubMConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm1d(out_planes, momentum=0.1),
# Replace with ReLU
nn.ReLU(inplace=False)
)
class SparseConvBNReLU(spconv.SparseSequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(SparseConvBNReLU, self).__init__(
spconv.SparseConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm1d(out_planes, momentum=0.1),
# Replace with ReLU
nn.ReLU(inplace=False)
)
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.net = spconv.SparseSequential( self.net = spconv.SparseSequential(
nn.BatchNorm1d(1), SubMConvBNReLU(1, 32, 3),
spconv.SubMConv2d(1, 32, 3, 1), SubMConvBNReLU(32, 64, 3),
nn.ReLU(), SparseConvBNReLU(64, 64, 2, 2),
spconv.SubMConv2d(32, 64, 3, 1), spconv.ToDense(),
nn.ReLU(), )
spconv.SparseConv2d(64, 64, 2, 2), self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x_sp: spconv.SparseConvTensor):
# def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
# x = self.quant(x)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
# x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp)
x = torch.flatten(x, 1)
x = self.dropout1(x)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
# x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetV2(nn.Module):
def __init__(self):
super(NetV2, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2),
spconv.ToDense(), spconv.ToDense(),
) )
self.fc1 = nn.Linear(14 * 14 * 64, 128) self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10) self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25) self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5) self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x: torch.Tensor): def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor # x: [N, 28, 28, 1], must be NHWC tensor
x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1)) x = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense # create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp) x = self.net(x_sp)
x = torch.flatten(x, 1) x = torch.flatten(x, 1)
...@@ -58,10 +120,93 @@ class Net(nn.Module): ...@@ -58,10 +120,93 @@ class Net(nn.Module):
x = F.relu(x) x = F.relu(x)
x = self.dropout2(x) x = self.dropout2(x)
x = self.fc2(x) x = self.fc2(x)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so
we only use sparse ops here.
"""
def __init__(self):
super(NetPTQ, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14
SparseConvBNReLU(64, 64, 2, 2), # 7x7
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4),
spconv.ToDense(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
features = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp = self.net(x_sp)
# print(x_sp.shape)
x = x_sp
x = torch.flatten(x, 1)
# x_res = torch.zeros_like(x)
# x_res[x_sp.indices[:, 0].long()] = x
# x = x_res
# x = torch.flatten(x, 1)
# x = self.dropout1(x)
# x = self.fc1(x)
# x = F.relu(x)
# x = self.dropout2(x)
# x = self.fc2(x)
# print(x_sp.features.shape, x_sp.spatial_shape)
x = self.dequant(x)
output = F.log_softmax(x, dim=1) output = F.log_softmax(x, dim=1)
return output return output
class NetDense(nn.Module):
def __init__(self):
super(NetDense, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch): def train(args, model, device, train_loader, optimizer, epoch):
model.train() model.train()
scaler = torch.cuda.amp.grad_scaler.GradScaler() scaler = torch.cuda.amp.grad_scaler.GradScaler()
...@@ -72,7 +217,13 @@ def train(args, model, device, train_loader, optimizer, epoch): ...@@ -72,7 +217,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
optimizer.zero_grad() optimizer.zero_grad()
with amp_ctx: with amp_ctx:
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(data.reshape(-1, 28, 28, 1))
# output = model(data_sp)
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
else:
output = model(data) output = model(data)
loss = F.nll_loss(output, target) loss = F.nll_loss(output, target)
scale = 1.0 scale = 1.0
if args.fp16: if args.fp16:
...@@ -114,7 +265,11 @@ def test(args, model, device, test_loader): ...@@ -114,7 +265,11 @@ def test(args, model, device, test_loader):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with amp_ctx: with amp_ctx:
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(data.reshape(-1, 28, 28, 1))
# output = model(data_sp)
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
else:
output = model(data) output = model(data)
test_loss += F.nll_loss( test_loss += F.nll_loss(
output, target, reduction='sum').item() # sum up batch loss output, target, reduction='sum').item() # sum up batch loss
...@@ -131,6 +286,19 @@ def test(args, model, device, test_loader): ...@@ -131,6 +286,19 @@ def test(args, model, device, test_loader):
100. * correct / len(test_loader.dataset))) 100. * correct / len(test_loader.dataset)))
def calibrate(args, model: torch.nn.Module, data_loader, device):
model.eval()
with torch.no_grad():
for image, target in data_loader:
image = image.to(device)
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(image.reshape(-1, 28, 28, 1))
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
# output = model(data_sp)
else:
output = model(image)
def main(): def main():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
...@@ -146,7 +314,7 @@ def main(): ...@@ -146,7 +314,7 @@ def main():
help='input batch size for testing (default: 1000)') help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', parser.add_argument('--epochs',
type=int, type=int,
default=14, default=1,
metavar='N', metavar='N',
help='number of epochs to train (default: 14)') help='number of epochs to train (default: 14)')
parser.add_argument('--lr', parser.add_argument('--lr',
...@@ -168,6 +336,10 @@ def main(): ...@@ -168,6 +336,10 @@ def main():
default=1, default=1,
metavar='S', metavar='S',
help='random seed (default: 1)') help='random seed (default: 1)')
parser.add_argument('--sparse',
action='store_true',
default=True,
help='use sparse conv network instead of dense')
parser.add_argument( parser.add_argument(
'--log-interval', '--log-interval',
type=int, type=int,
...@@ -190,8 +362,14 @@ def main(): ...@@ -190,8 +362,14 @@ def main():
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
if args.sparse:
model = NetPTQ().to(device)
else:
model = NetDense().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST( datasets.MNIST(
'../data', '../data',
...@@ -218,17 +396,46 @@ def main(): ...@@ -218,17 +396,46 @@ def main():
shuffle=True, shuffle=True,
**kwargs) **kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch) train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader) test(args, model, device, test_loader)
scheduler.step() scheduler.step()
# if args.save_model:
# torch.save(model.state_dict(), "mnist_cnn.pt")
model.eval()
STATIC_LOWER_FUSED_MODULE_MAP.update(SPCONV_STATIC_LOWER_FUSED_MODULE_MAP)
if not args.sparse:
model = model.cpu()
# qconfig_mapping_default = get_default_qconfig_mapping("x86")
qconfig_mapping = get_default_spconv_qconfig_mapping(False)
prepare_cfg = spconvq.get_spconv_prepare_custom_config()
backend_cfg = spconvq.get_spconv_backend_config()
convert_cfg = spconvq.get_spconv_convert_custom_config()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# prepared_model.print_readable()
print([type(m) for m in prepared_model.modules()])
print(prepared_model)
# calibrate: run model with some inputs
calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model = qfx.convert_to_reference_fx(prepared_model, convert_cfg, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
print([type(m) for m in converted_model.modules()])
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
# model.qconfig = get_default_spconv_trt_ptq_qconfig()
# prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# convert_custom_config_dict = spconvq.get_convert_custom_config()
# torch.ao.quantization.prepare(model, inplace=True)
# print('Post Training Quantization Prepare: Inserting Observers')
# print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# test(args, model, device, test_loader)
print(converted_model)
if args.save_model: test(args, converted_model, qdevice, test_loader)
torch.save(model.state_dict(), "mnist_cnn.pt") breakpoint()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,7 +21,7 @@ from ccimport.compat import InWindows ...@@ -21,7 +21,7 @@ from ccimport.compat import InWindows
from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT
if project_is_installed(PACKAGE_NAME) and project_is_editable( if project_is_installed(PACKAGE_NAME) and project_is_editable(
PACKAGE_NAME) and not DISABLE_JIT: PACKAGE_NAME) and not DISABLE_JIT and False:
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
......
...@@ -699,6 +699,22 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -699,6 +699,22 @@ IMPLGEMM_AMPERE_PARAMS = [
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
...@@ -797,7 +813,21 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -797,7 +813,21 @@ IMPLGEMM_TURING_PARAMS = [
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
......
...@@ -2,7 +2,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty ...@@ -2,7 +2,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
from pccm.stubs import EnumValue, EnumClassValue from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor from cumm.tensorview import Tensor
class ExternalAllocator: class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor: def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> Tensor:
""" """
Args: Args:
name: name:
...@@ -11,9 +11,10 @@ class ExternalAllocator: ...@@ -11,9 +11,10 @@ class ExternalAllocator:
device: device:
stream: stream:
is_temp_memory: is_temp_memory:
scale:
""" """
... ...
def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor: def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> Tensor:
""" """
Args: Args:
name: name:
...@@ -22,6 +23,7 @@ class ExternalAllocator: ...@@ -22,6 +23,7 @@ class ExternalAllocator:
device: device:
stream: stream:
is_temp_memory: is_temp_memory:
scale:
""" """
... ...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor: def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
......
...@@ -63,7 +63,7 @@ class ConvGemmOps: ...@@ -63,7 +63,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0) -> Tuple[int, Any]: def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0, output_dtype: int = -1) -> Tuple[int, Any]:
""" """
Args: Args:
allocator: allocator:
...@@ -91,6 +91,7 @@ class ConvGemmOps: ...@@ -91,6 +91,7 @@ class ConvGemmOps:
scale: scale:
output_add: output_add:
output_add_scale: output_add_scale:
output_dtype:
""" """
... ...
@staticmethod @staticmethod
......
...@@ -56,7 +56,7 @@ class ExternalAllocator(pccm.Class): ...@@ -56,7 +56,7 @@ class ExternalAllocator(pccm.Class):
code.arg("device", "int") code.arg("device", "int")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false") code.arg("is_temp_memory", "bool", "false")
code.arg("scale", "float", "1.0")
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True) @pccm.pybind.mark(virtual=True)
...@@ -69,7 +69,7 @@ class ExternalAllocator(pccm.Class): ...@@ -69,7 +69,7 @@ class ExternalAllocator(pccm.Class):
code.arg("device", "int") code.arg("device", "int")
code.arg("stream", "std::uintptr_t", "0") code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false") code.arg("is_temp_memory", "bool", "false")
code.arg("scale", "float", "1.0")
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True) @pccm.pybind.mark(virtual=True)
......
...@@ -2127,10 +2127,10 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2127,10 +2127,10 @@ class ConvGemmOps(pccm.ParameterizedClass):
}} }}
if (is_subm){{ if (is_subm){{
out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int); {{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int, false /*is_temp*/, output_scale);
}}else{{ }}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)}, out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int); {{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int, false /*is_temp*/, output_scale);
}} }}
// auto start_ev = tv::CUDAEvent(); // auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int); // start_ev.record(stream_int);
......
...@@ -311,7 +311,7 @@ class IndiceMaxPool(pccm.Class): ...@@ -311,7 +311,7 @@ class IndiceMaxPool(pccm.Class):
code.raw(f""" code.raw(f"""
auto nhot = out_inds.dim(0); auto nhot = out_inds.dim(0);
auto cudastream = reinterpret_cast<cudaStream_t>(stream); auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{ tv::dispatch<float, double, tv::half_t, tv::bfloat16_t, int8_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I); using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1)); auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims); int num_blocks_X = std::get<0>(launchdims);
...@@ -350,7 +350,7 @@ class IndiceMaxPool(pccm.Class): ...@@ -350,7 +350,7 @@ class IndiceMaxPool(pccm.Class):
tv::check_shape(in, {{-1, out.dim(1)}}); tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream); auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{ tv::dispatch<float, double, tv::half_t, tv::bfloat16_t, int8_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I); using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1)); auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims); int num_blocks_X = std::get<0>(launchdims);
...@@ -478,7 +478,7 @@ class IndiceMaxPool(pccm.Class): ...@@ -478,7 +478,7 @@ class IndiceMaxPool(pccm.Class):
tv::check_shape(in, {{-1, out.dim(1)}}); tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream); auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{ tv::dispatch<float, double, tv::half_t, tv::bfloat16_t, int8_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I); using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1)); auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims); int num_blocks_X = std::get<0>(launchdims);
......
This diff is collapsed.
...@@ -128,7 +128,7 @@ def scatter_nd(indices, updates, shape): ...@@ -128,7 +128,7 @@ def scatter_nd(indices, updates, shape):
return ret return ret
# ProxyableClassMeta is used for TensorRT conversion in future. # ProxyableClassMeta is used for torch.fx
class SparseConvTensor(metaclass=SpConvTensorMeta): class SparseConvTensor(metaclass=SpConvTensorMeta):
def __init__(self, def __init__(self,
features: torch.Tensor, features: torch.Tensor,
...@@ -181,8 +181,15 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -181,8 +181,15 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.thrust_allocator = ThrustSortAllocator(features.device) self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer) self._timer = CUDAKernelTimer(enable_timer)
self.force_algo = force_algo self.force_algo = force_algo
# for simple int8 torch inference
self.int8_scale: Optional[float] = None @property
def is_quantized(self):
return self.features.dtype == torch.qint8
def q_scale(self):
if self.is_quantized:
return self.features.q_scale()
raise ValueError("sparse tensor must be quantized")
def replace_feature(self, feature: torch.Tensor): def replace_feature(self, feature: torch.Tensor):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features)) """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
...@@ -220,7 +227,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -220,7 +227,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
x must be NHWC tensor, channel last x must be NHWC tensor, channel last
""" """
x_sp = x.to_sparse(x.ndim - 1) x_sp = x.to_sparse(x.ndim - 1)
spatial_shape = list(x_sp.shape[1:-1]) spatial_shape = x_sp.shape[1:-1]
batch_size = x_sp.shape[0] batch_size = x_sp.shape[0]
indices_th = x_sp.indices().permute(1, 0).contiguous().int() indices_th = x_sp.indices().permute(1, 0).contiguous().int()
features_th = x_sp.values() features_th = x_sp.values()
......
...@@ -34,6 +34,7 @@ _TORCH_DTYPE_TO_TV = { ...@@ -34,6 +34,7 @@ _TORCH_DTYPE_TO_TV = {
torch.int8: tv.int8, torch.int8: tv.int8,
torch.int16: tv.int16, torch.int16: tv.int16,
torch.uint8: tv.uint8, torch.uint8: tv.uint8,
torch.qint8: tv.int8,
} }
_TORCH_UINT_WORKAROUNDS = { _TORCH_UINT_WORKAROUNDS = {
...@@ -42,6 +43,8 @@ _TORCH_UINT_WORKAROUNDS = { ...@@ -42,6 +43,8 @@ _TORCH_UINT_WORKAROUNDS = {
tv.uint64: tv.int64 tv.uint64: tv.int64
} }
_TH_QTYPES = {torch.qint8}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()} _TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TV_DTYPE_TO_TORCH.update({ _TV_DTYPE_TO_TORCH.update({
tv.uint32: torch.int32, tv.uint32: torch.int32,
...@@ -50,6 +53,9 @@ _TV_DTYPE_TO_TORCH.update({ ...@@ -50,6 +53,9 @@ _TV_DTYPE_TO_TORCH.update({
}) })
_TV_DTYPE_TO_TORCHQ = _TV_DTYPE_TO_TORCH.copy()
_TV_DTYPE_TO_TORCHQ[tv.int8] = torch.qint8
_ALL_INTS = { _ALL_INTS = {
tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32, tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
tv.uint16 tv.uint16
...@@ -105,23 +111,31 @@ def get_arch(): ...@@ -105,23 +111,31 @@ def get_arch():
class TorchAllocator(ExternalAllocator): class TorchAllocator(ExternalAllocator):
def __init__(self, gpudevice: torch.device) -> None: def __init__(self, gpudevice: torch.device, is_quantized: bool = False) -> None:
super().__init__() super().__init__()
self.gpudevice = gpudevice self.gpudevice = gpudevice
self.cpudevice = torch.device("cpu") self.cpudevice = torch.device("cpu")
self.allocated: Dict[Union[str, int], torch.Tensor] = {} self.allocated: Dict[Union[str, int], torch.Tensor] = {}
self.is_quantized = is_quantized
self._tv_dtype_to_torch = _TV_DTYPE_TO_TORCH
if is_quantized:
self._tv_dtype_to_torch = _TV_DTYPE_TO_TORCHQ
def zeros(self, name: str, shape: List[int], dtype: int, def zeros(self, name: str, shape: List[int], dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> tv.Tensor:
# TODO free memory by name if its already free by pointer. # TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit. # provide a name if you want to access it after c++ function exit.
dtype_bkp = dtype dtype_bkp = dtype
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = self._tv_dtype_to_torch[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
else: else:
dev = self.gpudevice dev = self.gpudevice
ten = torch.zeros(shape, dtype=th_dtype, device=dev) if self.is_quantized:
ten = torch._empty_affine_quantized(shape, scale=scale, zero_point=0, dtype=th_dtype, device=dev)
else:
ten = torch.empty(shape, dtype=th_dtype, device=dev).zero_()
ten_tv = torch_tensor_to_tv(ten, dtype_bkp) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory: if name and not is_temp_memory:
...@@ -129,13 +143,16 @@ class TorchAllocator(ExternalAllocator): ...@@ -129,13 +143,16 @@ class TorchAllocator(ExternalAllocator):
return ten_tv return ten_tv
def empty(self, name: str, shape: List[int], dtype: int, def empty(self, name: str, shape: List[int], dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor: device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> tv.Tensor:
dtype_bkp = dtype dtype_bkp = dtype
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = self._tv_dtype_to_torch[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
else: else:
dev = self.gpudevice dev = self.gpudevice
if self.is_quantized:
ten = torch._empty_affine_quantized(shape, scale=scale, zero_point=0, dtype=th_dtype, device=dev)
else:
ten = torch.empty(shape, dtype=th_dtype, device=dev) ten = torch.empty(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten, dtype_bkp) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten self.allocated[ten_tv.byte_pointer()] = ten
...@@ -148,11 +165,13 @@ class TorchAllocator(ExternalAllocator): ...@@ -148,11 +165,13 @@ class TorchAllocator(ExternalAllocator):
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0: if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes") raise NotImplementedError("you can't use full for unsigned dtypes")
dtype_bkp = dtype dtype_bkp = dtype
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = self._tv_dtype_to_torch[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
else: else:
dev = self.gpudevice dev = self.gpudevice
if self.is_quantized:
assert th_dtype not in _TH_QTYPES
ten = torch.full(shape, value, dtype=th_dtype, device=dev) ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten, dtype_bkp) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten self.allocated[ten_tv.byte_pointer()] = ten
...@@ -165,11 +184,13 @@ class TorchAllocator(ExternalAllocator): ...@@ -165,11 +184,13 @@ class TorchAllocator(ExternalAllocator):
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0: if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes") raise NotImplementedError("you can't use full for unsigned dtypes")
dtype_bkp = dtype dtype_bkp = dtype
th_dtype = _TV_DTYPE_TO_TORCH[dtype] th_dtype = self._tv_dtype_to_torch[dtype]
if device == -1: if device == -1:
dev = self.cpudevice dev = self.cpudevice
else: else:
dev = self.gpudevice dev = self.gpudevice
if self.is_quantized:
assert th_dtype not in _TH_QTYPES
ten = torch.full(shape, value, dtype=th_dtype, device=dev) ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten, dtype_bkp) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten self.allocated[ten_tv.byte_pointer()] = ten
......
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backend_cfg import (get_spconv_backend_config,
get_spconv_prepare_custom_config,
get_spconv_convert_custom_config)
from .fake_q import (get_default_spconv_trt_ptq_qconfig,
get_default_spconv_trt_qat_qconfig)
from .qmapping import (get_spconv_fmod_to_qat_mapping,
get_spconv_qat_to_static_mapping)
from collections import namedtuple
from typing import List, Dict, Union, Type, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from torch.ao.quantization.backend_config import (BackendConfig,
BackendPatternConfig,
DTypeConfig, ObservationType,
get_tensorrt_backend_config)
from torch.ao.quantization.fx.custom_config import (ConvertCustomConfig,
FuseCustomConfig,
PrepareCustomConfig)
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
import spconv.pytorch.conv as sconvmod
import spconv.pytorch.quantization.intrinsic as snni
import spconv.pytorch.quantization.intrinsic.qat as snniqat
import spconv.pytorch.quantization.intrinsic.quantized as snniq
import spconv.pytorch.quantization.quantized as snnq
import spconv.pytorch.quantization.quantized.reference as snnqr
from spconv.pytorch.constants import PYTORCH_VERSION
from spconv.pytorch.quantization.fuse_mapping import (fuse_conv_bn,
fuse_conv_bn_relu)
from spconv.pytorch import ToDense
_SpConvMetadataDef = namedtuple(
"_ConvMetadata",
["root", "bn", "reference",
"fused_conv_relu", "fused_conv_bn", "fused_conv_bn_relu",
"qat", "relu_qat", "bn_qat", "bn_relu_qat"])
_SpConvMetadatas: List[_SpConvMetadataDef] = []
for t in sconvmod.DEFAULT_SPARSE_CONV_TYPES:
_SpConvMetadatas.append(_SpConvMetadataDef(t, nn.BatchNorm1d,
snnqr.SpConv,
snni.SpconvReLUNd, snni.SpconvBnNd, snni.SpconvBnReLUNd,
snniqat.SparseConv, snniqat.SparseConvReLU, snniqat.SparseConvBn, snniqat.SparseConvBnReLU))
_SpConvMetadatas.append(_SpConvMetadataDef(
sconvmod.SparseConvolution, nn.BatchNorm1d,
snnqr.SpConv,
snni.SpconvReLUNd, snni.SpconvBnNd, snni.SpconvBnReLUNd,
snniqat.SparseConv, snniqat.SparseConvReLU, snniqat.SparseConvBn, snniqat.SparseConvBnReLU))
def _sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules
"""
def fuser_method(is_qat, m1, m2):
return sequential(m1, m2)
return fuser_method
# new cfg remove reverse pattern.
def _get_spconv_configs(dtype_configs):
"""
Return all configs related to conv modules and ops.
"""
conv_configs = []
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
if PYTORCH_VERSION <= [1, 13, 1]:
from torch.ao.quantization.fuser_method_mappings import (
reverse2, reverse3, reverse_sequential_wrapper2)
for convs in _SpConvMetadatas:
# (1) Single conv modules/functions
# -----------------------------------
# conv module
conv_configs.append(
BackendPatternConfig(convs.root)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference)
.set_qat_module(convs.qat))
# conv qat module
conv_configs.append(
BackendPatternConfig(convs.qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs.append(
BackendPatternConfig((torch.nn.ReLU, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# conv relu fusion, conv module + functional relu
conv_configs.append(
BackendPatternConfig((F.relu, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference)
.set_qat_module(convs.relu_qat))
# conv relu, qat fused module
conv_configs.append(
BackendPatternConfig(convs.relu_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# 2.3 functional conv + relu configs
# fused conv relu
conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.relu_qat))
conv_configs.append(
BackendPatternConfig(convs.relu_qat)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs.append(
BackendPatternConfig((convs.bn, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse2(fuse_conv_bn))
.set_fused_module(convs.fused_conv_bn))
# conv + bn + relu module fusion
conv_configs.append(
BackendPatternConfig((nn.ReLU, (convs.bn, convs.root)))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse3(fuse_conv_bn_relu))
.set_fused_module(convs.fused_conv_bn_relu))
# conv + bn + relu functional fusion
conv_configs.append(
BackendPatternConfig((F.relu, (convs.bn, convs.root)))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_fuser_method(reverse3(fuse_conv_bn_relu))
.set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.bn_qat))
# fused conv bn relu
conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn_relu)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.bn_relu_qat))
# conv bn, qat fused module
conv_configs.append(
BackendPatternConfig(convs.bn_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# conv bn relu, qat fused module
conv_configs.append(
BackendPatternConfig(convs.bn_relu_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (4) conv transpose and its fusion
# 4.1 conv transpose config
# conv_configs.append(
# BackendPatternConfig(convs.transpose)
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# # 4.2 conv transpose + bn fusion
# conv_configs.append(
# BackendPatternConfig((convs.bn, convs.transpose))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(reverse2(fuse_conv_bn))
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
return conv_configs
else:
for convs in _SpConvMetadatas:
# (1) Single conv modules/functions
# -----------------------------------
# conv module
conv_configs.append(
BackendPatternConfig(convs.root)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference)
.set_qat_module(convs.qat))
# conv qat module
conv_configs.append(
BackendPatternConfig(convs.qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs.append(
BackendPatternConfig((convs.root, torch.nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# conv relu fusion, conv module + functional relu
conv_configs.append(
BackendPatternConfig((convs.root, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference)
.set_qat_module(convs.relu_qat))
# conv relu, qat fused module
conv_configs.append(
BackendPatternConfig(convs.relu_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# fused conv relu
conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.relu_qat))
conv_configs.append(
BackendPatternConfig(convs.relu_qat)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs.append(
BackendPatternConfig((convs.root, convs.bn))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(fuse_conv_bn)
.set_fused_module(convs.fused_conv_bn))
# conv + bn + relu module fusion
conv_configs.append(
BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(fuse_conv_bn_relu)
.set_fused_module(convs.fused_conv_bn_relu))
# conv + bn + relu functional fusion
conv_configs.append(
BackendPatternConfig((convs.root, convs.bn, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_fuser_method(fuse_conv_bn_relu)
.set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.bn_qat))
# fused conv bn relu
conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn_relu)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.bn_relu_qat))
# conv bn, qat fused module
conv_configs.append(
BackendPatternConfig(convs.bn_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# conv bn relu, qat fused module
conv_configs.append(
BackendPatternConfig(convs.bn_relu_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# # (4) conv transpose and its fusion
# # 4.1 conv transpose config
# conv_configs.append(
# BackendPatternConfig(convs.transpose)
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# # 4.2 conv transpose + bn fusion
# conv_configs.append(
# BackendPatternConfig((convs.transpose, convs.bn))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(fuse_conv_bn)
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
return conv_configs
weighted_op_qint8_dtype_config = DTypeConfig(
input_dtype=torch.qint8,
output_dtype=torch.qint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
conv_dtype_configs = [
weighted_op_qint8_dtype_config,
]
_to_dense_cfg = (BackendPatternConfig(ToDense)
.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT))
backend_config = get_tensorrt_backend_config() \
.set_backend_pattern_configs(_get_spconv_configs(conv_dtype_configs) + [_to_dense_cfg])
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU),
}
def get_spconv_backend_config():
return backend_config
def get_spconv_prepare_custom_config():
cfg = PrepareCustomConfig()
cfg.non_traceable_module_classes = [*sconvmod.DEFAULT_SPARSE_CONV_TYPES]
return cfg
def get_spconv_convert_custom_config():
cfg = ConvertCustomConfig()
cfg.set_observed_to_quantized_mapping(snni.SpconvReLUNd, snniq.SparseConvReLU)
# cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU)
return cfg
\ No newline at end of file
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize, fused_wt_fake_quant_range_neg_127_to_127
from spconv.pytorch.core import SparseConvTensor
import torch import torch
from torch.ao.quantization.qconfig import QConfig from torch.ao.quantization.fake_quantize import (
from torch.ao.quantization.observer import MovingAverageMinMaxObserver FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize, FakeQuantize,
default_fused_per_channel_wt_fake_quant, default_weight_fake_quant, default_per_channel_weight_fake_quant)
from torch.ao.quantization.observer import (HistogramObserver,
MovingAverageMinMaxObserver,
default_weight_observer,
default_placeholder_observer,
default_per_channel_weight_observer)
from torch.ao.quantization.qconfig import QConfig, QConfigAny, default_reuse_input_qconfig
from torch.ao.quantization.qconfig_mapping import QConfigMapping, _FIXED_QPARAMS_OP_TO_OBSERVER
from typing import Any, Callable, Dict, Tuple, Union, List
from torch.ao.quantization import get_default_qconfig
from spconv.pytorch.core import SparseConvTensor
__all__ = ["get_default_spconv_trt_ptq_qconfig", "get_default_spconv_trt_qat_qconfig"]
class SparseFusedMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize): class SparseFusedMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
def forward(self, input:SparseConvTensor): def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
if isinstance(input, SparseConvTensor):
# add lines to support spconv
x = input.features
res_features = super().forward(x)
return input.replace_feature(res_features)
else:
return super().forward(input)
# class SparseMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
# def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
# if isinstance(input, SparseConvTensor):
# # add lines to support spconv
# x = input.features
# res_features = super().forward(x)
# return input.replace_feature(res_features)
# else:
# return super().forward(input)
class SparseHistogramObserver(HistogramObserver):
def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
if isinstance(input, SparseConvTensor):
# add lines to support spconv # add lines to support spconv
x = input.features x = input.features
res_features = super().forward(x) res_features = super().forward(x)
return input.replace_feature(res_features) return input.replace_feature(res_features)
else:
return super().forward(input)
default_symmetric_spconv_ptq_qconfig = QConfig(
activation=SparseHistogramObserver.with_args(quant_min=-128,
quant_max=127,
dtype=torch.qint8,
reduce_range=False,
qscheme=torch.per_tensor_symmetric,
eps=2 ** -12),
weight=default_per_channel_weight_observer)
# default_symmetric_ptq_qconfig = QConfig(
# activation=HistogramObserver.with_args(quant_min=-128,
# quant_max=127,
# dtype=torch.qint8,
# reduce_range=False,
# eps=2 ** -12),
# weight=default_per_channel_weight_observer)
default_symmetric_spconv_qat_qconfig = QConfig( default_symmetric_spconv_qat_qconfig = QConfig(
activation=SparseFusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, activation=SparseFusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
...@@ -19,5 +69,78 @@ default_symmetric_spconv_qat_qconfig = QConfig( ...@@ -19,5 +69,78 @@ default_symmetric_spconv_qat_qconfig = QConfig(
quant_max=127, quant_max=127,
dtype=torch.qint8, dtype=torch.qint8,
reduce_range=False, reduce_range=False,
qscheme=torch.per_tensor_symmetric,
eps=2 ** -12), eps=2 ** -12),
weight=fused_wt_fake_quant_range_neg_127_to_127) weight=default_fused_per_channel_wt_fake_quant)
def get_default_spconv_trt_ptq_qconfig(backend, version):
return default_symmetric_spconv_ptq_qconfig
def get_default_spconv_trt_qat_qconfig(backend, version):
return default_symmetric_spconv_qat_qconfig
def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "x86", version: int = 0) -> QConfigMapping:
"""
From torch.ao.quantization.qconfig_mapping
Return the default QConfigMapping for the given quantization type and backend.
"""
# get_default_qconfig(backend, version)
if is_qat:
qconfig = get_default_spconv_trt_qat_qconfig(backend, version)
else:
# qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8),
# weight=default_per_channel_weight_observer)
qconfig = get_default_spconv_trt_ptq_qconfig(backend, version)
default_weight = default_weight_fake_quant if is_qat else default_weight_observer
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
# so we have to modify the weight observer to default_weight_observer or another
# per tensor supported observer.
# see https://github.com/pytorch/pytorch/issues/47535
if backend in ("fbgemm", "x86"):
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
else:
qconfig_transpose = qconfig
# currently layernorm only supports float weights
# we have to add this because otherwise there will be a extra quantize-dequantize pair
qconfig_layernorm = QConfig(activation=qconfig.activation, weight=default_placeholder_observer)
qconfig_mapping = QConfigMapping() \
.set_global(qconfig) \
.set_object_type("reshape", default_reuse_input_qconfig) \
.set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
.set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \
.set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \
# Use special observers for ops with fixed qparams
fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}
for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items():
if observer in fixed_qparams_observer_to_qconfig:
fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer]
else:
if is_qat:
activation = FixedQParamsFakeQuantize.with_args(observer=observer)
else:
activation = observer
fixed_qparams_qconfig = QConfig(activation=activation, weight=default_weight)
fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig
qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig)
# QConfig for fused ops for onednn backend
# Separate ops are required to have the same qconfig as fused ops
# TODO: we should be able to configure qconfig for patterns
if backend == 'onednn':
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig) \
.set_object_type(torch.nn.LeakyReLU, qconfig) \
.set_object_type(torch.nn.functional.leaky_relu, qconfig) \
.set_object_type(torch.nn.Tanh, qconfig) \
.set_object_type(torch.nn.functional.tanh, qconfig)
return qconfig_mapping
...@@ -3,9 +3,9 @@ import torch.nn as nn ...@@ -3,9 +3,9 @@ import torch.nn as nn
import spconv.pytorch as spconv import spconv.pytorch as spconv
from .utils import fuse_spconv_bn_eval from .utils import fuse_spconv_bn_eval
from . import intrinsic as snni from . import intrinsic as snni
from .conv_fused import SparseConvBn, SparseConvBnReLU from .intrinsic.qat.modules import SparseConvBn, SparseConvBnReLU, SparseConvBnAddReLU
from spconv.pytorch.conv import DEFAULT_SPARSE_CONV_TYPES
def fuse_conv_bn(conv, bn): def fuse_conv_bn(is_qat, conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module r"""Given the conv and bn modules, fuses them and returns the fused module
Args: Args:
...@@ -22,18 +22,10 @@ def fuse_conv_bn(conv, bn): ...@@ -22,18 +22,10 @@ def fuse_conv_bn(conv, bn):
"Conv and BN both must be in the same mode (train or eval)." "Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map = { fused_module_class_map = {
spconv.SubMConv1d: snni.SpconvBnNd, k: snni.SpconvBnNd for k in DEFAULT_SPARSE_CONV_TYPES
spconv.SparseConv1d: snni.SpconvBnNd,
spconv.SparseInverseConv1d: snni.SpconvBnNd,
spconv.SubMConv2d: snni.SpconvBnNd,
spconv.SparseConv2d: snni.SpconvBnNd,
spconv.SparseInverseConv2d: snni.SpconvBnNd,
spconv.SubMConv3d: snni.SpconvBnNd,
spconv.SparseConv3d: snni.SpconvBnNd,
spconv.SparseInverseConv3d: snni.SpconvBnNd,
} }
if conv.training: if is_qat:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
...@@ -45,7 +37,7 @@ def fuse_conv_bn(conv, bn): ...@@ -45,7 +37,7 @@ def fuse_conv_bn(conv, bn):
else: else:
return fuse_spconv_bn_eval(conv, bn) return fuse_spconv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu): def fuse_conv_bn_relu(is_qat, conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module r"""Given the conv and bn modules, fuses them and returns the fused module
Args: Args:
...@@ -61,17 +53,9 @@ def fuse_conv_bn_relu(conv, bn, relu): ...@@ -61,17 +53,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
assert(conv.training == bn.training == relu.training),\ assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)." "Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[spconv.SparseSequential]] = None fused_module : Optional[Type[spconv.SparseSequential]] = None
if conv.training: if is_qat:
map_to_fused_module_train = { map_to_fused_module_train = {
spconv.SubMConv1d: snni.SpconvBnReLUNd, k: snni.SpconvBnReLUNd for k in DEFAULT_SPARSE_CONV_TYPES
spconv.SparseConv1d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv1d: snni.SpconvBnReLUNd,
spconv.SubMConv2d: snni.SpconvBnReLUNd,
spconv.SparseConv2d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv2d: snni.SpconvBnReLUNd,
spconv.SubMConv3d: snni.SpconvBnReLUNd,
spconv.SparseConv3d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv3d: snni.SpconvBnReLUNd,
} }
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True' assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
...@@ -83,15 +67,7 @@ def fuse_conv_bn_relu(conv, bn, relu): ...@@ -83,15 +67,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu))) raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else: else:
map_to_fused_module_eval = { map_to_fused_module_eval = {
spconv.SubMConv1d: snni.SpconvReLUNd, k: snni.SpconvReLUNd for k in DEFAULT_SPARSE_CONV_TYPES
spconv.SparseConv1d: snni.SpconvReLUNd,
spconv.SparseInverseConv1d: snni.SpconvReLUNd,
spconv.SubMConv2d: snni.SpconvReLUNd,
spconv.SparseConv2d: snni.SpconvReLUNd,
spconv.SparseInverseConv2d: snni.SpconvReLUNd,
spconv.SubMConv3d: snni.SpconvReLUNd,
spconv.SparseConv3d: snni.SpconvReLUNd,
spconv.SparseInverseConv3d: snni.SpconvReLUNd,
} }
fused_module = map_to_fused_module_eval.get(type(conv), None) fused_module = map_to_fused_module_eval.get(type(conv), None)
if fused_module is not None: if fused_module is not None:
...@@ -100,31 +76,28 @@ def fuse_conv_bn_relu(conv, bn, relu): ...@@ -100,31 +76,28 @@ def fuse_conv_bn_relu(conv, bn, relu):
else: else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = { # DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn, # (spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, # (spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
} # }
# def get_spconv_fuse_method_mapping():
# return DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
# Default map for swapping float module to qat modules # Default map for swapping float module to qat modules
DEFAULT_SPCONV_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
# nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
snni.SpconvBnNd: SparseConvBn,
snni.SpconvBnReLUNd: SparseConvBnReLU,
}
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modules import SpconvBnNd, SpconvBnReLUNd, SpconvBnAddReLUNd, SpconvReLUNd
...@@ -4,9 +4,28 @@ from torch.nn.utils.parametrize import type_before_parametrizations ...@@ -4,9 +4,28 @@ from torch.nn.utils.parametrize import type_before_parametrizations
import torch.ao.nn.intrinsic as nni import torch.ao.nn.intrinsic as nni
from spconv.pytorch.conv import SparseConvolution from spconv.pytorch.conv import SparseConvolution
from spconv.pytorch.modules import is_spconv_module
from spconv.pytorch.core import SparseConvTensor
class _FusedSparseModule(nni._FusedModule):
def forward(self, input):
for k, module in self._modules.items():
if is_spconv_module(module): # use SpConvTensor as input
if isinstance(input, list):
input = module(input)
else:
# assert isinstance(input, spconv.SparseConvTensor)
# self._sparity_dict[k] = input.sparity
input = module(input)
else:
if isinstance(input, SparseConvTensor):
if input.indices.shape[0] != 0:
input = input.replace_feature(module(input.features))
else:
input = module(input)
return input
class SpconvReLUNd(nni._FusedModule): class SpconvReLUNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv3d and ReLU modules. r"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module.""" During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu): def __init__(self, conv, relu):
...@@ -15,7 +34,7 @@ class SpconvReLUNd(nni._FusedModule): ...@@ -15,7 +34,7 @@ class SpconvReLUNd(nni._FusedModule):
type(conv), type(relu)) type(conv), type(relu))
super().__init__(conv, relu) super().__init__(conv, relu)
class SpconvBnNd(nni._FusedModule): class SpconvBnNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules. r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module.""" During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn): def __init__(self, conv, bn):
...@@ -24,8 +43,16 @@ class SpconvBnNd(nni._FusedModule): ...@@ -24,8 +43,16 @@ class SpconvBnNd(nni._FusedModule):
type(conv), type(bn)) type(conv), type(bn))
super().__init__(conv, bn) super().__init__(conv, bn)
class SpconvBnReLUNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert isinstance(conv, SparseConvolution) and isinstance(bn, BatchNorm1d) and \
isinstance(relu, ReLU), 'Incorrect types for input modules{}{}{}' \
.format(type(conv), type(bn), type(relu))
super().__init__(conv, bn, relu)
class SpconvBnReLUNd(nni._FusedModule): class SpconvBnAddReLUNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules. r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module.""" During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu): def __init__(self, conv, bn, relu):
......
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modules import SparseConvBn, SparseConvBnAddReLU, SparseConvBnReLU, SparseConv, SparseConvReLU
\ No newline at end of file
...@@ -6,8 +6,6 @@ import torch.ao.nn.intrinsic as nni ...@@ -6,8 +6,6 @@ import torch.ao.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat import torch.ao.nn.qat as nnqat
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import init from torch.nn import init
from torch.nn.utils import fuse_conv_bn_weights
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from typing import TypeVar from typing import TypeVar
from spconv.pytorch.conv import SparseConvolution from spconv.pytorch.conv import SparseConvolution
...@@ -16,9 +14,189 @@ from spconv.core import ConvAlgo ...@@ -16,9 +14,189 @@ from spconv.core import ConvAlgo
from cumm import tensorview as tv from cumm import tensorview as tv
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
import spconv.pytorch.quantization.intrinsic as snni import spconv.pytorch.quantization.intrinsic as snni
from spconv.pytorch.quantization.utils import fuse_spconv_bn_weights
MOD = TypeVar('MOD', bound=SparseConvolution) MOD = TypeVar('MOD', bound=SparseConvolution)
class _SparseConv(SparseConvolution, nni._FusedModule):
_FLOAT_MODULE = MOD
_FLOAT_CONV_MODULE = SparseConvolution
def __init__(self,
ndim: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: int = 1,
bias: bool = True,
subm: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
transposed: bool = False,
inverse: bool = False,
indice_key: Optional[str] = None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0,
name=None,
qconfig=None,
device=None,
dtype=None) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
SparseConvolution.__init__(self, ndim, in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
bias=False,
subm=subm,
output_padding=output_padding,
transposed=transposed,
inverse=inverse,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
act_type=act_type,
act_alpha=act_alpha,
act_beta=act_beta,
name=name, **factory_kwargs)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return self._conv_forward(False, input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod
def from_float(cls, mod):
r"""Create a qat module from a float module
Args:
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
)
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
if issubclass(type(mod), nni._FusedModule):
mod = mod[0] # type: ignore[index]
conv: SparseConvolution = mod
qconfig = mod.qconfig
qat_conv = cls(conv.ndim, conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups,
conv.bias is not None,
subm=conv.subm,
output_padding=conv.output_padding,
transposed=conv.transposed,
inverse=conv.inverse,
indice_key=conv.indice_key,
algo=conv.algo,
fp32_accum=conv.fp32_accum,
record_voxel_count=conv.record_voxel_count,
act_type=conv.act_type,
act_alpha=conv.act_alpha,
act_beta=conv.act_beta,
name=conv.name,
qconfig=qconfig)
qat_conv.weight = mod.weight
qat_conv.bias = mod.bias
return qat_conv
def to_float(self):
""" This works for both single qat conv, and the qat conv - relu modules
to convert the qat module to a floating point module
"""
cls = type(self)
conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
self.ndim,
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.groups,
self.bias is not None,
subm=self.subm,
output_padding=self.output_padding,
transposed=self.transposed,
inverse=self.inverse,
indice_key=self.indice_key,
algo=self.algo,
fp32_accum=self.fp32_accum,
record_voxel_count=self.record_voxel_count,
act_type=self.act_type,
act_alpha=self.act_alpha,
act_beta=self.act_beta,
name=self.name)
conv.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
conv.bias = torch.nn.Parameter(self.bias.detach())
# conv relu
if issubclass(cls, nni._FusedModule):
modules = [conv]
assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
modules.append(relu)
fused = cls._FLOAT_MODULE(*modules) # type: ignore[arg-type, attr-defined, operator]
fused.train(self.training)
return fused
else:
return conv
class SparseConv(_SparseConv, SparseConvolution):
r"""
A Conv1d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as :class:`~torch.nn.Conv1d`
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = SparseConvolution
_FLOAT_CONV_MODULE = SparseConvolution
@classmethod
def from_float(cls, mod):
return super().from_float(cls, mod)
class SparseConvReLU(SparseConv, nni._FusedModule):
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = snni.SpconvReLUNd
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = None
_FLOAT_RELU_MODULE = nn.ReLU
def forward(self, input):
x = self._conv_forward(self.training, input, self.weight_fake_quant(self.weight), self.bias)
return x.replace_feature(F.relu(x.features))
@classmethod
def from_float(cls, mod):
return super(SparseConvReLU, cls).from_float(mod)
class _SparseConvBn(SparseConvolution, nni._FusedModule): class _SparseConvBn(SparseConvolution, nni._FusedModule):
_version = 2 _version = 2
...@@ -34,7 +212,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule): ...@@ -34,7 +212,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
stride: Union[int, List[int], Tuple[int, ...]] = 1, stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0, padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1, dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: Union[int, List[int], Tuple[int, ...]] = 1, groups: int = 1,
bias: bool = True, bias: bool = True,
subm: bool = False, subm: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]] = 0, output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
...@@ -143,7 +321,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule): ...@@ -143,7 +321,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
zero_bias = torch.zeros_like(self.bias, dtype=input.features.dtype) zero_bias = torch.zeros_like(self.bias, dtype=input.features.dtype)
else: else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.features.dtype) zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.features.dtype)
conv_spt = self._conv_forward(input, scaled_weight, zero_bias) conv_spt = self._conv_forward(self.training, input, scaled_weight, zero_bias)
conv = conv_spt.features conv = conv_spt.features
conv_orig = conv / scale_factor.reshape(bias_shape) conv_orig = conv / scale_factor.reshape(bias_shape)
if self.bias is not None: if self.bias is not None:
...@@ -396,7 +574,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule): ...@@ -396,7 +574,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined] if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
# fuse bn into conv # fuse bn into conv
conv.weight, conv.bias = fuse_conv_bn_weights( conv.weight, conv.bias = fuse_spconv_bn_weights(
conv.weight, conv.weight,
conv.bias, conv.bias,
self.bn.running_mean, self.bn.running_mean,
...@@ -473,3 +651,35 @@ class SparseConvBnReLU(_SparseConvBn): ...@@ -473,3 +651,35 @@ class SparseConvBnReLU(_SparseConvBn):
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod):
return super(SparseConvBnReLU, cls).from_float(mod) return super(SparseConvBnReLU, cls).from_float(mod)
class SparseConvBnAddReLU(_SparseConvBn):
r"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE = snni.SpconvBnReLUNd # type: ignore[assignment]
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = nn.BatchNorm1d
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE = snni.SpconvReLUNd
def forward(self, input, add_input):
x = _SparseConvBn._forward(self, input, add_input)
return x.replace_feature(F.relu(x.features))
@classmethod
def from_float(cls, mod):
return super(SparseConvBnAddReLU, cls).from_float(mod)
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv_relu import *
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment