Commit b1c57a31 authored by yan.yan's avatar yan.yan
Browse files

still working on int8

parent aa26c99e
......@@ -23,33 +23,95 @@ from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import contextlib
import torch.cuda.amp
import torch.ao.quantization
from torch.ao.quantization import QuantStub, DeQuantStub
import torch.ao.quantization.quantize_fx as qfx
from spconv.pytorch.quantization.fake_q import get_default_spconv_qconfig_mapping
import spconv.pytorch.quantization as spconvq
from spconv.pytorch.quantization import get_default_spconv_trt_ptq_qconfig
from torch.ao.quantization import get_default_qconfig_mapping
from spconv.pytorch.quantization.backend_cfg import SPCONV_STATIC_LOWER_FUSED_MODULE_MAP
from torch.ao.quantization.fx._lower_to_native_backend import STATIC_LOWER_FUSED_MODULE_MAP
@contextlib.contextmanager
def identity_ctx():
yield
class SubMConvBNReLU(spconv.SparseSequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(SubMConvBNReLU, self).__init__(
spconv.SubMConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm1d(out_planes, momentum=0.1),
# Replace with ReLU
nn.ReLU(inplace=False)
)
class SparseConvBNReLU(spconv.SparseSequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(SparseConvBNReLU, self).__init__(
spconv.SparseConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm1d(out_planes, momentum=0.1),
# Replace with ReLU
nn.ReLU(inplace=False)
)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = spconv.SparseSequential(
nn.BatchNorm1d(1),
spconv.SubMConv2d(1, 32, 3, 1),
nn.ReLU(),
spconv.SubMConv2d(32, 64, 3, 1),
nn.ReLU(),
spconv.SparseConv2d(64, 64, 2, 2),
SubMConvBNReLU(1, 32, 3),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2),
spconv.ToDense(),
)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x_sp: spconv.SparseConvTensor):
# def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
# x = self.quant(x)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
# x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp)
x = torch.flatten(x, 1)
x = self.dropout1(x)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
# x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetV2(nn.Module):
def __init__(self):
super(NetV2, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2),
spconv.ToDense(),
)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x: torch.Tensor):
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp)
x = torch.flatten(x, 1)
......@@ -58,10 +120,93 @@ class Net(nn.Module):
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so
we only use sparse ops here.
"""
def __init__(self):
super(NetPTQ, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14
SparseConvBNReLU(64, 64, 2, 2), # 7x7
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4),
spconv.ToDense(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
features = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp = self.net(x_sp)
# print(x_sp.shape)
x = x_sp
x = torch.flatten(x, 1)
# x_res = torch.zeros_like(x)
# x_res[x_sp.indices[:, 0].long()] = x
# x = x_res
# x = torch.flatten(x, 1)
# x = self.dropout1(x)
# x = self.fc1(x)
# x = F.relu(x)
# x = self.dropout2(x)
# x = self.fc2(x)
# print(x_sp.features.shape, x_sp.spatial_shape)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetDense(nn.Module):
def __init__(self):
super(NetDense, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
scaler = torch.cuda.amp.grad_scaler.GradScaler()
......@@ -72,7 +217,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with amp_ctx:
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(data.reshape(-1, 28, 28, 1))
# output = model(data_sp)
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
else:
output = model(data)
loss = F.nll_loss(output, target)
scale = 1.0
if args.fp16:
......@@ -114,7 +265,11 @@ def test(args, model, device, test_loader):
data, target = data.to(device), target.to(device)
with amp_ctx:
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(data.reshape(-1, 28, 28, 1))
# output = model(data_sp)
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
else:
output = model(data)
test_loss += F.nll_loss(
output, target, reduction='sum').item() # sum up batch loss
......@@ -131,6 +286,19 @@ def test(args, model, device, test_loader):
100. * correct / len(test_loader.dataset)))
def calibrate(args, model: torch.nn.Module, data_loader, device):
model.eval()
with torch.no_grad():
for image, target in data_loader:
image = image.to(device)
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(image.reshape(-1, 28, 28, 1))
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
# output = model(data_sp)
else:
output = model(image)
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
......@@ -146,7 +314,7 @@ def main():
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs',
type=int,
default=14,
default=1,
metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr',
......@@ -168,6 +336,10 @@ def main():
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument('--sparse',
action='store_true',
default=True,
help='use sparse conv network instead of dense')
parser.add_argument(
'--log-interval',
type=int,
......@@ -190,8 +362,14 @@ def main():
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
if args.sparse:
model = NetPTQ().to(device)
else:
model = NetDense().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
......@@ -218,17 +396,46 @@ def main():
shuffle=True,
**kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
scheduler.step()
# if args.save_model:
# torch.save(model.state_dict(), "mnist_cnn.pt")
model.eval()
STATIC_LOWER_FUSED_MODULE_MAP.update(SPCONV_STATIC_LOWER_FUSED_MODULE_MAP)
if not args.sparse:
model = model.cpu()
# qconfig_mapping_default = get_default_qconfig_mapping("x86")
qconfig_mapping = get_default_spconv_qconfig_mapping(False)
prepare_cfg = spconvq.get_spconv_prepare_custom_config()
backend_cfg = spconvq.get_spconv_backend_config()
convert_cfg = spconvq.get_spconv_convert_custom_config()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# prepared_model.print_readable()
print([type(m) for m in prepared_model.modules()])
print(prepared_model)
# calibrate: run model with some inputs
calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model = qfx.convert_to_reference_fx(prepared_model, convert_cfg, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
print([type(m) for m in converted_model.modules()])
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
# model.qconfig = get_default_spconv_trt_ptq_qconfig()
# prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# convert_custom_config_dict = spconvq.get_convert_custom_config()
# torch.ao.quantization.prepare(model, inplace=True)
# print('Post Training Quantization Prepare: Inserting Observers')
# print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# test(args, model, device, test_loader)
print(converted_model)
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
test(args, converted_model, qdevice, test_loader)
breakpoint()
if __name__ == '__main__':
......
......@@ -21,7 +21,7 @@ from ccimport.compat import InWindows
from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT
if project_is_installed(PACKAGE_NAME) and project_is_editable(
PACKAGE_NAME) and not DISABLE_JIT:
PACKAGE_NAME) and not DISABLE_JIT and False:
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS, SHUFFLE_AMPERE_PARAMS
from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS, IMPLGEMM_AMPERE_PARAMS
......
......@@ -699,6 +699,22 @@ IMPLGEMM_AMPERE_PARAMS = [
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
......@@ -797,7 +813,21 @@ IMPLGEMM_TURING_PARAMS = [
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=0,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
......
......@@ -2,7 +2,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> Tensor:
"""
Args:
name:
......@@ -11,9 +11,10 @@ class ExternalAllocator:
device:
stream:
is_temp_memory:
scale:
"""
...
def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
def empty(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> Tensor:
"""
Args:
name:
......@@ -22,6 +23,7 @@ class ExternalAllocator:
device:
stream:
is_temp_memory:
scale:
"""
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False) -> Tensor:
......
......@@ -63,7 +63,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0) -> Tuple[int, Any]:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True, output_scale: float = 1.0, scale: Tensor = Tensor(), output_add: Tensor = Tensor(), output_add_scale: float = 1.0, output_dtype: int = -1) -> Tuple[int, Any]:
"""
Args:
allocator:
......@@ -91,6 +91,7 @@ class ConvGemmOps:
scale:
output_add:
output_add_scale:
output_dtype:
"""
...
@staticmethod
......
......@@ -56,7 +56,7 @@ class ExternalAllocator(pccm.Class):
code.arg("device", "int")
code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
code.arg("scale", "float", "1.0")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
......@@ -69,7 +69,7 @@ class ExternalAllocator(pccm.Class):
code.arg("device", "int")
code.arg("stream", "std::uintptr_t", "0")
code.arg("is_temp_memory", "bool", "false")
code.arg("scale", "float", "1.0")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
......
......@@ -2127,10 +2127,10 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
if (is_subm){{
out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int, false /*is_temp*/, output_scale);
}}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int);
{{num_activate_out, out_channel}}, tv::DType(output_dtype), features.device(), stream_int, false /*is_temp*/, output_scale);
}}
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
......
......@@ -311,7 +311,7 @@ class IndiceMaxPool(pccm.Class):
code.raw(f"""
auto nhot = out_inds.dim(0);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t, int8_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
......@@ -350,7 +350,7 @@ class IndiceMaxPool(pccm.Class):
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t, int8_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
......@@ -478,7 +478,7 @@ class IndiceMaxPool(pccm.Class):
tv::check_shape(in, {{-1, out.dim(1)}});
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t>(out.dtype(), [&](auto I){{
tv::dispatch<float, double, tv::half_t, tv::bfloat16_t, int8_t>(out.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
auto launchdims = LaunchUtils::get_blocks_threads_of_2d_tensor(nhot, out.dim(1));
int num_blocks_X = std::get<0>(launchdims);
......
......@@ -36,7 +36,7 @@ from spconv.constants import SAVED_WEIGHT_LAYOUT, ALL_WEIGHT_IS_KRSC, SPCONV_DEB
from spconv.utils import nullcontext
from torch.nn.init import calculate_gain
from cumm import tensorview as tv
from collections import namedtuple
from torch.nn import functional as F
FILTER_HWIO = False
......@@ -55,12 +55,7 @@ def _apply_act(x: torch.Tensor, act_type: tv.gemm.Activation, act_alpha: float,
else:
raise NotImplementedError
class SparseConvolution(SparseModule):
__constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
'transposed', 'output_padding'
]
class SparseConvolutionBase:
def __init__(self,
ndim: int,
in_channels: int,
......@@ -69,7 +64,7 @@ class SparseConvolution(SparseModule):
stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: Union[int, List[int], Tuple[int, ...]] = 1,
groups: int = 1,
bias: bool = True,
subm: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
......@@ -81,21 +76,17 @@ class SparseConvolution(SparseModule):
record_voxel_count: bool = False,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0,
name=None):
super(SparseConvolution, self).__init__(name=name)
act_beta: float = 0):
assert groups == 1, "don't support groups for now"
self.ndim = ndim
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = expand_nd(ndim, kernel_size)
self.stride = expand_nd(ndim, stride)
kv = int(np.prod(self.kernel_size))
kv_stride = int(np.prod(self.stride))
self.dilation = expand_nd(ndim, dilation)
self.padding = expand_nd(ndim, padding)
self.conv1x1 = kv == 1
# TODO we should deprecate support for ksize == 1 but stride != 1.
if not subm:
......@@ -110,11 +101,6 @@ class SparseConvolution(SparseModule):
self.groups = groups
self.subm = subm
self.indice_key = indice_key
if record_voxel_count and not self.subm and not self.inverse:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self.register_buffer(_MAX_NUM_VOXELS_DURING_TRAINING,
torch.zeros(1, dtype=torch.int32))
self.record_voxel_count = record_voxel_count
if algo is None:
if kv <= 128 and not CPU_ONLY_BUILD:
......@@ -131,169 +117,37 @@ class SparseConvolution(SparseModule):
self.algo = algo
self.fp32_accum = fp32_accum
# self.algo = ConvAlgo.Native
if self.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO:
# RSCK
self.weight = Parameter(
torch.Tensor(*self.kernel_size, in_channels, out_channels))
weight_shape = [*self.kernel_size, in_channels, out_channels]
else:
# RSKC
self.weight = Parameter(
torch.Tensor(*self.kernel_size, out_channels, in_channels))
weight_shape = [*self.kernel_size, out_channels, in_channels]
else:
# KRSC
self.weight = Parameter(
torch.Tensor(out_channels, *self.kernel_size, in_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
weight_shape = [out_channels, *self.kernel_size, in_channels]
self.weight_shape = weight_shape
self.act_type = act_type
self.act_alpha = act_alpha
self.act_beta = act_beta
self.enable_int8_test_mode: bool = False
self._int8_weight = torch.Tensor()
# calculated by max(abs(weight)) for each channel
self._int8_weight_scale = torch.Tensor()
# calculated by scale self.bias with _int8_input_scale
self._int8_bias = torch.Tensor()
# int8 inference must set _int8_input_scale
self._int8_input_scale: Optional[float] = None
# if _int8_output_scale unset, will execute s8 @ s8 => f16/f32 (weight dtype), i.e. dequantization
self._int8_output_scale: Optional[float] = None
self.scale = 1.0
self.zero_point = 0
if self.conv1x1:
assert act_type == tv.gemm.Activation.None_, "conv1x1 don't support fused act"
self.reset_parameters()
if hasattr(self, "_register_load_state_dict_pre_hook"):
self._register_load_state_dict_pre_hook(
self._load_weight_different_layout)
def get_max_num_voxels(self) -> Optional[torch.Tensor]:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
def set_int8_test(self, enable: bool, input_scale: float, output_scale: Optional[float] = None, weight_scale: Optional[torch.Tensor] = None):
self._int8_input_scale = input_scale
self._int8_output_scale = output_scale
if weight_scale is not None:
self._int8_weight_scale = weight_scale
self.enable_int8_test_mode = enable
def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs):
name = prefix + _MAX_NUM_VOXELS_DURING_TRAINING
if self.record_voxel_count and not self.subm and not self.inverse and name not in state_dict:
state_dict[name] = torch.zeros(
1, dtype=torch.int32)
if not SAVED_WEIGHT_LAYOUT:
return
key = prefix + "weight"
assert key in state_dict
ndim = self.ndim
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(ndim, *range(ndim),
ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim),
ndim).contiguous()
if ALL_WEIGHT_IS_KRSC or self.algo != ConvAlgo.Native:
# in spconv 2.2, we only support KRSC layout.
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(
ndim, *range(ndim), ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(
ndim + 1, *range(ndim), ndim).contiguous()
else:
if self.algo == ConvAlgo.Native:
# to RSCK
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(
*range(ndim), ndim + 1, ndim).contiguous()
elif SAVED_WEIGHT_LAYOUT == "KRSC":
state_dict[key] = state_dict[key].permute(
*range(1, ndim + 1), 0, ndim + 1).contiguous()
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0, ) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.algo is not None:
s += f', algo={self.algo}'
return s.format(**self.__dict__)
def _calculate_fan_in_and_fan_out(self):
receptive_field_size = 1
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in self.kernel_size:
receptive_field_size *= s
fan_in = self.in_channels * receptive_field_size
fan_out = self.out_channels * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(self, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(
mode, valid_modes))
fan_in, fan_out = self._calculate_fan_in_and_fan_out()
return fan_in if mode == 'fan_in' else fan_out
def _custom_kaiming_uniform_(self,
tensor,
a=0,
mode='fan_in',
nonlinearity='leaky_relu'):
r"""same as torch.init.kaiming_uniform_, with KRSC layout support
"""
fan = self._calculate_correct_fan(mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(
3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def reset_parameters(self):
if SPCONV_DEBUG_WEIGHT:
self._custom_kaiming_uniform_(self.weight, a=math.sqrt(0.005))
else:
self._custom_kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = self._calculate_fan_in_and_fan_out()
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(input, self.weight, self.bias, add_input)
def _conv_forward(self, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None):
assert isinstance(input, SparseConvTensor)
def _conv_forward(self, training: bool, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None,
channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None, name: Optional[str] = None,
sparse_unique_name: str = "",
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0):
# assert isinstance(input, SparseConvTensor)
is_int8 = input.is_quantized and weight.is_quantized
if is_int8:
assert output_scale is not None and channel_scale is not None, "int8 must be called in static quantized module"
assert bias is not None, "currently you must specify a bias"
assert input.features.shape[
1] == self.in_channels, "channel size mismatch"
features = input.features
......@@ -301,35 +155,38 @@ class SparseConvolution(SparseModule):
indices = input.indices
spatial_shape = input.spatial_shape
batch_size = input.batch_size
bias_for_training = bias if self.training else None
bias_for_infer = bias if not self.training else None
output_scale = None
bias_for_training = bias if training else None
bias_for_infer = bias if not training else None
output_add_scale = 1.0
if self.enable_int8_test_mode:
assert not self.training, "must in eval mode"
assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
assert bias_for_infer is not None, "conv-bn-relu must be fused"
assert self._int8_input_scale is not None
if features.dtype != torch.int8:
# quantize
features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
output_scale = self._int8_output_scale
int8_out_scale = output_scale
if int8_out_scale is None:
int8_out_scale = 1
if is_int8:
if add_input is not None:
assert add_input.int8_scale is not None, "only support int8 add"
output_add_scale = add_input.int8_scale
if self._int8_weight.numel() == 0:
with torch.no_grad():
assert ALL_WEIGHT_IS_KRSC
weight_scales = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
num_1s = [1] * (self.ndim + 1)
self._int8_weight = (weight / weight_scales.view(self.out_channels, *num_1s) * 127).to(torch.int8)
if self._int8_weight_scale.numel() == 0:
self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scales)
self._int8_bias = bias_for_infer * int8_out_scale
if self.training:
output_add_scale = add_input.q_scale()
# if self.enable_int8_test_mode:
# assert not self.training, "must in eval mode"
# assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
# assert bias_for_infer is not None, "conv-bn-relu must be fused"
# assert self._int8_input_scale is not None
# if features.dtype != torch.int8:
# # quantize
# features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
# output_scale = self._int8_output_scale
# int8_out_scale = output_scale
# if int8_out_scale is None:
# int8_out_scale = 1
# if add_input is not None:
# assert add_input.int8_scale is not None, "only support int8 add"
# output_add_scale = add_input.int8_scale
# if self._int8_weight.numel() == 0:
# with torch.no_grad():
# assert ALL_WEIGHT_IS_KRSC
# weight_scale = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
# num_1s = [1] * (self.ndim + 1)
# self._int8_weight = (weight / weight_scale.view(self.out_channels, *num_1s) * 127).to(torch.int8)
# if self._int8_weight_scale.numel() == 0:
# self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scale)
# self._int8_bias = bias_for_infer * int8_out_scale
if training:
msg = "act don't support backward, only used in inference"
assert self.act_type == tv.gemm.Activation.None_, msg
......@@ -349,12 +206,12 @@ class SparseConvolution(SparseModule):
# t = time.time()
out_tensor = input.shadow_copy()
if input.benchmark:
if self.name is None:
if name is None:
raise ValueError(
"you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
)
if self.name not in input.benchmark_record:
input.benchmark_record[self.name] = {
if name not in input.benchmark_record:
input.benchmark_record[name] = {
"type": "SparseConvolution",
"indice_gen_time": [],
"time": [],
......@@ -372,7 +229,7 @@ class SparseConvolution(SparseModule):
"out_channels": self.out_channels,
}
}
if self.conv1x1 and not self.enable_int8_test_mode:
if self.conv1x1 and not is_int8:
# in int8 test mode, we don't implement conv1x1 via mm.
if FILTER_HWIO:
features = torch.mm(
......@@ -401,8 +258,8 @@ class SparseConvolution(SparseModule):
assert algo == datas.algo, msg
# algo = datas.algo
profile_ctx = nullcontext()
if input._timer is not None and self._sparse_unique_name:
profile_ctx = input._timer.namespace(self._sparse_unique_name)
if input._timer is not None and sparse_unique_name:
profile_ctx = input._timer.namespace(sparse_unique_name)
with profile_ctx:
if algo == ConvAlgo.Native:
datas = input.find_indice_pair(self.indice_key)
......@@ -449,7 +306,7 @@ class SparseConvolution(SparseModule):
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
self.name]["indice_gen_time"].append(interval)
name]["indice_gen_time"].append(interval)
indice_data = IndiceData(outids,
indices,
......@@ -567,7 +424,7 @@ class SparseConvolution(SparseModule):
out_padding=self.output_padding,
subm=self.subm,
transpose=self.transposed,
is_train=(not self.subm) or self.training,
is_train=(not self.subm) or training,
alloc=input.thrust_allocator,
timer=input._timer)
except Exception as e:
......@@ -583,7 +440,7 @@ class SparseConvolution(SparseModule):
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
self.name]["indice_gen_time"].append(interval)
name]["indice_gen_time"].append(interval)
outids = res[0]
num_inds_per_loc = res[1]
pair_fwd = res[2]
......@@ -621,16 +478,16 @@ class SparseConvolution(SparseModule):
num_activate_out = outids.shape[0]
weight_cur = weight
bias_cur = bias_for_infer
if self.enable_int8_test_mode:
assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
weight_cur = self._int8_weight
bias_cur = self._int8_bias
if self.training:
# if self.enable_int8_test_mode:
# assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
# weight_cur = self._int8_weight
# bias_cur = self._int8_bias
if training:
out_features = Fsp.implicit_gemm(
features, weight_cur, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm,
num_activate_out, masks, training, self.subm,
input._timer, self.fp32_accum,
bias_cur,
self.act_alpha,
......@@ -638,20 +495,20 @@ class SparseConvolution(SparseModule):
self.act_type)
else:
output_dtype = None
if self._int8_output_scale is None:
if output_scale is None:
output_dtype = weight.dtype
out_features, _, _ = ops.implicit_gemm(
features, weight_cur, pair_fwd, pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out, masks, self.training, self.subm,
num_activate_out, masks, training, self.subm,
input._timer, self.fp32_accum,
bias_cur,
self.act_alpha,
self.act_beta,
self.act_type,
# TODO do we really need output scale to scale bias in kernel?
1.0, # output_scale
self._int8_weight_scale, # scale
1.0 if output_scale is None else output_scale, # output_scale
channel_scale, # scale
output_add=add_input.features if add_input is not None else None,
output_add_scale=output_add_scale,
output_dtype=output_dtype)
......@@ -661,10 +518,10 @@ class SparseConvolution(SparseModule):
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[self.name]["time"].append(interval)
out_tensor.benchmark_record[self.name]["num_points"].append(
out_tensor.benchmark_record[name]["time"].append(interval)
out_tensor.benchmark_record[name]["num_points"].append(
features.shape[0])
out_tensor.benchmark_record[self.name]["num_out_points"].append(
out_tensor.benchmark_record[name]["num_out_points"].append(
out_features.shape[0])
if not self.subm and not self.inverse and self.record_voxel_count:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
......@@ -675,9 +532,9 @@ class SparseConvolution(SparseModule):
out_tensor.indices = outids
out_tensor.indice_dict = indice_dict
out_tensor.spatial_shape = out_spatial_shape
if add_input is not None and not self.enable_int8_test_mode:
if add_input is not None and not is_int8:
# in int8, we apply add + act in kernel.
out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
out_tensor.int8_scale = output_scale
return out_tensor
......@@ -725,6 +582,629 @@ class SparseConvolution(SparseModule):
"please check Inverse Convolution in ."
)
class SparseConvolution(SparseConvolutionBase, SparseModule):
__constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
'transposed', 'output_padding'
]
def __init__(self,
ndim: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: int = 1,
bias: bool = True,
subm: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
transposed: bool = False,
inverse: bool = False,
indice_key: Optional[str] = None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0,
name=None,
device=None,
dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
SparseConvolutionBase.__init__(self, ndim, in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
bias=False,
subm=subm,
output_padding=output_padding,
transposed=transposed,
inverse=inverse,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
act_type=act_type,
act_alpha=act_alpha,
act_beta=act_beta)
SparseModule.__init__(self, name=name)
if record_voxel_count and not self.subm and not self.inverse:
# we record maximum voxel num in both inference and training if
# record_voxel_count flag setting.
self.register_buffer(_MAX_NUM_VOXELS_DURING_TRAINING,
torch.zeros(1, dtype=torch.int32, device=device))
self.weight = Parameter(torch.zeros(*self.weight_shape, **factory_kwargs))
if bias:
self.bias = Parameter(torch.zeros(out_channels, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
if hasattr(self, "_register_load_state_dict_pre_hook"):
self._register_load_state_dict_pre_hook(
self._load_weight_different_layout)
def get_max_num_voxels(self) -> Optional[torch.Tensor]:
if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
return getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING)
return None
# def set_int8_test(self, enable: bool, input_scale: float, output_scale: Optional[float] = None, weight_scale: Optional[torch.Tensor] = None):
# self._int8_input_scale = input_scale
# self._int8_output_scale = output_scale
# if weight_scale is not None:
# self._int8_weight_scale = weight_scale
# self.enable_int8_test_mode = enable
def _load_weight_different_layout(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs):
name = prefix + _MAX_NUM_VOXELS_DURING_TRAINING
if self.record_voxel_count and not self.subm and not self.inverse and name not in state_dict:
state_dict[name] = torch.zeros(
1, dtype=torch.int32)
if not SAVED_WEIGHT_LAYOUT:
return
key = prefix + "weight"
assert key in state_dict
ndim = self.ndim
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(ndim, *range(ndim),
ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim),
ndim).contiguous()
if ALL_WEIGHT_IS_KRSC or self.algo != ConvAlgo.Native:
# in spconv 2.2, we only support KRSC layout.
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(
ndim, *range(ndim), ndim + 1).contiguous()
elif SAVED_WEIGHT_LAYOUT == "RSCK":
state_dict[key] = state_dict[key].permute(
ndim + 1, *range(ndim), ndim).contiguous()
else:
if self.algo == ConvAlgo.Native:
# to RSCK
if SAVED_WEIGHT_LAYOUT == "RSKC":
state_dict[key] = state_dict[key].permute(
*range(ndim), ndim + 1, ndim).contiguous()
elif SAVED_WEIGHT_LAYOUT == "KRSC":
state_dict[key] = state_dict[key].permute(
*range(1, ndim + 1), 0, ndim + 1).contiguous()
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0, ) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.algo is not None:
s += f', algo={self.algo}'
return s.format(**self.__dict__)
def _calculate_fan_in_and_fan_out(self):
receptive_field_size = 1
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in self.kernel_size:
receptive_field_size *= s
fan_in = self.in_channels * receptive_field_size
fan_out = self.out_channels * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(self, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(
mode, valid_modes))
fan_in, fan_out = self._calculate_fan_in_and_fan_out()
return fan_in if mode == 'fan_in' else fan_out
def _custom_kaiming_uniform_(self,
tensor,
a=0,
mode='fan_in',
nonlinearity='leaky_relu'):
r"""same as torch.init.kaiming_uniform_, with KRSC layout support
"""
fan = self._calculate_correct_fan(mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(
3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def reset_parameters(self):
if SPCONV_DEBUG_WEIGHT:
self._custom_kaiming_uniform_(self.weight, a=math.sqrt(0.005))
else:
self._custom_kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = self._calculate_fan_in_and_fan_out()
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def is_inverseable(self):
return self.indice_key is not None and not self.subm
def forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
return self._conv_forward(self.training, input, self.weight, self.bias, add_input,
name=self.name, sparse_unique_name=self._sparse_unique_name, act_type=self.act_type,
act_alpha=self.act_alpha, act_beta=self.act_beta)
# def _conv_forward(self, input: SparseConvTensor, weight: torch.Tensor, bias: Optional[torch.Tensor], add_input: Optional[SparseConvTensor] = None,
# channel_scale: Optional[torch.Tensor] = None, output_scale: Optional[float] = None):
# assert isinstance(input, SparseConvTensor)
# is_int8 = input.is_quantized and weight.is_quantized
# if is_int8:
# assert output_scale is not None and channel_scale is not None, "int8 must be called in static quantized module"
# assert bias is not None, "currently you must specify a bias"
# assert input.features.shape[
# 1] == self.in_channels, "channel size mismatch"
# features = input.features
# device = features.device
# indices = input.indices
# spatial_shape = input.spatial_shape
# batch_size = input.batch_size
# bias_for_training = bias if self.training else None
# bias_for_infer = bias if not self.training else None
# output_add_scale = 1.0
# if is_int8:
# if add_input is not None:
# output_add_scale = add_input.q_scale()
# # if self.enable_int8_test_mode:
# # assert not self.training, "must in eval mode"
# # assert self.algo == ConvAlgo.MaskImplicitGemm, "int8 inference only support MaskImplicitGemm"
# # assert bias_for_infer is not None, "conv-bn-relu must be fused"
# # assert self._int8_input_scale is not None
# # if features.dtype != torch.int8:
# # # quantize
# # features = torch.clamp(torch.round(features / self._int8_input_scale), -128, 127).to(torch.int8)
# # output_scale = self._int8_output_scale
# # int8_out_scale = output_scale
# # if int8_out_scale is None:
# # int8_out_scale = 1
# # if add_input is not None:
# # assert add_input.int8_scale is not None, "only support int8 add"
# # output_add_scale = add_input.int8_scale
# # if self._int8_weight.numel() == 0:
# # with torch.no_grad():
# # assert ALL_WEIGHT_IS_KRSC
# # weight_scale = torch.abs(weight).view(self.out_channels, -1).max(1)[0]
# # num_1s = [1] * (self.ndim + 1)
# # self._int8_weight = (weight / weight_scale.view(self.out_channels, *num_1s) * 127).to(torch.int8)
# # if self._int8_weight_scale.numel() == 0:
# # self._int8_weight_scale = int8_out_scale / (self._int8_input_scale * weight_scale)
# # self._int8_bias = bias_for_infer * int8_out_scale
# if self.training:
# msg = "act don't support backward, only used in inference"
# assert self.act_type == tv.gemm.Activation.None_, msg
# if not self.subm:
# if self.transposed:
# out_spatial_shape = ops.get_deconv_output_size(
# spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding)
# else:
# out_spatial_shape = ops.get_conv_output_size(
# spatial_shape, self.kernel_size, self.stride, self.padding,
# self.dilation)
# else:
# out_spatial_shape = spatial_shape
# # print(self._sparse_unique_name, spatial_shape, out_spatial_shape)
# # input.update_grid(out_spatial_shape)
# # t = time.time()
# out_tensor = input.shadow_copy()
# if input.benchmark:
# if self.name is None:
# raise ValueError(
# "you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
# )
# if self.name not in input.benchmark_record:
# input.benchmark_record[self.name] = {
# "type": "SparseConvolution",
# "indice_gen_time": [],
# "time": [],
# "num_points": [],
# "num_out_points": [],
# "params": {
# "kernel_size": self.kernel_size,
# "stride": self.stride,
# "padding": self.padding,
# "dilation": self.dilation,
# "output_padding": self.output_padding,
# "subm": self.subm,
# "transposed": self.transposed,
# "input_channels": self.in_channels,
# "out_channels": self.out_channels,
# }
# }
# if self.conv1x1 and not is_int8:
# # in int8 test mode, we don't implement conv1x1 via mm.
# if FILTER_HWIO:
# features = torch.mm(
# input.features,
# weight.view(self.out_channels, self.in_channels).T)
# else:
# features = torch.mm(
# input.features,
# weight.view(self.in_channels, self.out_channels))
# if bias is not None:
# features += bias
# out_tensor = out_tensor.replace_feature(features)
# # padding may change spatial shape of conv 1x1.
# out_tensor.spatial_shape = out_spatial_shape
# return out_tensor
# indice_dict = input.indice_dict.copy()
# # only support contiguous tensor for now
# if not features.is_contiguous():
# features = features.contiguous()
# algo = self.algo
# if self.indice_key is not None:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
# assert algo == datas.algo, msg
# # algo = datas.algo
# profile_ctx = nullcontext()
# if input._timer is not None and self._sparse_unique_name:
# profile_ctx = input._timer.namespace(self._sparse_unique_name)
# with profile_ctx:
# if algo == ConvAlgo.Native:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# assert isinstance(datas, IndiceData)
# if self.inverse:
# assert datas is not None and self.indice_key is not None
# assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
# outids = datas.indices
# indice_pairs = datas.indice_pairs
# indice_pair_num = datas.indice_pair_num
# out_spatial_shape = datas.spatial_shape
# self._check_inverse_reuse_valid(input, spatial_shape,
# datas)
# else:
# if self.indice_key is not None and datas is not None:
# outids = datas.out_indices
# indice_pairs = datas.indice_pairs
# indice_pair_num = datas.indice_pair_num
# assert self.subm, "only support reuse subm indices"
# self._check_subm_reuse_valid(input, spatial_shape,
# datas)
# else:
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# try:
# outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
# indices, batch_size, spatial_shape, algo,
# self.kernel_size, self.stride, self.padding,
# self.dilation, self.output_padding, self.subm,
# self.transposed)
# except Exception as e:
# msg = "[Exception|native_pair]"
# msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
# msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
# msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
# msg += f"transpose={self.transposed}"
# print(msg, file=sys.stderr)
# spconv_save_debug_data(indices)
# raise e
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[
# self.name]["indice_gen_time"].append(interval)
# indice_data = IndiceData(outids,
# indices,
# indice_pairs,
# indice_pair_num,
# spatial_shape,
# out_spatial_shape,
# is_subm=self.subm,
# algo=algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation)
# if self.indice_key is not None:
# msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
# assert self.indice_key not in indice_dict, msg
# indice_dict[self.indice_key] = indice_data
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# indice_pairs_calc = indice_pairs
# if indice_pairs.device != features.device:
# indice_pairs_calc = indice_pairs.to(features.device)
# if self.subm:
# out_features = Fsp.indice_subm_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# if self.inverse:
# out_features = Fsp.indice_inverse_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# out_features = Fsp.indice_conv(
# features,
# weight,
# indice_pairs_calc,
# indice_pair_num,
# outids.shape[0],
# algo,
# input._timer,
# bias_for_infer,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# datas = input.find_indice_pair(self.indice_key)
# if datas is not None:
# assert isinstance(datas, ImplicitGemmIndiceData)
# if self.inverse:
# assert datas is not None and self.indice_key is not None
# assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
# outids = datas.indices
# pair_fwd = datas.pair_bwd
# pair_bwd = datas.pair_fwd
# pair_mask_fwd_splits = datas.pair_mask_bwd_splits
# pair_mask_bwd_splits = datas.pair_mask_fwd_splits
# mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
# mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
# masks = datas.masks
# out_spatial_shape = datas.spatial_shape
# # assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
# self._check_inverse_reuse_valid(input, spatial_shape,
# datas)
# else:
# if self.indice_key is not None and datas is not None:
# outids = datas.out_indices
# pair_fwd = datas.pair_fwd
# pair_bwd = datas.pair_bwd
# pair_mask_fwd_splits = datas.pair_mask_fwd_splits
# pair_mask_bwd_splits = datas.pair_mask_bwd_splits
# mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
# mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
# masks = datas.masks
# assert self.subm, "only support reuse subm indices"
# self._check_subm_reuse_valid(input, spatial_shape,
# datas)
# else:
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# with input._timer.namespace("gen_pairs"):
# # we need to gen bwd indices for regular conv
# # because it may be inversed.
# try:
# res = ops.get_indice_pairs_implicit_gemm(
# indices,
# batch_size,
# spatial_shape,
# algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation,
# out_padding=self.output_padding,
# subm=self.subm,
# transpose=self.transposed,
# is_train=(not self.subm) or self.training,
# alloc=input.thrust_allocator,
# timer=input._timer)
# except Exception as e:
# msg = "[Exception|implicit_gemm_pair]"
# msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
# msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
# msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
# msg += f"transpose={self.transposed}"
# print(msg, file=sys.stderr)
# spconv_save_debug_data(indices)
# raise e
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[
# self.name]["indice_gen_time"].append(interval)
# outids = res[0]
# num_inds_per_loc = res[1]
# pair_fwd = res[2]
# pair_bwd = res[3]
# pair_mask_fwd_splits = res[4]
# pair_mask_bwd_splits = res[5]
# mask_argsort_fwd_splits = res[6]
# mask_argsort_bwd_splits = res[7]
# masks = res[8]
# if self.indice_key is not None:
# indice_data = ImplicitGemmIndiceData(
# outids,
# indices,
# pair_fwd,
# pair_bwd,
# pair_mask_fwd_splits=pair_mask_fwd_splits,
# pair_mask_bwd_splits=pair_mask_bwd_splits,
# mask_argsort_fwd_splits=mask_argsort_fwd_splits,
# mask_argsort_bwd_splits=mask_argsort_bwd_splits,
# masks=masks,
# is_subm=self.subm,
# spatial_shape=spatial_shape,
# out_spatial_shape=out_spatial_shape,
# algo=algo,
# ksize=self.kernel_size,
# stride=self.stride,
# padding=self.padding,
# dilation=self.dilation)
# msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
# assert self.indice_key not in indice_dict, msg
# indice_dict[self.indice_key] = indice_data
# if input.benchmark:
# torch.cuda.synchronize()
# t = time.time()
# num_activate_out = outids.shape[0]
# weight_cur = weight
# bias_cur = bias_for_infer
# # if self.enable_int8_test_mode:
# # assert features.dtype == torch.int8, "in int8 test mode, feature must be int8"
# # weight_cur = self._int8_weight
# # bias_cur = self._int8_bias
# if self.training:
# out_features = Fsp.implicit_gemm(
# features, weight_cur, pair_fwd, pair_bwd,
# pair_mask_fwd_splits, pair_mask_bwd_splits,
# mask_argsort_fwd_splits, mask_argsort_bwd_splits,
# num_activate_out, masks, self.training, self.subm,
# input._timer, self.fp32_accum,
# bias_cur,
# self.act_alpha,
# self.act_beta,
# self.act_type)
# else:
# output_dtype = None
# if output_scale is None:
# output_dtype = weight.dtype
# out_features, _, _ = ops.implicit_gemm(
# features, weight_cur, pair_fwd, pair_mask_fwd_splits,
# mask_argsort_fwd_splits,
# num_activate_out, masks, self.training, self.subm,
# input._timer, self.fp32_accum,
# bias_cur,
# self.act_alpha,
# self.act_beta,
# self.act_type,
# # TODO do we really need output scale to scale bias in kernel?
# 1.0 if output_scale is None else output_scale, # output_scale
# channel_scale, # scale
# output_add=add_input.features if add_input is not None else None,
# output_add_scale=output_add_scale,
# output_dtype=output_dtype)
# if bias_for_training is not None:
# out_features += bias_for_training
# if input.benchmark:
# torch.cuda.synchronize()
# interval = time.time() - t
# out_tensor.benchmark_record[self.name]["time"].append(interval)
# out_tensor.benchmark_record[self.name]["num_points"].append(
# features.shape[0])
# out_tensor.benchmark_record[self.name]["num_out_points"].append(
# out_features.shape[0])
# if not self.subm and not self.inverse and self.record_voxel_count:
# if hasattr(self, _MAX_NUM_VOXELS_DURING_TRAINING):
# ops.maximum_value_int_(
# getattr(self, _MAX_NUM_VOXELS_DURING_TRAINING),
# outids.shape[0])
# out_tensor = out_tensor.replace_feature(out_features)
# out_tensor.indices = outids
# out_tensor.indice_dict = indice_dict
# out_tensor.spatial_shape = out_spatial_shape
# if add_input is not None and not is_int8:
# # in int8, we apply add + act in kernel.
# out_tensor = out_tensor.replace_feature(_apply_act(out_tensor.features + add_input.features, self.act_type, self.act_alpha, self.act_beta))
# return out_tensor
# def _check_subm_reuse_valid(self, inp: SparseConvTensor,
# spatial_shape: List[int],
# datas: Union[ImplicitGemmIndiceData,
# IndiceData]):
# assert datas.is_subm, "only support reuse subm indices"
# if self.kernel_size != datas.ksize:
# raise ValueError(
# f"subm with same indice_key must have same kernel"
# f" size, expect {datas.ksize}, this layer {self.kernel_size}")
# if self.dilation != datas.dilation:
# raise ValueError(
# f"subm with same indice_key must have same dilation"
# f", expect {datas.dilation}, this layer {self.dilation}")
# if inp.spatial_shape != datas.spatial_shape:
# raise ValueError(
# f"subm with same indice_key must have same spatial structure"
# f", expect {datas.spatial_shape}, input {spatial_shape}")
# if inp.indices.shape[0] != datas.indices.shape[0]:
# raise ValueError(
# f"subm with same indice_key must have same num of indices"
# f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
# )
# def _check_inverse_reuse_valid(self, inp: SparseConvTensor,
# spatial_shape: List[int],
# datas: Union[ImplicitGemmIndiceData,
# IndiceData]):
# if self.kernel_size != datas.ksize:
# raise ValueError(
# f"Inverse with same indice_key must have same kernel"
# f" size, expect {datas.ksize}, this layer {self.kernel_size}, "
# "please check Inverse Convolution in docs/USAGE.md.")
# if inp.spatial_shape != datas.out_spatial_shape:
# raise ValueError(
# f"Inverse with same indice_key must have same spatial structure (spatial shape)"
# f", expect {datas.spatial_shape}, input {spatial_shape}, "
# "please check Inverse Convolution in docs/USAGE.md.")
# if inp.indices.shape[0] != datas.out_indices.shape[0]:
# raise ValueError(
# f"Inverse with same indice_key must have same num of indices"
# f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, "
# "please check Inverse Convolution in ."
# )
class SparseConv1d(SparseConvolution):
def __init__(self,
in_channels,
......@@ -1191,3 +1671,24 @@ class SubMConv4d(SparseConvolution):
algo=algo,
fp32_accum=fp32_accum,
name=name)
DEFAULT_SPARSE_CONV_TYPES = {
SubMConv1d,
SubMConv2d,
SubMConv3d,
SubMConv4d,
SparseConv1d,
SparseConv2d,
SparseConv3d,
SparseConv4d,
SparseInverseConv1d,
SparseInverseConv2d,
SparseInverseConv3d,
SparseInverseConv4d,
SparseConvTranspose1d,
SparseConvTranspose2d,
SparseConvTranspose3d,
SparseConvTranspose4d,
}
......@@ -128,7 +128,7 @@ def scatter_nd(indices, updates, shape):
return ret
# ProxyableClassMeta is used for TensorRT conversion in future.
# ProxyableClassMeta is used for torch.fx
class SparseConvTensor(metaclass=SpConvTensorMeta):
def __init__(self,
features: torch.Tensor,
......@@ -181,8 +181,15 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer)
self.force_algo = force_algo
# for simple int8 torch inference
self.int8_scale: Optional[float] = None
@property
def is_quantized(self):
return self.features.dtype == torch.qint8
def q_scale(self):
if self.is_quantized:
return self.features.q_scale()
raise ValueError("sparse tensor must be quantized")
def replace_feature(self, feature: torch.Tensor):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
......@@ -220,7 +227,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
x must be NHWC tensor, channel last
"""
x_sp = x.to_sparse(x.ndim - 1)
spatial_shape = list(x_sp.shape[1:-1])
spatial_shape = x_sp.shape[1:-1]
batch_size = x_sp.shape[0]
indices_th = x_sp.indices().permute(1, 0).contiguous().int()
features_th = x_sp.values()
......
......@@ -34,6 +34,7 @@ _TORCH_DTYPE_TO_TV = {
torch.int8: tv.int8,
torch.int16: tv.int16,
torch.uint8: tv.uint8,
torch.qint8: tv.int8,
}
_TORCH_UINT_WORKAROUNDS = {
......@@ -42,6 +43,8 @@ _TORCH_UINT_WORKAROUNDS = {
tv.uint64: tv.int64
}
_TH_QTYPES = {torch.qint8}
_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TV_DTYPE_TO_TORCH.update({
tv.uint32: torch.int32,
......@@ -50,6 +53,9 @@ _TV_DTYPE_TO_TORCH.update({
})
_TV_DTYPE_TO_TORCHQ = _TV_DTYPE_TO_TORCH.copy()
_TV_DTYPE_TO_TORCHQ[tv.int8] = torch.qint8
_ALL_INTS = {
tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
tv.uint16
......@@ -105,23 +111,31 @@ def get_arch():
class TorchAllocator(ExternalAllocator):
def __init__(self, gpudevice: torch.device) -> None:
def __init__(self, gpudevice: torch.device, is_quantized: bool = False) -> None:
super().__init__()
self.gpudevice = gpudevice
self.cpudevice = torch.device("cpu")
self.allocated: Dict[Union[str, int], torch.Tensor] = {}
self.is_quantized = is_quantized
self._tv_dtype_to_torch = _TV_DTYPE_TO_TORCH
if is_quantized:
self._tv_dtype_to_torch = _TV_DTYPE_TO_TORCHQ
def zeros(self, name: str, shape: List[int], dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> tv.Tensor:
# TODO free memory by name if its already free by pointer.
# provide a name if you want to access it after c++ function exit.
dtype_bkp = dtype
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
th_dtype = self._tv_dtype_to_torch[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
ten = torch.zeros(shape, dtype=th_dtype, device=dev)
if self.is_quantized:
ten = torch._empty_affine_quantized(shape, scale=scale, zero_point=0, dtype=th_dtype, device=dev)
else:
ten = torch.empty(shape, dtype=th_dtype, device=dev).zero_()
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory:
......@@ -129,13 +143,16 @@ class TorchAllocator(ExternalAllocator):
return ten_tv
def empty(self, name: str, shape: List[int], dtype: int,
device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> tv.Tensor:
dtype_bkp = dtype
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
th_dtype = self._tv_dtype_to_torch[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
if self.is_quantized:
ten = torch._empty_affine_quantized(shape, scale=scale, zero_point=0, dtype=th_dtype, device=dev)
else:
ten = torch.empty(shape, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
......@@ -148,11 +165,13 @@ class TorchAllocator(ExternalAllocator):
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
dtype_bkp = dtype
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
th_dtype = self._tv_dtype_to_torch[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
if self.is_quantized:
assert th_dtype not in _TH_QTYPES
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
......@@ -165,11 +184,13 @@ class TorchAllocator(ExternalAllocator):
if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
raise NotImplementedError("you can't use full for unsigned dtypes")
dtype_bkp = dtype
th_dtype = _TV_DTYPE_TO_TORCH[dtype]
th_dtype = self._tv_dtype_to_torch[dtype]
if device == -1:
dev = self.cpudevice
else:
dev = self.gpudevice
if self.is_quantized:
assert th_dtype not in _TH_QTYPES
ten = torch.full(shape, value, dtype=th_dtype, device=dev)
ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
self.allocated[ten_tv.byte_pointer()] = ten
......
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backend_cfg import (get_spconv_backend_config,
get_spconv_prepare_custom_config,
get_spconv_convert_custom_config)
from .fake_q import (get_default_spconv_trt_ptq_qconfig,
get_default_spconv_trt_qat_qconfig)
from .qmapping import (get_spconv_fmod_to_qat_mapping,
get_spconv_qat_to_static_mapping)
from collections import namedtuple
from typing import List, Dict, Union, Type, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from torch.ao.quantization.backend_config import (BackendConfig,
BackendPatternConfig,
DTypeConfig, ObservationType,
get_tensorrt_backend_config)
from torch.ao.quantization.fx.custom_config import (ConvertCustomConfig,
FuseCustomConfig,
PrepareCustomConfig)
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
import spconv.pytorch.conv as sconvmod
import spconv.pytorch.quantization.intrinsic as snni
import spconv.pytorch.quantization.intrinsic.qat as snniqat
import spconv.pytorch.quantization.intrinsic.quantized as snniq
import spconv.pytorch.quantization.quantized as snnq
import spconv.pytorch.quantization.quantized.reference as snnqr
from spconv.pytorch.constants import PYTORCH_VERSION
from spconv.pytorch.quantization.fuse_mapping import (fuse_conv_bn,
fuse_conv_bn_relu)
from spconv.pytorch import ToDense
_SpConvMetadataDef = namedtuple(
"_ConvMetadata",
["root", "bn", "reference",
"fused_conv_relu", "fused_conv_bn", "fused_conv_bn_relu",
"qat", "relu_qat", "bn_qat", "bn_relu_qat"])
_SpConvMetadatas: List[_SpConvMetadataDef] = []
for t in sconvmod.DEFAULT_SPARSE_CONV_TYPES:
_SpConvMetadatas.append(_SpConvMetadataDef(t, nn.BatchNorm1d,
snnqr.SpConv,
snni.SpconvReLUNd, snni.SpconvBnNd, snni.SpconvBnReLUNd,
snniqat.SparseConv, snniqat.SparseConvReLU, snniqat.SparseConvBn, snniqat.SparseConvBnReLU))
_SpConvMetadatas.append(_SpConvMetadataDef(
sconvmod.SparseConvolution, nn.BatchNorm1d,
snnqr.SpConv,
snni.SpconvReLUNd, snni.SpconvBnNd, snni.SpconvBnReLUNd,
snniqat.SparseConv, snniqat.SparseConvReLU, snniqat.SparseConvBn, snniqat.SparseConvBnReLU))
def _sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules
"""
def fuser_method(is_qat, m1, m2):
return sequential(m1, m2)
return fuser_method
# new cfg remove reverse pattern.
def _get_spconv_configs(dtype_configs):
"""
Return all configs related to conv modules and ops.
"""
conv_configs = []
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
if PYTORCH_VERSION <= [1, 13, 1]:
from torch.ao.quantization.fuser_method_mappings import (
reverse2, reverse3, reverse_sequential_wrapper2)
for convs in _SpConvMetadatas:
# (1) Single conv modules/functions
# -----------------------------------
# conv module
conv_configs.append(
BackendPatternConfig(convs.root)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference)
.set_qat_module(convs.qat))
# conv qat module
conv_configs.append(
BackendPatternConfig(convs.qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs.append(
BackendPatternConfig((torch.nn.ReLU, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# conv relu fusion, conv module + functional relu
conv_configs.append(
BackendPatternConfig((F.relu, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference)
.set_qat_module(convs.relu_qat))
# conv relu, qat fused module
conv_configs.append(
BackendPatternConfig(convs.relu_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# 2.3 functional conv + relu configs
# fused conv relu
conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.relu_qat))
conv_configs.append(
BackendPatternConfig(convs.relu_qat)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs.append(
BackendPatternConfig((convs.bn, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse2(fuse_conv_bn))
.set_fused_module(convs.fused_conv_bn))
# conv + bn + relu module fusion
conv_configs.append(
BackendPatternConfig((nn.ReLU, (convs.bn, convs.root)))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(reverse3(fuse_conv_bn_relu))
.set_fused_module(convs.fused_conv_bn_relu))
# conv + bn + relu functional fusion
conv_configs.append(
BackendPatternConfig((F.relu, (convs.bn, convs.root)))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_fuser_method(reverse3(fuse_conv_bn_relu))
.set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.bn_qat))
# fused conv bn relu
conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn_relu)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.bn_relu_qat))
# conv bn, qat fused module
conv_configs.append(
BackendPatternConfig(convs.bn_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# conv bn relu, qat fused module
conv_configs.append(
BackendPatternConfig(convs.bn_relu_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (4) conv transpose and its fusion
# 4.1 conv transpose config
# conv_configs.append(
# BackendPatternConfig(convs.transpose)
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# # 4.2 conv transpose + bn fusion
# conv_configs.append(
# BackendPatternConfig((convs.bn, convs.transpose))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(reverse2(fuse_conv_bn))
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
return conv_configs
else:
for convs in _SpConvMetadatas:
# (1) Single conv modules/functions
# -----------------------------------
# conv module
conv_configs.append(
BackendPatternConfig(convs.root)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference)
.set_qat_module(convs.qat))
# conv qat module
conv_configs.append(
BackendPatternConfig(convs.qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (2) Conv + relu
# -----------------
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs.append(
BackendPatternConfig((convs.root, torch.nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# conv relu fusion, conv module + functional relu
conv_configs.append(
BackendPatternConfig((convs.root, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference)
.set_qat_module(convs.relu_qat))
# conv relu, qat fused module
conv_configs.append(
BackendPatternConfig(convs.relu_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# fused conv relu
conv_configs.append(
BackendPatternConfig(convs.fused_conv_relu)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.relu_qat))
conv_configs.append(
BackendPatternConfig(convs.relu_qat)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# (3) Conv + batchnorm (+ relu)
# -------------------------------
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs.append(
BackendPatternConfig((convs.root, convs.bn))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(fuse_conv_bn)
.set_fused_module(convs.fused_conv_bn))
# conv + bn + relu module fusion
conv_configs.append(
BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_fuser_method(fuse_conv_bn_relu)
.set_fused_module(convs.fused_conv_bn_relu))
# conv + bn + relu functional fusion
conv_configs.append(
BackendPatternConfig((convs.root, convs.bn, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
.set_fuser_method(fuse_conv_bn_relu)
.set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well
# 3.2 conv + bn (+ relu) fused module configs
# fused conv bn
conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.bn_qat))
# fused conv bn relu
conv_configs.append(
BackendPatternConfig(convs.fused_conv_bn_relu)
.set_dtype_configs(dtype_configs) # noqa: E131
.set_qat_module(convs.bn_relu_qat))
# conv bn, qat fused module
conv_configs.append(
BackendPatternConfig(convs.bn_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# conv bn relu, qat fused module
conv_configs.append(
BackendPatternConfig(convs.bn_relu_qat)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
.set_root_module(convs.root)
.set_reference_quantized_module(convs.reference))
# # (4) conv transpose and its fusion
# # 4.1 conv transpose config
# conv_configs.append(
# BackendPatternConfig(convs.transpose)
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
# # 4.2 conv transpose + bn fusion
# conv_configs.append(
# BackendPatternConfig((convs.transpose, convs.bn))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(fuse_conv_bn)
# .set_root_module(convs.transpose)
# .set_reference_quantized_module(convs.transpose_reference))
return conv_configs
weighted_op_qint8_dtype_config = DTypeConfig(
input_dtype=torch.qint8,
output_dtype=torch.qint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
conv_dtype_configs = [
weighted_op_qint8_dtype_config,
]
_to_dense_cfg = (BackendPatternConfig(ToDense)
.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT))
backend_config = get_tensorrt_backend_config() \
.set_backend_pattern_configs(_get_spconv_configs(conv_dtype_configs) + [_to_dense_cfg])
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU),
}
def get_spconv_backend_config():
return backend_config
def get_spconv_prepare_custom_config():
cfg = PrepareCustomConfig()
cfg.non_traceable_module_classes = [*sconvmod.DEFAULT_SPARSE_CONV_TYPES]
return cfg
def get_spconv_convert_custom_config():
cfg = ConvertCustomConfig()
cfg.set_observed_to_quantized_mapping(snni.SpconvReLUNd, snniq.SparseConvReLU)
# cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU)
return cfg
\ No newline at end of file
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize, fused_wt_fake_quant_range_neg_127_to_127
from spconv.pytorch.core import SparseConvTensor
import torch
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MovingAverageMinMaxObserver
from torch.ao.quantization.fake_quantize import (
FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize, FakeQuantize,
default_fused_per_channel_wt_fake_quant, default_weight_fake_quant, default_per_channel_weight_fake_quant)
from torch.ao.quantization.observer import (HistogramObserver,
MovingAverageMinMaxObserver,
default_weight_observer,
default_placeholder_observer,
default_per_channel_weight_observer)
from torch.ao.quantization.qconfig import QConfig, QConfigAny, default_reuse_input_qconfig
from torch.ao.quantization.qconfig_mapping import QConfigMapping, _FIXED_QPARAMS_OP_TO_OBSERVER
from typing import Any, Callable, Dict, Tuple, Union, List
from torch.ao.quantization import get_default_qconfig
from spconv.pytorch.core import SparseConvTensor
__all__ = ["get_default_spconv_trt_ptq_qconfig", "get_default_spconv_trt_qat_qconfig"]
class SparseFusedMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
def forward(self, input:SparseConvTensor):
def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
if isinstance(input, SparseConvTensor):
# add lines to support spconv
x = input.features
res_features = super().forward(x)
return input.replace_feature(res_features)
else:
return super().forward(input)
# class SparseMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
# def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
# if isinstance(input, SparseConvTensor):
# # add lines to support spconv
# x = input.features
# res_features = super().forward(x)
# return input.replace_feature(res_features)
# else:
# return super().forward(input)
class SparseHistogramObserver(HistogramObserver):
def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
if isinstance(input, SparseConvTensor):
# add lines to support spconv
x = input.features
res_features = super().forward(x)
return input.replace_feature(res_features)
else:
return super().forward(input)
default_symmetric_spconv_ptq_qconfig = QConfig(
activation=SparseHistogramObserver.with_args(quant_min=-128,
quant_max=127,
dtype=torch.qint8,
reduce_range=False,
qscheme=torch.per_tensor_symmetric,
eps=2 ** -12),
weight=default_per_channel_weight_observer)
# default_symmetric_ptq_qconfig = QConfig(
# activation=HistogramObserver.with_args(quant_min=-128,
# quant_max=127,
# dtype=torch.qint8,
# reduce_range=False,
# eps=2 ** -12),
# weight=default_per_channel_weight_observer)
default_symmetric_spconv_qat_qconfig = QConfig(
activation=SparseFusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
......@@ -19,5 +69,78 @@ default_symmetric_spconv_qat_qconfig = QConfig(
quant_max=127,
dtype=torch.qint8,
reduce_range=False,
qscheme=torch.per_tensor_symmetric,
eps=2 ** -12),
weight=fused_wt_fake_quant_range_neg_127_to_127)
weight=default_fused_per_channel_wt_fake_quant)
def get_default_spconv_trt_ptq_qconfig(backend, version):
return default_symmetric_spconv_ptq_qconfig
def get_default_spconv_trt_qat_qconfig(backend, version):
return default_symmetric_spconv_qat_qconfig
def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "x86", version: int = 0) -> QConfigMapping:
"""
From torch.ao.quantization.qconfig_mapping
Return the default QConfigMapping for the given quantization type and backend.
"""
# get_default_qconfig(backend, version)
if is_qat:
qconfig = get_default_spconv_trt_qat_qconfig(backend, version)
else:
# qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False, dtype=torch.qint8),
# weight=default_per_channel_weight_observer)
qconfig = get_default_spconv_trt_ptq_qconfig(backend, version)
default_weight = default_weight_fake_quant if is_qat else default_weight_observer
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
# so we have to modify the weight observer to default_weight_observer or another
# per tensor supported observer.
# see https://github.com/pytorch/pytorch/issues/47535
if backend in ("fbgemm", "x86"):
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
else:
qconfig_transpose = qconfig
# currently layernorm only supports float weights
# we have to add this because otherwise there will be a extra quantize-dequantize pair
qconfig_layernorm = QConfig(activation=qconfig.activation, weight=default_placeholder_observer)
qconfig_mapping = QConfigMapping() \
.set_global(qconfig) \
.set_object_type("reshape", default_reuse_input_qconfig) \
.set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
.set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \
.set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \
# Use special observers for ops with fixed qparams
fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}
for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items():
if observer in fixed_qparams_observer_to_qconfig:
fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer]
else:
if is_qat:
activation = FixedQParamsFakeQuantize.with_args(observer=observer)
else:
activation = observer
fixed_qparams_qconfig = QConfig(activation=activation, weight=default_weight)
fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig
qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig)
# QConfig for fused ops for onednn backend
# Separate ops are required to have the same qconfig as fused ops
# TODO: we should be able to configure qconfig for patterns
if backend == 'onednn':
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig) \
.set_object_type(torch.nn.LeakyReLU, qconfig) \
.set_object_type(torch.nn.functional.leaky_relu, qconfig) \
.set_object_type(torch.nn.Tanh, qconfig) \
.set_object_type(torch.nn.functional.tanh, qconfig)
return qconfig_mapping
......@@ -3,9 +3,9 @@ import torch.nn as nn
import spconv.pytorch as spconv
from .utils import fuse_spconv_bn_eval
from . import intrinsic as snni
from .conv_fused import SparseConvBn, SparseConvBnReLU
def fuse_conv_bn(conv, bn):
from .intrinsic.qat.modules import SparseConvBn, SparseConvBnReLU, SparseConvBnAddReLU
from spconv.pytorch.conv import DEFAULT_SPARSE_CONV_TYPES
def fuse_conv_bn(is_qat, conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
......@@ -22,18 +22,10 @@ def fuse_conv_bn(conv, bn):
"Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map = {
spconv.SubMConv1d: snni.SpconvBnNd,
spconv.SparseConv1d: snni.SpconvBnNd,
spconv.SparseInverseConv1d: snni.SpconvBnNd,
spconv.SubMConv2d: snni.SpconvBnNd,
spconv.SparseConv2d: snni.SpconvBnNd,
spconv.SparseInverseConv2d: snni.SpconvBnNd,
spconv.SubMConv3d: snni.SpconvBnNd,
spconv.SparseConv3d: snni.SpconvBnNd,
spconv.SparseInverseConv3d: snni.SpconvBnNd,
k: snni.SpconvBnNd for k in DEFAULT_SPARSE_CONV_TYPES
}
if conv.training:
if is_qat:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
......@@ -45,7 +37,7 @@ def fuse_conv_bn(conv, bn):
else:
return fuse_spconv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu):
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
......@@ -61,17 +53,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[spconv.SparseSequential]] = None
if conv.training:
if is_qat:
map_to_fused_module_train = {
spconv.SubMConv1d: snni.SpconvBnReLUNd,
spconv.SparseConv1d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv1d: snni.SpconvBnReLUNd,
spconv.SubMConv2d: snni.SpconvBnReLUNd,
spconv.SparseConv2d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv2d: snni.SpconvBnReLUNd,
spconv.SubMConv3d: snni.SpconvBnReLUNd,
spconv.SparseConv3d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv3d: snni.SpconvBnReLUNd,
k: snni.SpconvBnReLUNd for k in DEFAULT_SPARSE_CONV_TYPES
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
......@@ -83,15 +67,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
map_to_fused_module_eval = {
spconv.SubMConv1d: snni.SpconvReLUNd,
spconv.SparseConv1d: snni.SpconvReLUNd,
spconv.SparseInverseConv1d: snni.SpconvReLUNd,
spconv.SubMConv2d: snni.SpconvReLUNd,
spconv.SparseConv2d: snni.SpconvReLUNd,
spconv.SparseInverseConv2d: snni.SpconvReLUNd,
spconv.SubMConv3d: snni.SpconvReLUNd,
spconv.SparseConv3d: snni.SpconvReLUNd,
spconv.SparseInverseConv3d: snni.SpconvReLUNd,
k: snni.SpconvReLUNd for k in DEFAULT_SPARSE_CONV_TYPES
}
fused_module = map_to_fused_module_eval.get(type(conv), None)
if fused_module is not None:
......@@ -100,31 +76,28 @@ def fuse_conv_bn_relu(conv, bn, relu):
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
}
# DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
# (spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
# (spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
# }
# def get_spconv_fuse_method_mapping():
# return DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD
# Default map for swapping float module to qat modules
DEFAULT_SPCONV_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
# nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
snni.SpconvBnNd: SparseConvBn,
snni.SpconvBnReLUNd: SparseConvBnReLU,
}
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modules import SpconvBnNd, SpconvBnReLUNd, SpconvBnAddReLUNd, SpconvReLUNd
......@@ -4,9 +4,28 @@ from torch.nn.utils.parametrize import type_before_parametrizations
import torch.ao.nn.intrinsic as nni
from spconv.pytorch.conv import SparseConvolution
from spconv.pytorch.modules import is_spconv_module
from spconv.pytorch.core import SparseConvTensor
class _FusedSparseModule(nni._FusedModule):
def forward(self, input):
for k, module in self._modules.items():
if is_spconv_module(module): # use SpConvTensor as input
if isinstance(input, list):
input = module(input)
else:
# assert isinstance(input, spconv.SparseConvTensor)
# self._sparity_dict[k] = input.sparity
input = module(input)
else:
if isinstance(input, SparseConvTensor):
if input.indices.shape[0] != 0:
input = input.replace_feature(module(input.features))
else:
input = module(input)
return input
class SpconvReLUNd(nni._FusedModule):
class SpconvReLUNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
......@@ -15,7 +34,7 @@ class SpconvReLUNd(nni._FusedModule):
type(conv), type(relu))
super().__init__(conv, relu)
class SpconvBnNd(nni._FusedModule):
class SpconvBnNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
......@@ -24,8 +43,16 @@ class SpconvBnNd(nni._FusedModule):
type(conv), type(bn))
super().__init__(conv, bn)
class SpconvBnReLUNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert isinstance(conv, SparseConvolution) and isinstance(bn, BatchNorm1d) and \
isinstance(relu, ReLU), 'Incorrect types for input modules{}{}{}' \
.format(type(conv), type(bn), type(relu))
super().__init__(conv, bn, relu)
class SpconvBnReLUNd(nni._FusedModule):
class SpconvBnAddReLUNd(_FusedSparseModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
......
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modules import SparseConvBn, SparseConvBnAddReLU, SparseConvBnReLU, SparseConv, SparseConvReLU
\ No newline at end of file
......@@ -6,8 +6,6 @@ import torch.ao.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.nn.functional as F
from torch.nn import init
from torch.nn.utils import fuse_conv_bn_weights
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.parameter import Parameter
from typing import TypeVar
from spconv.pytorch.conv import SparseConvolution
......@@ -16,9 +14,189 @@ from spconv.core import ConvAlgo
from cumm import tensorview as tv
from spconv.pytorch.core import SparseConvTensor
import spconv.pytorch.quantization.intrinsic as snni
from spconv.pytorch.quantization.utils import fuse_spconv_bn_weights
MOD = TypeVar('MOD', bound=SparseConvolution)
class _SparseConv(SparseConvolution, nni._FusedModule):
_FLOAT_MODULE = MOD
_FLOAT_CONV_MODULE = SparseConvolution
def __init__(self,
ndim: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: int = 1,
bias: bool = True,
subm: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
transposed: bool = False,
inverse: bool = False,
indice_key: Optional[str] = None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0,
name=None,
qconfig=None,
device=None,
dtype=None) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
SparseConvolution.__init__(self, ndim, in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
bias=False,
subm=subm,
output_padding=output_padding,
transposed=transposed,
inverse=inverse,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
act_type=act_type,
act_alpha=act_alpha,
act_beta=act_beta,
name=name, **factory_kwargs)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return self._conv_forward(False, input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod
def from_float(cls, mod):
r"""Create a qat module from a float module
Args:
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
)
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
if issubclass(type(mod), nni._FusedModule):
mod = mod[0] # type: ignore[index]
conv: SparseConvolution = mod
qconfig = mod.qconfig
qat_conv = cls(conv.ndim, conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups,
conv.bias is not None,
subm=conv.subm,
output_padding=conv.output_padding,
transposed=conv.transposed,
inverse=conv.inverse,
indice_key=conv.indice_key,
algo=conv.algo,
fp32_accum=conv.fp32_accum,
record_voxel_count=conv.record_voxel_count,
act_type=conv.act_type,
act_alpha=conv.act_alpha,
act_beta=conv.act_beta,
name=conv.name,
qconfig=qconfig)
qat_conv.weight = mod.weight
qat_conv.bias = mod.bias
return qat_conv
def to_float(self):
""" This works for both single qat conv, and the qat conv - relu modules
to convert the qat module to a floating point module
"""
cls = type(self)
conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
self.ndim,
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.groups,
self.bias is not None,
subm=self.subm,
output_padding=self.output_padding,
transposed=self.transposed,
inverse=self.inverse,
indice_key=self.indice_key,
algo=self.algo,
fp32_accum=self.fp32_accum,
record_voxel_count=self.record_voxel_count,
act_type=self.act_type,
act_alpha=self.act_alpha,
act_beta=self.act_beta,
name=self.name)
conv.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
conv.bias = torch.nn.Parameter(self.bias.detach())
# conv relu
if issubclass(cls, nni._FusedModule):
modules = [conv]
assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
modules.append(relu)
fused = cls._FLOAT_MODULE(*modules) # type: ignore[arg-type, attr-defined, operator]
fused.train(self.training)
return fused
else:
return conv
class SparseConv(_SparseConv, SparseConvolution):
r"""
A Conv1d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as :class:`~torch.nn.Conv1d`
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = SparseConvolution
_FLOAT_CONV_MODULE = SparseConvolution
@classmethod
def from_float(cls, mod):
return super().from_float(cls, mod)
class SparseConvReLU(SparseConv, nni._FusedModule):
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
We combined the interface of :class:`~torch.nn.Conv2d` and
:class:`~torch.nn.BatchNorm2d`.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = snni.SpconvReLUNd
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = None
_FLOAT_RELU_MODULE = nn.ReLU
def forward(self, input):
x = self._conv_forward(self.training, input, self.weight_fake_quant(self.weight), self.bias)
return x.replace_feature(F.relu(x.features))
@classmethod
def from_float(cls, mod):
return super(SparseConvReLU, cls).from_float(mod)
class _SparseConvBn(SparseConvolution, nni._FusedModule):
_version = 2
......@@ -34,7 +212,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: Union[int, List[int], Tuple[int, ...]] = 1,
groups: int = 1,
bias: bool = True,
subm: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
......@@ -143,7 +321,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
zero_bias = torch.zeros_like(self.bias, dtype=input.features.dtype)
else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.features.dtype)
conv_spt = self._conv_forward(input, scaled_weight, zero_bias)
conv_spt = self._conv_forward(self.training, input, scaled_weight, zero_bias)
conv = conv_spt.features
conv_orig = conv / scale_factor.reshape(bias_shape)
if self.bias is not None:
......@@ -396,7 +574,7 @@ class _SparseConvBn(SparseConvolution, nni._FusedModule):
if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
# fuse bn into conv
conv.weight, conv.bias = fuse_conv_bn_weights(
conv.weight, conv.bias = fuse_spconv_bn_weights(
conv.weight,
conv.bias,
self.bn.running_mean,
......@@ -473,3 +651,35 @@ class SparseConvBnReLU(_SparseConvBn):
@classmethod
def from_float(cls, mod):
return super(SparseConvBnReLU, cls).from_float(mod)
class SparseConvBnAddReLU(_SparseConvBn):
r"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE = snni.SpconvBnReLUNd # type: ignore[assignment]
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = nn.BatchNorm1d
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE = snni.SpconvReLUNd
def forward(self, input, add_input):
x = _SparseConvBn._forward(self, input, add_input)
return x.replace_feature(F.relu(x.features))
@classmethod
def from_float(cls, mod):
return super(SparseConvBnAddReLU, cls).from_float(mod)
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv_relu import *
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment