Commit e387ee74 authored by yan.yan's avatar yan.yan
Browse files

sync quantization code

parent b1c57a31
......@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv_relu import *
\ No newline at end of file
from .conv_relu import *
......@@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from spconv.pytorch.core import SparseConvTensor
import spconv.pytorch.quantization.quantized as nnq
from spconv.pytorch.quantization.intrinsic import SpconvReLUNd
from spconv.pytorch.quantization.intrinsic import SpconvReLUNd, SpconvAddReLUNd
from cumm import tensorview as tv
from spconv.pytorch.quantization.utils import fuse_spconv_bn_weights
import spconv.pytorch.quantization.intrinsic.qat as snniqat
import spconv.pytorch.quantization.intrinsic as snni
import torch
__all__ = ["SparseConvReLU"]
__all__ = ["SparseConvReLU", "SparseConvAddReLU"]
class SparseConvReLU(nnq.SparseConv):
r"""
......@@ -36,16 +39,18 @@ class SparseConvReLU(nnq.SparseConv):
def forward(self, input):
inp_scale = input.q_scale()
w_scales = self.weight().q_per_channel_scales()
w_scales = self.weight().q_per_channel_scales().to(torch.float32)
out_scale = self.scale
channel_scale = out_scale / (inp_scale * w_scales)
bias = self.bias() * out_scale
return self._conv_forward(False, input,
self.weight(), bias, channel_scale=channel_scale, output_scale=out_scale,
channel_scale = (inp_scale * w_scales) / out_scale
scaled_bias = self.bias() / out_scale
# print(bias.dtype, input.features.dtype, channel_scale.dtype, w_scales.dtype)
res = self._conv_forward(False, input,
self.weight(), scaled_bias, channel_scale=channel_scale, output_scale=out_scale,
act_type=tv.gemm.Activation.ReLU)
return res
def _get_name(self):
return 'QuantizedConvReLU1d'
return 'QuantizedSparseConvReLU'
@classmethod
def from_float(cls, mod):
......@@ -60,3 +65,45 @@ class SparseConvReLU(nnq.SparseConv):
assert type(ref_qconv) != snni.SpconvBnReLUNd, \
"BatchNorm1d should be fused into Conv1d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class SparseConvAddReLU(nnq.SparseConv):
r"""
A ConvReLU1d module is a fused module of Conv1d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
Attributes:
Same as torch.ao.nn.quantized.Conv1d
"""
_FLOAT_MODULE = SpconvAddReLUNd # type: ignore[assignment]
def forward(self, input, add_input: Optional[SparseConvTensor] = None):
inp_scale = input.q_scale()
w_scales = self.weight().q_per_channel_scales().to(torch.float32)
out_scale = self.scale
channel_scale = (inp_scale * w_scales) / out_scale
scaled_bias = self.bias() / out_scale
# print(bias.dtype, input.features.dtype, channel_scale.dtype, w_scales.dtype)
res = self._conv_forward(False, input,
self.weight(), scaled_bias, channel_scale=channel_scale, output_scale=out_scale,
act_type=tv.gemm.Activation.ReLU, add_input=add_input)
return res
def _get_name(self):
return 'QuantizedSparseConvReLU'
@classmethod
def from_float(cls, mod):
if type(mod) == snniqat.SparseConvBnAddReLU:
mod.weight, mod.bias = fuse_spconv_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
return super(SparseConvAddReLU, cls).from_float(mod)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert type(ref_qconv) != snni.SpconvBnReLUNd, \
"BatchNorm1d should be fused into Conv1d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
......@@ -323,9 +323,12 @@ class SparseConv(_SparseConv):
return 'QuantizedSparseConvolution'
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
assert b is not None
self._weight = w
self._bias = b
if b is None:
# currently bias tensor must exists.
self._bias = torch.zeros((w.shape[0],), dtype=torch.float32, device=w.device)
else:
self._bias = b
def weight(self):
return self._weight_bias()[0]
......@@ -336,12 +339,11 @@ class SparseConv(_SparseConv):
def forward(self, input: SparseConvTensor):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
print("?")
inp_scale = input.q_scale()
w_scales = self.weight().q_per_channel_scales()
w_scales = self.weight().q_per_channel_scales().to(torch.float32)
out_scale = self.scale
channel_scale = out_scale / (inp_scale * w_scales)
bias = self.bias() * out_scale
channel_scale = (inp_scale * w_scales) / out_scale
bias = self.bias() / out_scale
return self._conv_forward(False, input,
self.weight(), bias, channel_scale=channel_scale, output_scale=out_scale)
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
......
......@@ -132,7 +132,7 @@ class SpConv(_SpConvNd, sconvmod.SparseConvolution):
device=device)
self._init_weight_qparams(weight_qparams, device)
def forward(self, x: SparseConvTensor) -> SparseConvTensor:
def forward(self, x: SparseConvTensor, add_input: Optional[SparseConvTensor] = None) -> SparseConvTensor:
"""
we have:
w(float) -- quant - dequant \
......@@ -144,7 +144,7 @@ class SpConv(_SpConvNd, sconvmod.SparseConvolution):
and the backend should be able to fuse the ops with `*` into a quantized SparseConvolution
"""
weight_quant_dequant = self.get_weight()
result = self._conv_forward(self.training, x, weight_quant_dequant, self.bias)
result = self._conv_forward(self.training, x, weight_quant_dequant, self.bias, add_input=add_input)
return result
def _get_name(self):
......
......@@ -50,3 +50,4 @@ def fuse_spconv_act_eval(conv, act):
else:
raise NotImplementedError
return fused_conv
from collections import OrderedDict
import contextlib
import operator
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.ao.quantization.fx.match_utils import (
MatchAllNode,
)
from torch.ao.quantization.quantize_fx import (
fuse_fx,
)
from torch.ao.quantization.backend_config import (
get_qnnpack_backend_config,
BackendConfig,
BackendPatternConfig,
DTypeConfig,
ObservationType,
get_fbgemm_backend_config
)
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as qfx
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(3)
self.iden = nn.Identity()
def forward(self, x):
y = x
y = self.iden(x)
x = self.conv(x)
x = self.bn(x)
x = torch.add(x, y)
x = self.relu(x)
return x
class M2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
self.bn1 = torch.nn.BatchNorm2d(3)
self.bn2 = torch.nn.BatchNorm2d(3)
self.relu1 = torch.nn.ReLU()
self.relu2 = torch.nn.ReLU()
self.iden = nn.Identity()
def forward(self, x):
y = x
y = self.iden(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = torch.add(x, y)
x = self.relu2(x)
return x
class M3(torch.nn.Module):
def __init__(self):
super().__init__()
conv1 = torch.nn.Conv2d(3, 3, 3)
conv2 = torch.nn.Conv2d(3, 3, 3)
bn1 = torch.nn.BatchNorm2d(3)
bn2 = torch.nn.BatchNorm2d(3)
relu1 = torch.nn.ReLU()
self.relu2 = torch.nn.ReLU()
self.conv1_bn_relu = nn.Sequential(OrderedDict(conv=conv1, bn=bn1, relu=nn.ReLU(inplace=True)))
self.conv2_bn = nn.Sequential(OrderedDict(conv=conv2, bn=bn2))
self.iden = nn.Identity()
self.conv3 = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
y = x
y = self.conv3(x)
x = self.conv1_bn_relu(x)
x = self.conv2_bn(x)
x = self.relu2(torch.add(x,y))
return x
m = M().eval()
def fuse_conv_bn_relu(is_qat, relu, add_pattern):
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def conv_bn_res_relu_root_node_getter(pattern):
relu, add_pattern = pattern
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def conv_bn_res_relu_extra_inputs_getter(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu, add_pattern = pattern
_, bn_pattern, extra_input = add_pattern
bn, conv = bn_pattern
return [extra_input]
# conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
# .set_fuser_method(fuse_conv_bn_relu) \
# ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
# ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
fbgemm_weighted_op_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
conv_bn_res_relu_config = BackendPatternConfig() \
._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
.set_fuser_method(fuse_conv_bn_relu) \
._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter) \
.set_dtype_configs(fbgemm_weighted_op_int8_dtype_config)
backend_config = get_fbgemm_backend_config().set_backend_pattern_config(conv_bn_res_relu_config)
# m = fuse_fx(m, backend_config=backend_config)
qmapping = get_default_qconfig_mapping()
prepared_model = qfx.prepare_fx(m, qmapping, (), backend_config=backend_config)
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qmapping, backend_config=backend_config)
converted_model.print_readable()
\ No newline at end of file
from collections import OrderedDict
import contextlib
import operator
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.ao.quantization.fx.match_utils import (
MatchAllNode,
)
from torch.ao.quantization.quantize_fx import (
fuse_fx,
)
from torch.ao.quantization.backend_config import (
get_qnnpack_backend_config,
BackendConfig,
BackendPatternConfig,
DTypeConfig,
ObservationType,
get_fbgemm_backend_config
)
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as qfx
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(3)
self.iden = nn.Conv2d(3, 3, 3)
# self.iden2 = nn.Conv2d(3, 3, 3)
def forward(self, x):
y = x
y = self.iden(x)
x = self.conv(x)
x = self.bn(x)
x = torch.add(x, y)
x = self.relu(x)
return x
m = M().eval()
def fuse_conv_bn_relu(is_qat, relu, add_pattern):
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def conv_bn_res_relu_root_node_getter(pattern):
relu, add_pattern = pattern
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def conv_bn_res_relu_extra_inputs_getter(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu, add_pattern = pattern
_, bn_pattern, extra_input = add_pattern
bn, conv = bn_pattern
return [extra_input]
# for pytorch <= 1.13
# conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
# .set_fuser_method(fuse_conv_bn_relu) \
# ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
# ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
fbgemm_weighted_op_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
# for pytorch master
conv_bn_res_relu_config = BackendPatternConfig() \
._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
.set_fuser_method(fuse_conv_bn_relu) \
._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter) \
.set_dtype_configs(fbgemm_weighted_op_int8_dtype_config)
backend_config = get_fbgemm_backend_config()# .set_backend_pattern_config(conv_bn_res_relu_config)
# m = fuse_fx(m, backend_config=backend_config)
qmapping = get_default_qconfig_mapping()
prepared_model = qfx.prepare_fx(m, qmapping, (), backend_config=backend_config)
prepared_model.print_readable()
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qmapping, backend_config=backend_config)
converted_model.print_readable()
\ No newline at end of file
......@@ -88,7 +88,7 @@ class SparseConvTester:
op = expand_nd(ndim, 0)
self.kv: int = np.prod(self.ksize)
self.num_split = 1 if algo == ConvAlgo.MaskImplicitGemm else 2
self.output_scale: float = 1.0
self.output_scale: float = 3.4
self.check_int8_infer = check_int8_infer
if check_int8_infer:
assert check_bias and self.dtype == np.int8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment