Commit 5b3fe9e7 authored by yan.yan's avatar yan.yan
Browse files

sync quantization code

parent e387ee74
...@@ -250,8 +250,8 @@ class NetV2(nn.Module): ...@@ -250,8 +250,8 @@ class NetV2(nn.Module):
) )
self.fc1 = nn.Linear(14 * 14 * 64, 128) self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10) self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25) # self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5) # self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub() self.quant = QuantStub()
self.dequant = DeQuantStub() self.dequant = DeQuantStub()
...@@ -263,10 +263,10 @@ class NetV2(nn.Module): ...@@ -263,10 +263,10 @@ class NetV2(nn.Module):
# create SparseConvTensor manually: see SparseConvTensor.from_dense # create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp) x = self.net(x_sp)
x = torch.flatten(x, 1) x = torch.flatten(x, 1)
x = self.dropout1(x) # x = self.dropout1(x)
x = self.fc1(x) x = self.fc1(x)
x = F.relu(x) x = F.relu(x)
x = self.dropout2(x) # x = self.dropout2(x)
x = self.fc2(x) x = self.fc2(x)
x = self.dequant(x) x = self.dequant(x)
output = F.log_softmax(x, dim=1) output = F.log_softmax(x, dim=1)
...@@ -474,22 +474,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device): ...@@ -474,22 +474,6 @@ def calibrate(args, model: torch.nn.Module, data_loader, device):
else: else:
output = model(image) output = model(image)
def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for node in m.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.quantize_per_tensor:
node.target = quantize_per_tensor
m.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
m.recompile()
return m
def is_dequantize_node(node): def is_dequantize_node(node):
...@@ -522,6 +506,23 @@ def remove_conv_add_dq(model: torch.fx.graph_module.GraphModule): ...@@ -522,6 +506,23 @@ def remove_conv_add_dq(model: torch.fx.graph_module.GraphModule):
# Graph is well-formed. # Graph is well-formed.
return model return model
def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for node in m.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.quantize_per_tensor:
node.target = quantize_per_tensor
m.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
m.recompile()
return m
def main(): def main():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
...@@ -561,7 +562,7 @@ def main(): ...@@ -561,7 +562,7 @@ def main():
help='random seed (default: 1)') help='random seed (default: 1)')
parser.add_argument('--sparse', parser.add_argument('--sparse',
action='store_true', action='store_true',
default=True, default=False,
help='use sparse conv network instead of dense') help='use sparse conv network instead of dense')
parser.add_argument( parser.add_argument(
'--log-interval', '--log-interval',
...@@ -588,7 +589,7 @@ def main(): ...@@ -588,7 +589,7 @@ def main():
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu") qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
if args.sparse: if args.sparse:
model = ResidualNetPTQ().to(device) model = NetV2().to(device)
else: else:
model = NetDense().to(device) model = NetDense().to(device)
...@@ -647,7 +648,7 @@ def main(): ...@@ -647,7 +648,7 @@ def main():
# print(prepared_model) # print(prepared_model)
# calibrate: run model with some inputs # calibrate: run model with some inputs
# calibrate(args, prepared_model, test_loader, qdevice) calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules # convert (ptq): replace intrinsic blocks with quantized modules
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg) converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
converted_model = transform_qdq(converted_model) converted_model = transform_qdq(converted_model)
......
...@@ -269,7 +269,8 @@ class SimpleGemm: ...@@ -269,7 +269,8 @@ class SimpleGemm:
def device_synchronize(self): def device_synchronize(self):
return GemmMainUnitTest.device_synchronize() return GemmMainUnitTest.device_synchronize()
def _compile_nvrtc_module(self, desp: GemmAlgoDesp): @staticmethod
def _compile_nvrtc_module(desp: GemmAlgoDesp):
params = algocore.get_gemm_param_from_desp(desp) params = algocore.get_gemm_param_from_desp(desp)
kernel = gen_gemm_kernels(params, SPCONV_NVRTC_MODE) kernel = gen_gemm_kernels(params, SPCONV_NVRTC_MODE)
kernel.namespace = "spconv" kernel.namespace = "spconv"
...@@ -808,7 +809,8 @@ class SimpleConv: ...@@ -808,7 +809,8 @@ class SimpleConv:
return desp.query_conv_workspace_size(mnk[0], mnk[1], mnk[2], splitk, return desp.query_conv_workspace_size(mnk[0], mnk[1], mnk[2], splitk,
kv) kv)
def _compile_nvrtc_module(self, desp: ConvAlgoDesp): @staticmethod
def _compile_nvrtc_module(desp: ConvAlgoDesp):
params = algocore.get_conv_param_from_desp(desp) params = algocore.get_conv_param_from_desp(desp)
kernel = gen_conv_kernels(params, SPCONV_NVRTC_MODE) kernel = gen_conv_kernels(params, SPCONV_NVRTC_MODE)
kernel.namespace = "spconv" kernel.namespace = "spconv"
...@@ -824,9 +826,8 @@ class SimpleConv: ...@@ -824,9 +826,8 @@ class SimpleConv:
cudadevrt = str(cudadevrt_p) cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel], mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt, cudadevrt_path=cudadevrt,
verbose=True, verbose=False,
custom_names=custom_names, custom_names=custom_names)
verbose_path="/home/yy/Projects/spconv-release/spconv/build/dev_nvrtc_int8")
mod.load() mod.load()
return mod, kernel return mod, kernel
...@@ -870,7 +871,7 @@ class SimpleConv: ...@@ -870,7 +871,7 @@ class SimpleConv:
inp = inp.clone() inp = inp.clone()
weight = weight.clone() weight = weight.clone()
output = output.clone() output = output.clone()
print(len(avail), inp.dtype, weight.dtype, output.dtype, bias.dtype, scale.dtype, bias.empty(), scale.empty()) # print(len(avail), inp.dtype, weight.dtype, output.dtype, bias.dtype, scale.dtype, bias.empty(), scale.empty())
channel_k = output.dim(1) channel_k = output.dim(1)
channel_c = inp.dim(1) channel_c = inp.dim(1)
weight = weight.view([channel_k, -1, channel_c]) weight = weight.view([channel_k, -1, channel_c])
......
...@@ -410,6 +410,7 @@ IMPLGEMM_SIMT_PARAMS = [ ...@@ -410,6 +410,7 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
] ]
IMPLGEMM_VOLTA_PARAMS = [ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
...@@ -618,12 +619,26 @@ IMPLGEMM_AMPERE_PARAMS = [ ...@@ -618,12 +619,26 @@ IMPLGEMM_AMPERE_PARAMS = [
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
] ]
if not SPCONV_INT8_DEBUG: if not SPCONV_INT8_DEBUG:
IMPLGEMM_AMPERE_PARAMS.extend([ IMPLGEMM_AMPERE_PARAMS.extend([
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (32, 32, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
...@@ -632,14 +647,28 @@ if not SPCONV_INT8_DEBUG: ...@@ -632,14 +647,28 @@ if not SPCONV_INT8_DEBUG:
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Ampere, GemmAlgo.Ampere,
TensorOp((16, 8, 32)), TensorOp((16, 8, 16)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 32, 32), (32, 32, 32),
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
...@@ -655,7 +684,7 @@ if not SPCONV_INT8_DEBUG: ...@@ -655,7 +684,7 @@ if not SPCONV_INT8_DEBUG:
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
...@@ -671,7 +700,7 @@ if not SPCONV_INT8_DEBUG: ...@@ -671,7 +700,7 @@ if not SPCONV_INT8_DEBUG:
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
...@@ -687,11 +716,27 @@ if not SPCONV_INT8_DEBUG: ...@@ -687,11 +716,27 @@ if not SPCONV_INT8_DEBUG:
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
[2, 3, 4], [2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16", "s8,s8,f32,s32,f32", "s8,s8,f32,s32,f16", "s8,s8,f16,s32,f32", "s8,s8,f16,s32,f16"], ["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Ampere,
TensorOp((16, 8, 32)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
[2, 3, 4],
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -1073,6 +1118,51 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -1073,6 +1118,51 @@ IMPLGEMM_TURING_PARAMS = [
if not SPCONV_INT8_DEBUG: if not SPCONV_INT8_DEBUG:
IMPLGEMM_TURING_PARAMS.extend([ IMPLGEMM_TURING_PARAMS.extend([
*gen_conv_params(ConvFwdAndBwdInput, (64, 32, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (32, 32, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32", "s8,s8,s8,s32,f16"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=True,
int8_inference=True),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
...@@ -1136,7 +1226,7 @@ if not SPCONV_INT8_DEBUG: ...@@ -1136,7 +1226,7 @@ if not SPCONV_INT8_DEBUG:
access_per_vector=1, access_per_vector=1,
is_nvrtc=True, is_nvrtc=True,
int8_inference=True), int8_inference=True),
# TODO 16,8,32 produce wrong result.
*gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
......
import platform import platform
from pathlib import Path from pathlib import Path
from typing import Union
import numpy as np import numpy as np
import torch import torch
......
...@@ -267,18 +267,30 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -267,18 +267,30 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
# return self.indices.shape[0] / np.prod( # return self.indices.shape[0] / np.prod(
# self.spatial_shape) / self.batch_size # self.spatial_shape) / self.batch_size
def __add__(self, other: "SparseConvTensor"): def __add__(self, other: Union["SparseConvTensor", torch.Tensor]):
assert isinstance(other, SparseConvTensor) assert isinstance(other, (SparseConvTensor, torch.Tensor))
return self.replace_feature(self.features + other.features) if isinstance(other, torch.Tensor):
other_features = other
else:
other_features = other.features
return self.replace_feature(self.features + other_features)
def __iadd__(self, other: "SparseConvTensor"): def __iadd__(self, other: Union["SparseConvTensor", torch.Tensor]):
assert isinstance(other, SparseConvTensor) assert isinstance(other, (SparseConvTensor, torch.Tensor))
self.features += other.features if isinstance(other, torch.Tensor):
other_features = other
else:
other_features = other.features
self.features += other_features
return self return self
def __radd__(self, other: "SparseConvTensor"): def __radd__(self, other: Union["SparseConvTensor", torch.Tensor]):
assert isinstance(other, SparseConvTensor) assert isinstance(other, (SparseConvTensor, torch.Tensor))
return other.replace_feature(self.features + other.features) if isinstance(other, torch.Tensor):
other_features = other
else:
other_features = other.features
return self.replace_feature(self.features + other_features)
def shadow_copy(self) -> "SparseConvTensor": def shadow_copy(self) -> "SparseConvTensor":
"""create a new spconv tensor with all member unchanged""" """create a new spconv tensor with all member unchanged"""
......
...@@ -137,6 +137,10 @@ class TorchAllocator(ExternalAllocator): ...@@ -137,6 +137,10 @@ class TorchAllocator(ExternalAllocator):
else: else:
ten = torch.empty(shape, dtype=th_dtype, device=dev).zero_() ten = torch.empty(shape, dtype=th_dtype, device=dev).zero_()
ten_tv = torch_tensor_to_tv(ten, dtype_bkp) ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
# if self.is_quantized:
# ctx = tv.Context()
# ctx.set_cuda_stream(stream)
# ten_tv.zero_(ctx)
self.allocated[ten_tv.byte_pointer()] = ten self.allocated[ten_tv.byte_pointer()] = ten
if name and not is_temp_memory: if name and not is_temp_memory:
self.allocated[name] = ten self.allocated[name] = ten
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import sys import sys
import time import time
from collections import OrderedDict from collections import OrderedDict
from typing import Union
import torch import torch
from torch import nn from torch import nn
...@@ -182,3 +183,24 @@ class SparseIdentity(nn.Identity): ...@@ -182,3 +183,24 @@ class SparseIdentity(nn.Identity):
if isinstance(input, spconv.SparseConvTensor): if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(super().forward(input.features)) return input.replace_feature(super().forward(input.features))
return super().forward(input) return super().forward(input)
class PrintTensorMeta(nn.Module):
def forward(self, x: Union[spconv.SparseConvTensor, torch.Tensor]):
if isinstance(x, torch.Tensor):
print(x.min(), x.max(), x.mean())
elif isinstance(x, spconv.SparseConvTensor):
ft = x.features
print(ft.min(), ft.max(), ft.mean())
return x
class PrintCurrentTime(nn.Module):
def __init__(self) -> None:
super().__init__()
self.first_time = time.time()
def forward(self, x, msg="", reset: bool = False):
if reset:
self.first_time = time.time()
torch.cuda.synchronize()
print(msg, time.time() - self.first_time)
return x
...@@ -16,7 +16,10 @@ from .backend_cfg import (get_spconv_backend_config, ...@@ -16,7 +16,10 @@ from .backend_cfg import (get_spconv_backend_config,
get_spconv_prepare_custom_config, get_spconv_prepare_custom_config,
get_spconv_convert_custom_config) get_spconv_convert_custom_config)
from .fake_q import (get_default_spconv_trt_ptq_qconfig, from .fake_q import (get_default_spconv_trt_ptq_qconfig,
get_default_spconv_trt_qat_qconfig) get_default_spconv_trt_qat_qconfig,
get_default_spconv_qconfig_mapping)
from .qmapping import (get_spconv_fmod_to_qat_mapping, from .qmapping import (get_spconv_fmod_to_qat_mapping,
get_spconv_qat_to_static_mapping) get_spconv_qat_to_static_mapping)
from .core import quantize_per_tensor from .core import quantize_per_tensor
from .graph import remove_conv_add_dq, transform_qdq
\ No newline at end of file
from collections import namedtuple
import operator import operator
from typing import Dict, List, Tuple, Type, Union from collections import namedtuple
from typing import Dict, List, Optional, Tuple, Type, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.ao.quantization.fx.match_utils import (
MatchAllNode, )
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
from torch.ao.quantization.backend_config import (BackendConfig, from torch.ao.quantization.backend_config import (BackendConfig,
BackendPatternConfig, BackendPatternConfig,
...@@ -15,9 +13,12 @@ from torch.ao.quantization.backend_config import (BackendConfig, ...@@ -15,9 +13,12 @@ from torch.ao.quantization.backend_config import (BackendConfig,
from torch.ao.quantization.fx.custom_config import (ConvertCustomConfig, from torch.ao.quantization.fx.custom_config import (ConvertCustomConfig,
FuseCustomConfig, FuseCustomConfig,
PrepareCustomConfig) PrepareCustomConfig)
from torch.ao.quantization.fx.match_utils import MatchAllNode
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized._reference as nnqr
import spconv.pytorch.conv as sconvmod import spconv.pytorch.conv as sconvmod
from spconv.pytorch.modules import SparseBatchNorm, SparseIdentity, SparseReLU, SparseSyncBatchNorm
import spconv.pytorch.quantization.intrinsic as snni import spconv.pytorch.quantization.intrinsic as snni
import spconv.pytorch.quantization.intrinsic.qat as snniqat import spconv.pytorch.quantization.intrinsic.qat as snniqat
import spconv.pytorch.quantization.intrinsic.quantized as snniq import spconv.pytorch.quantization.intrinsic.quantized as snniq
...@@ -25,10 +26,15 @@ import spconv.pytorch.quantization.quantized as snnq ...@@ -25,10 +26,15 @@ import spconv.pytorch.quantization.quantized as snnq
import spconv.pytorch.quantization.quantized.reference as snnqr import spconv.pytorch.quantization.quantized.reference as snnqr
from spconv.pytorch import ToDense from spconv.pytorch import ToDense
from spconv.pytorch.constants import PYTORCH_VERSION from spconv.pytorch.constants import PYTORCH_VERSION
from spconv.pytorch.modules import (PrintTensorMeta, SparseBatchNorm,
SparseIdentity, SparseReLU,
SparseSyncBatchNorm, PrintCurrentTime)
from spconv.pytorch.pool import ALL_POOL_LAYERS from spconv.pytorch.pool import ALL_POOL_LAYERS
from spconv.pytorch.quantization.fuse_mapping import (fuse_conv_bn, from spconv.pytorch.quantization.fuse_mapping import (fuse_conv_bn,
fuse_conv_bn_relu, fuse_conv_bn_add_relu,
fuse_conv_bn_add_relu) fuse_conv_bn_relu)
_SpConvMetadataDef = namedtuple("_ConvMetadata", [ _SpConvMetadataDef = namedtuple("_ConvMetadata", [
"root", "bn", "reference", "fused_conv_relu", "fused_conv_bn", "root", "bn", "reference", "fused_conv_relu", "fused_conv_bn",
...@@ -105,6 +111,31 @@ def _conv_res_relu_extra_inputs_getter(pattern): ...@@ -105,6 +111,31 @@ def _conv_res_relu_extra_inputs_getter(pattern):
return [extra_input] return [extra_input]
# def _get_custom_bn_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
# """
# Return all configs related to linear modules and ops.
# """
# observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
# linear_configs: List[BackendPatternConfig] = []
# # (3) Linear + batchnorm
# # ------------------------
# # 3.1 linear bn fusion
# if PYTORCH_VERSION[:2] <= [1, 13]:
# linear_configs.append(
# BackendPatternConfig((nn.Linear, nn.BatchNorm1d))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(fuse_linear_bn)
# .set_fused_module(nni.LinearBn1d))
# else:
# linear_configs.append(
# BackendPatternConfig((nn.Linear, nn.BatchNorm1d))
# .set_dtype_configs(dtype_configs) # noqa: E131
# .set_fuser_method(fuse_linear_bn)
# .set_fused_module(nni.LinearBn1d))
# return linear_configs
def _get_bn_spconv_configs(bn_cls, dtype_configs): def _get_bn_spconv_configs(bn_cls, dtype_configs):
""" """
Return all configs related to conv modules and ops. Return all configs related to conv modules and ops.
...@@ -526,6 +557,9 @@ def _get_share_observer_ops(dtype_configs): ...@@ -526,6 +557,9 @@ def _get_share_observer_ops(dtype_configs):
res.append(_to_dense_cfg) res.append(_to_dense_cfg)
res.append(iden_cfg) res.append(iden_cfg)
res.append(BackendPatternConfig(PrintCurrentTime).set_observation_type(
ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT).set_dtype_configs(
dtype_configs))
for p in ALL_POOL_LAYERS: for p in ALL_POOL_LAYERS:
_pool_cfg = (BackendPatternConfig(p).set_observation_type( _pool_cfg = (BackendPatternConfig(p).set_observation_type(
...@@ -551,31 +585,40 @@ conv_dtype_configs = [ ...@@ -551,31 +585,40 @@ conv_dtype_configs = [
weighted_op_qint8_dtype_config, weighted_op_qint8_dtype_config,
] ]
backend_config = get_tensorrt_backend_config() \
.set_backend_pattern_configs(_get_spconv_configs(conv_dtype_configs) + _get_share_observer_ops([non_weighted_op_qint8_dtype_config]))
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[ SPCONV_STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[
Type[nn.Module], Type[WeightedQuantizedModule]]] = { Type[nn.Module], Type[WeightedQuantizedModule]]] = {
snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU), snni.SpconvReLUNd: (snnqr.SpConv, snniq.SparseConvReLU),
snni.SpconvAddReLUNd: (snnqr.SpConv, snniq.SparseConvAddReLU), snni.SpconvAddReLUNd: (snnqr.SpConv, snniq.SparseConvAddReLU),
# use simple cumm i8 conv to implement linear
nni.LinearReLU: (nnqr.Linear, snniq.LinearPerChannelWeightReLU),
} }
SPCONV_STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], SPCONV_STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module],
Type[WeightedQuantizedModule]] = { Type[WeightedQuantizedModule]] = {
snnqr.SpConv: snnq.SparseConv, snnqr.SpConv: snnq.SparseConv,
nnqr.Linear: snnq.LinearPerChannelWeight,
} }
def get_spconv_backend_config(): def get_spconv_backend_config(additional_bns: Optional[List[Type[nn.Module]]] = None):
backend_config = get_tensorrt_backend_config() \
.set_backend_pattern_configs(_get_spconv_configs(conv_dtype_configs) + _get_share_observer_ops([non_weighted_op_qint8_dtype_config]))
if additional_bns is not None:
for bn_type in additional_bns:
backend_config.set_backend_pattern_configs(_get_bn_spconv_configs(bn_type, conv_dtype_configs))
return backend_config return backend_config
def get_spconv_prepare_custom_config(): def get_spconv_prepare_custom_config(additional_bns: Optional[List[Type[nn.Module]]] = None):
cfg = PrepareCustomConfig() cfg = PrepareCustomConfig()
cfg.non_traceable_module_classes = [*sconvmod.DEFAULT_SPARSE_CONV_TYPES] cfg.non_traceable_module_classes = [*sconvmod.DEFAULT_SPARSE_CONV_TYPES]
cfg.non_traceable_module_classes.extend( cfg.non_traceable_module_classes.extend(
[SparseReLU, SparseBatchNorm, SparseSyncBatchNorm]) [SparseReLU, SparseBatchNorm, SparseSyncBatchNorm, PrintTensorMeta,
PrintCurrentTime])
if additional_bns is not None:
cfg.non_traceable_module_classes.extend(additional_bns)
return cfg return cfg
......
...@@ -3,8 +3,12 @@ from typing import Union, List, Dict ...@@ -3,8 +3,12 @@ from typing import Union, List, Dict
import torch import torch
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from cumm import tensorview as tv
from spconv.pytorch.cppcore import get_current_stream, torch_tensor_to_tv
def quantize_per_tensor(ten: Union[Union[SparseConvTensor, torch.Tensor], List[Union[SparseConvTensor, torch.Tensor]]], scale, zero_point, dtype): def quantize_per_tensor(ten: Union[Union[SparseConvTensor, torch.Tensor], List[Union[SparseConvTensor, torch.Tensor]]], scale, zero_point, dtype):
# with tv.measure_and_print("quantize_per_tensor", stream=get_current_stream()):
if isinstance(ten, (list, tuple)): if isinstance(ten, (list, tuple)):
res = [] res = []
for i, v in enumerate(ten): for i, v in enumerate(ten):
...@@ -19,3 +23,15 @@ def quantize_per_tensor(ten: Union[Union[SparseConvTensor, torch.Tensor], List[U ...@@ -19,3 +23,15 @@ def quantize_per_tensor(ten: Union[Union[SparseConvTensor, torch.Tensor], List[U
else: else:
return torch.quantize_per_tensor(ten, scale, zero_point, dtype) return torch.quantize_per_tensor(ten, scale, zero_point, dtype)
def quantized_add(x: torch.Tensor, y: torch.Tensor, scale, zero_point):
x_detach = torch.zeros(size=x.shape, dtype=torch.int8, device=x.device)
y_detach = torch.zeros(size=y.shape, dtype=torch.int8, device=y.device)
torch_tensor_to_tv(x_detach).copy_(torch_tensor_to_tv(x))
torch_tensor_to_tv(y_detach).copy_(torch_tensor_to_tv(y))
res = (x_detach.to(torch.float32) * x.q_scale() + y_detach.to(torch.float32) * y.q_scale()) / scale
res = torch.clip(torch.round(res), -128, 127).to(torch.int8)
res_q = torch._empty_affine_quantized(size=res.shape, dtype=torch.qint8, scale=scale, zero_point=zero_point, device=x.device)
torch_tensor_to_tv(res_q, tv.int8).copy_(torch_tensor_to_tv(res))
return res_q
from typing import Any, Callable, Dict, List, Tuple, Union
import torch import torch
from torch.ao.quantization import get_default_qat_qconfig, get_default_qconfig
from torch.ao.quantization.fake_quantize import ( from torch.ao.quantization.fake_quantize import (
FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize, FakeQuantize, FakeQuantize, FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize,
default_fused_per_channel_wt_fake_quant, default_weight_fake_quant, default_per_channel_weight_fake_quant) default_fused_per_channel_wt_fake_quant,
from torch.ao.quantization.observer import (HistogramObserver, default_per_channel_weight_fake_quant, default_weight_fake_quant)
MovingAverageMinMaxObserver, from torch.ao.quantization.observer import (
default_weight_observer, MinMaxObserver,
default_placeholder_observer, HistogramObserver, MovingAverageMinMaxObserver,
default_per_channel_weight_observer) default_per_channel_weight_observer, default_placeholder_observer,
from torch.ao.quantization.qconfig import QConfig, QConfigAny, default_reuse_input_qconfig default_weight_observer)
from torch.ao.quantization.qconfig_mapping import QConfigMapping, _FIXED_QPARAMS_OP_TO_OBSERVER from torch.ao.quantization.qconfig import (QConfig, QConfigAny,
from typing import Any, Callable, Dict, Tuple, Union, List default_reuse_input_qconfig)
from torch.ao.quantization import get_default_qconfig, get_default_qat_qconfig from torch.ao.quantization.qconfig_mapping import (
_FIXED_QPARAMS_OP_TO_OBSERVER, QConfigMapping)
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from spconv.pytorch.modules import PrintTensorMeta, PrintCurrentTime
__all__ = ["get_default_spconv_trt_ptq_qconfig", "get_default_spconv_trt_qat_qconfig"] __all__ = ["get_default_spconv_trt_ptq_qconfig", "get_default_spconv_trt_qat_qconfig"]
...@@ -26,6 +32,16 @@ class SparseFusedMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize): ...@@ -26,6 +32,16 @@ class SparseFusedMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
else: else:
return super().forward(input) return super().forward(input)
class SparseMovingAvgObsFakeQuantize(FakeQuantize):
def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
if isinstance(input, SparseConvTensor):
# add lines to support spconv
x = input.features
res_features = super().forward(x)
return input.replace_feature(res_features)
else:
return super().forward(input)
# class SparseMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize): # class SparseMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
# def forward(self, input:Union[SparseConvTensor, torch.Tensor]): # def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
# if isinstance(input, SparseConvTensor): # if isinstance(input, SparseConvTensor):
...@@ -46,6 +62,16 @@ class SparseHistogramObserver(HistogramObserver): ...@@ -46,6 +62,16 @@ class SparseHistogramObserver(HistogramObserver):
else: else:
return super().forward(input) return super().forward(input)
class SparseMinMaxObserver(MinMaxObserver):
def forward(self, input:Union[SparseConvTensor, torch.Tensor]):
if isinstance(input, SparseConvTensor):
# add lines to support spconv
x = input.features
res_features = super().forward(x)
return input.replace_feature(res_features)
else:
return super().forward(input)
default_symmetric_spconv_ptq_qconfig = QConfig( default_symmetric_spconv_ptq_qconfig = QConfig(
activation=SparseHistogramObserver.with_args(quant_min=-128, activation=SparseHistogramObserver.with_args(quant_min=-128,
quant_max=127, quant_max=127,
...@@ -143,6 +169,8 @@ def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "fbgemm", ve ...@@ -143,6 +169,8 @@ def get_default_spconv_qconfig_mapping(is_qat: bool, backend: str = "fbgemm", ve
.set_object_type(torch.nn.functional.leaky_relu, qconfig) \ .set_object_type(torch.nn.functional.leaky_relu, qconfig) \
.set_object_type(torch.nn.Tanh, qconfig) \ .set_object_type(torch.nn.Tanh, qconfig) \
.set_object_type(torch.nn.functional.tanh, qconfig) .set_object_type(torch.nn.functional.tanh, qconfig)
qconfig_mapping.set_object_type(PrintTensorMeta, None)
qconfig_mapping.set_object_type(PrintCurrentTime, None)
return qconfig_mapping return qconfig_mapping
import torch.fx
import torch
from torch import nn
from typing import Dict, Optional
from spconv.pytorch.quantization.core import quantize_per_tensor, quantized_add
import spconv.pytorch.quantization.intrinsic.quantized as snniq
def is_dequantize_node(node):
return isinstance(node, torch.fx.Node) and node.op == "call_method" and node.target == "dequantize"
def _get_module(node: torch.fx.Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""
Return the `torch.nn.Module` that corresponds to the specified node's target.
If no such node exists, return None.
"""
if node.op == "call_module" and str(node.target) in modules:
return modules[str(node.target)]
else:
return None
def remove_conv_add_dq(model: torch.fx.graph_module.GraphModule):
modules = dict(model.named_modules(remove_duplicate=False))
for n in model.graph.nodes:
if (n.op == "call_module" and type(_get_module(n, modules)) == snniq.SparseConvAddReLU):
# check second input, if it's dequantized, remove that dequantize node
arg1 = n.args[1]
if is_dequantize_node(arg1):
dq_node = arg1
assert(isinstance(dq_node, torch.fx.Node))
dn_input = dq_node.args[0]
n.replace_input_with(dq_node, dn_input)
model.graph.eliminate_dead_code()
model.recompile()
model.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return model
def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""torch.quantize_per_tensor don't support SparseConvTensor, so we
use a custom one by fx transform.
"""
for node in m.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.quantize_per_tensor:
node.target = quantize_per_tensor
if node.target == torch.ops.quantized.add:
node.target = quantized_add
m.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
m.recompile()
return m
...@@ -14,16 +14,18 @@ ...@@ -14,16 +14,18 @@
from typing import Optional from typing import Optional
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from spconv.pytorch.cppcore import get_current_stream
import spconv.pytorch.quantization.quantized as nnq import spconv.pytorch.quantization.quantized as nnq
from spconv.pytorch.quantization.intrinsic import SpconvReLUNd, SpconvAddReLUNd from spconv.pytorch.quantization.intrinsic import SpconvReLUNd, SpconvAddReLUNd
from cumm import tensorview as tv from cumm import tensorview as tv
from spconv.pytorch.quantization.utils import fuse_spconv_bn_weights from spconv.pytorch.quantization.utils import fuse_spconv_bn_weights
import torch.ao.nn.intrinsic as nni
import spconv.pytorch.quantization.intrinsic.qat as snniqat import spconv.pytorch.quantization.intrinsic.qat as snniqat
import spconv.pytorch.quantization.intrinsic as snni import spconv.pytorch.quantization.intrinsic as snni
import torch import torch
__all__ = ["SparseConvReLU", "SparseConvAddReLU"] __all__ = ["SparseConvReLU", "SparseConvAddReLU", "LinearPerChannelWeightReLU"]
class SparseConvReLU(nnq.SparseConv): class SparseConvReLU(nnq.SparseConv):
r""" r"""
...@@ -38,6 +40,10 @@ class SparseConvReLU(nnq.SparseConv): ...@@ -38,6 +40,10 @@ class SparseConvReLU(nnq.SparseConv):
_FLOAT_MODULE = SpconvReLUNd # type: ignore[assignment] _FLOAT_MODULE = SpconvReLUNd # type: ignore[assignment]
def forward(self, input): def forward(self, input):
msg = f"{input.features.shape[0]}, {input.features.shape[1]}, {self.weight().shape[0]}"
with tv.measure_and_print(f"QuantizedSparseConvReLU|{msg}", get_current_stream(), enable=False):
inp_scale = input.q_scale() inp_scale = input.q_scale()
w_scales = self.weight().q_per_channel_scales().to(torch.float32) w_scales = self.weight().q_per_channel_scales().to(torch.float32)
out_scale = self.scale out_scale = self.scale
...@@ -80,6 +86,9 @@ class SparseConvAddReLU(nnq.SparseConv): ...@@ -80,6 +86,9 @@ class SparseConvAddReLU(nnq.SparseConv):
_FLOAT_MODULE = SpconvAddReLUNd # type: ignore[assignment] _FLOAT_MODULE = SpconvAddReLUNd # type: ignore[assignment]
def forward(self, input, add_input: Optional[SparseConvTensor] = None): def forward(self, input, add_input: Optional[SparseConvTensor] = None):
msg = f"{input.features.shape[0]}, {input.features.shape[1]}, {self.weight().shape[0]}"
with tv.measure_and_print(f"QuantizedSparseConvAddReLU|{msg}", get_current_stream(), enable=False):
inp_scale = input.q_scale() inp_scale = input.q_scale()
w_scales = self.weight().q_per_channel_scales().to(torch.float32) w_scales = self.weight().q_per_channel_scales().to(torch.float32)
out_scale = self.scale out_scale = self.scale
...@@ -92,7 +101,7 @@ class SparseConvAddReLU(nnq.SparseConv): ...@@ -92,7 +101,7 @@ class SparseConvAddReLU(nnq.SparseConv):
return res return res
def _get_name(self): def _get_name(self):
return 'QuantizedSparseConvReLU' return 'QuantizedSparseConvAddReLU'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod):
...@@ -107,3 +116,44 @@ class SparseConvAddReLU(nnq.SparseConv): ...@@ -107,3 +116,44 @@ class SparseConvAddReLU(nnq.SparseConv):
assert type(ref_qconv) != snni.SpconvBnReLUNd, \ assert type(ref_qconv) != snni.SpconvBnReLUNd, \
"BatchNorm1d should be fused into Conv1d before converting to reference module" "BatchNorm1d should be fused into Conv1d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point) return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class LinearPerChannelWeightReLU(nnq.LinearPerChannelWeight):
r"""
A LinearPerChannelWeight module fused from Linear and ReLU modules
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
Attributes:
Same as torch.ao.nn.quantized.Linear
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearReLU
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out, nvrtc_params = self._linear_fwd(x, self.weight(), self.bias(), self.scale, tv.gemm.Activation.ReLU, self._nvrtc_params)
if self._nvrtc_params is None:
self._nvrtc_params = nvrtc_params
return out
def _get_name(self):
return 'QuantizedLinearPerChannelWeightReLU'
@classmethod
def from_float(cls, mod):
return super(LinearPerChannelWeightReLU, cls).from_float(mod)
@classmethod
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
return super().from_reference(ref_linear_relu[0], output_scale, output_zero_point)
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .conv import SparseConv from .conv import SparseConv, LinearPerChannelWeight
\ No newline at end of file \ No newline at end of file
...@@ -16,11 +16,29 @@ from spconv.pytorch.core import SparseConvTensor ...@@ -16,11 +16,29 @@ from spconv.pytorch.core import SparseConvTensor
from torch._ops import ops from torch._ops import ops
from torch.nn.common_types import _size_1_t from torch.nn.common_types import _size_1_t
from torch.nn.modules.utils import _single, _pair, _triple from torch.nn.modules.utils import _single, _pair, _triple
from collections.abc import Iterable
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule, _quantize_weight from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule, _quantize_weight
import spconv.pytorch.quantization.intrinsic.qat.modules as snniqat import spconv.pytorch.quantization.intrinsic.qat.modules as snniqat
import spconv.pytorch.quantization.intrinsic.modules as snni import spconv.pytorch.quantization.intrinsic.modules as snni
from spconv.pytorch.quantization.utils import fuse_spconv_bn_eval, fuse_spconv_bn_weights from spconv.pytorch.quantization.utils import fuse_spconv_bn_eval, fuse_spconv_bn_weights
from cumm.tensorview.gemm import ConvParams, GemmAlgoDesp, GemmParams
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview.gemm import ConvOpType as ConvOpTypeCpp
from spconv.constants import (NDIM_DONT_CARE, SPCONV_BWD_SPLITK,
SPCONV_NVRTC_MODE, SPCONV_DEBUG_NVRTC_KERNELS)
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
from spconv import algocore
from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream
import torch.ao.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat
from torch.nn.utils.fusion import fuse_linear_bn_weights
from torch.nn.utils.parametrize import type_before_parametrizations
from spconv.algo import _get_nvrtc_params, SimpleConv
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo
from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType)
from cumm.gemm.algospec.core import TensorOp
class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule): class _SparseConv(SparseConvolutionBase, WeightedQuantizedModule):
_FLOAT_MODULE = SparseConvolution _FLOAT_MODULE = SparseConvolution
...@@ -359,3 +377,319 @@ class SparseConv(_SparseConv): ...@@ -359,3 +377,319 @@ class SparseConv(_SparseConv):
return _SparseConv.from_float(cls, mod) return _SparseConv.from_float(cls, mod)
class LinearPerChannelWeight(WeightedQuantizedModule):
r"""
A quantized linear module with quantized tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
This module use conv int8 in cumm to provide qcuda int8 debug.
Similar to :class:`~torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized to zero.
scale: `scale` parameter of output Quantized Tensor, type: double
zero_point: `zero_point` parameter for output Quantized Tensor, type: long
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> m = nn.quantized.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> # xdoctest: +SKIP
>>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_version = 3
_FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear)
CUMM_CONV_PARAMS = [ *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
2,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
access_per_vector=1,
is_nvrtc=True,
int8_inference=True,
dynamic_mask=False),
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
2,
ConvIterAlgo.Optimized,
2,
["s8,s8,s8,s32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((16, 8, 16)),
access_per_vector=0,
is_nvrtc=True,
int8_inference=True,
dynamic_mask=False),
]
def __init__(self, in_features, out_features, bias_=True,
dtype=torch.qint8):
super().__init__()
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
self.in_features = in_features
self.out_features = out_features
bias = None
if bias_:
bias = torch.zeros(out_features, dtype=torch.float)
if dtype == torch.qint8:
qweight = torch._empty_affine_quantized(
[out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8)
elif dtype == torch.float16:
qweight = torch.zeros([out_features, in_features], dtype=torch.float)
else:
raise RuntimeError('Unsupported dtype specified for quantized Linear!')
self._weight: torch.Tensor = qweight
self._bias: Optional[torch.Tensor] = bias
self.scale = 1.0
self.zero_point = 0
self._nvrtc_params = None
# this standard int8 conv operators is used for only quantization debug (to implement quantized Linear/Conv for qcuda backend)
def _get_name(self):
return 'QuantizedLinearPerChannelWeight'
def extra_repr(self):
return 'in_features={}, out_features={}, scale={}, zero_point={}, qscheme={}'.format(
self.in_features, self.out_features, self.scale, self.zero_point, self.weight().qscheme()
)
@staticmethod
def _linear_fwd(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], scale: float, act: tv.gemm.Activation, nvrtc_params):
is_ref = True
inp_scale = x.q_scale()
w_scales = weight.q_per_channel_scales().to(torch.float32)
out_scale = scale
channel_scale = (inp_scale * w_scales) / out_scale
channel_k = weight.size(0)
channel_c = weight.size(-1)
if bias is not None:
bias = bias / out_scale
else:
bias = torch.zeros([channel_k], dtype=torch.float32, device=x.device)
ldi = x.size(-1)
ldw = weight.size(-1)
ldo = weight.size(0)
params = ConvParams(2, ConvOpTypeCpp(ConvOpType.kForward.value))
assert len(LinearPerChannelWeight.CUMM_CONV_PARAMS) == 2
algo_desp_fast = algocore.get_conv_algo_desp_from_param(LinearPerChannelWeight.CUMM_CONV_PARAMS[0])
algo_desp_generic = algocore.get_conv_algo_desp_from_param(LinearPerChannelWeight.CUMM_CONV_PARAMS[1])
algo_desp = algo_desp_fast
if not algo_desp_fast.supported_ldx_conv(ldi, ldw, ldo):
algo_desp = algo_desp_generic
# if not algo_desp.supported_ldx_conv(ldi, ldw, ldo):
# breakpoint()
if is_ref:
x_detach = torch.zeros(size=x.size(), dtype=torch.int8, device=x.device)
weight_detach = torch.zeros(size=weight.size(), dtype=torch.int8, device=x.device)
torch_tensor_to_tv(x_detach).copy_(torch_tensor_to_tv(x))
torch_tensor_to_tv(weight_detach).copy_(torch_tensor_to_tv(weight))
# o_tmp = torch.from_numpy(x_detach.to(torch.int32).cpu().numpy() @ weight_detach.to(torch.int32).cpu().numpy().T).to(x.device)
o_tmp = x_detach.to(torch.float32) @ weight_detach.to(torch.float32).T
o_tmp = o_tmp.to(torch.float32) * channel_scale + bias
if act == tv.gemm.Activation.ReLU:
o_tmp = torch.maximum(o_tmp, torch.tensor(0, dtype=o_tmp.dtype, device=x.device))
o_tmp = torch.clip(torch.round(o_tmp), -128, 127).to(torch.int8)
output = torch._empty_affine_quantized(o_tmp.shape, scale=scale, zero_point=0, dtype=x.dtype, device=x.device)
torch_tensor_to_tv(output).copy_(torch_tensor_to_tv(o_tmp))
return output, None
else:
assert algo_desp.supported_ldx_conv(ldi, ldw, ldo)
out_shape = [x.size(0),weight.size(0) ]
output = torch._empty_affine_quantized(out_shape, scale=scale, zero_point=0, dtype=x.dtype, device=x.device)
params.conv_algo_desp = algo_desp
params.input = torch_tensor_to_tv(x).view([x.size(0), 1, 1, channel_c])
params.verbose = False
params.weight = torch_tensor_to_tv(weight).view([channel_k, 1, 1, channel_c])
params.output = torch_tensor_to_tv(output).view([x.size(0), 1, 1, channel_k])
params.split_k_slices = 1
params.alpha = 1.0
params.beta = 0.0
params.act_alpha = 1.0
params.act_beta = 0.0
params.act_type = act
params.padding = [0, 0]
params.stride = [1, 1]
params.dilation = [1, 1]
params.stream = get_current_stream()
if nvrtc_params is None:
mod, ker = SimpleConv._compile_nvrtc_module(algo_desp)
nvrtc_params = _get_nvrtc_params(mod, ker, "conv_kernel")
params.bias = torch_tensor_to_tv(bias)
params.scale = torch_tensor_to_tv(channel_scale)
params.nvrtc_params = nvrtc_params
tv.gemm.run_nvrtc_conv_kernel(params)
return output, nvrtc_params
def forward(self, x: torch.Tensor) -> torch.Tensor:
out, nvrtc_params = self._linear_fwd(x, self.weight(), self.bias(), self.scale, tv.gemm.Activation.None_, self._nvrtc_params)
if self._nvrtc_params is None:
self._nvrtc_params = nvrtc_params
return out
# ===== Serialization methods =====
# The special consideration here is that we have to unpack the weights into their
# regular QTensor form for serialization. Packed weights should not live
# outside the process in which they were created, rather they should be derived
# from the QTensor weight.
#
# Version 1
# self
# |--- scale : float
# |--- zero_point : int
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 2
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 3
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- _packed_params : (Tensor, Tensor) representing weight, bias
# of LinearPackedParams C++ struct
#
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'scale'] = torch.tensor(self.scale)
destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
(w, b) = self._weight_bias()
destination[prefix + 'weight'] = w
destination[prefix + 'bias'] = b
# ===== Deserialization methods =====
# Counterpart to the serialization methods, we must pack the serialized QTensor
# weight into its packed format for use by the FBGEMM ops.
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self.scale = float(state_dict[prefix + 'scale'])
state_dict.pop(prefix + 'scale')
self.zero_point = int(state_dict[prefix + 'zero_point'])
state_dict.pop(prefix + 'zero_point')
version = local_metadata.get('version', None)
# if version is None or version == 1:
# # We moved the parameters into a LinearPackedParameters submodule
# weight = state_dict.pop(prefix + 'weight')
# bias = state_dict.pop(prefix + 'bias')
# state_dict.update({prefix + '_packed_params.weight': weight,
# prefix + '_packed_params.bias': bias})
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
# Function rather than property to make sure that JIT serialization doesn't
# register this as an attribute
def _weight_bias(self):
return (self._weight, self._bias)
def weight(self):
return self._weight_bias()[0]
def bias(self):
return self._weight_bias()[1]
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._weight = w
self._bias = b
# self._packed_params.set_weight_bias(w, b)
@classmethod
def from_float(cls, mod):
r"""Create a quantized module from an observed float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
if hasattr(mod, 'weight_fake_quant'):
if type_before_parametrizations(mod) == nniqat.LinearBn1d:
mod.weight, mod.bias = fuse_linear_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
weight_post_process = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
# This function does not participate in JIT, so it is OK to ignore
# the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too.
if not isinstance(cls._FLOAT_MODULE, Iterable):
cls._FLOAT_MODULE = [cls._FLOAT_MODULE] # type: ignore[assignment]
supported_modules = ', '.join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE]) # type: ignore[attr-defined]
error_msg = 'nnq.{}.from_float only works for {}, but got: {}'.format(cls.__name__, supported_modules, type(mod))
assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, error_msg.format() # type: ignore[attr-defined]
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
activation_post_process = mod.activation_post_process
if type_before_parametrizations(mod) == nni.LinearReLU:
mod = mod[0]
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams()
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear = cls(mod.in_features,
mod.out_features,
dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
qlinear.scale = float(act_scale)
qlinear.zero_point = int(act_zp)
return qlinear
@classmethod
def from_reference(cls, ref_qlinear, output_scale, output_zero_point):
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
output_zero_point (int): zero point for output Tensor
"""
qlinear = cls(
ref_qlinear.in_features,
ref_qlinear.out_features)
qweight = ref_qlinear.get_quantized_weight()
qlinear.set_weight_bias(qweight, ref_qlinear.bias)
qlinear.scale = float(output_scale)
qlinear.zero_point = int(output_zero_point)
return qlinear
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment