Commit aa26c99e authored by yan.yan's avatar yan.yan
Browse files

working on quantization

parent ee8c9465
# torch.ao.nn.intrinsic.qat.modules.conv_fused
import math
import torch
import torch.nn as nn
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.nn.functional as F
from torch.nn import init
from torch.nn.utils import fuse_conv_bn_weights
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.parameter import Parameter
from typing import TypeVar
from spconv.pytorch.conv import SparseConvolution
from typing import List, Optional, Tuple, Union
from spconv.core import ConvAlgo
from cumm import tensorview as tv
from spconv.pytorch.core import SparseConvTensor
import spconv.pytorch.quantization.intrinsic as snni
MOD = TypeVar('MOD', bound=SparseConvolution)
class _SparseConvBn(SparseConvolution, nni._FusedModule):
_version = 2
_FLOAT_MODULE = MOD
_FLOAT_CONV_MODULE = SparseConvolution
def __init__(self,
# SparseConvolution args
ndim: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: Union[int, List[int], Tuple[int, ...]] = 1,
bias: bool = True,
subm: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
transposed: bool = False,
inverse: bool = False,
indice_key: Optional[str] = None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
record_voxel_count: bool = False,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_,
act_alpha: float = 0,
act_beta: float = 0,
name=None,
# BatchNormNd args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
SparseConvolution.__init__(self, ndim, in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
bias=False,
subm=subm,
output_padding=output_padding,
transposed=transposed,
inverse=inverse,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
record_voxel_count=record_voxel_count,
act_type=act_type,
act_alpha=act_alpha,
act_beta=act_beta,
name=name)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.freeze_bn = freeze_bn if self.training else True
self.bn = nn.BatchNorm1d(out_channels, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_bn_parameters()
# this needs to be called after reset_bn_parameters,
# as they modify the same state
if self.training:
if freeze_bn:
self.freeze_bn_stats()
else:
self.update_bn_stats()
else:
self.freeze_bn_stats()
self._enable_slow_path_for_better_numerical_stability = False
def reset_running_stats(self):
self.bn.reset_running_stats()
def reset_bn_parameters(self):
self.bn.reset_running_stats()
init.uniform_(self.bn.weight)
init.zeros_(self.bn.bias)
# note: below is actully for conv, not BN
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def reset_parameters(self):
super(_SparseConvBn, self).reset_parameters()
def update_bn_stats(self):
self.freeze_bn = False
self.bn.training = True
return self
def freeze_bn_stats(self):
self.freeze_bn = True
self.bn.training = False
return self
def _forward(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
assert not self._enable_slow_path_for_better_numerical_stability
if self._enable_slow_path_for_better_numerical_stability:
return self._forward_slow(input)
return self._forward_approximate(input, add_input)
def _forward_approximate(self, input: SparseConvTensor, add_input: Optional[SparseConvTensor] = None):
"""Approximated method to fuse conv and bn. It requires only one forward pass.
conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
"""
assert self.bn.running_var is not None
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.weight.shape)
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
# using zero bias here since the bias for original conv
# will be added later
if self.bias is not None:
zero_bias = torch.zeros_like(self.bias, dtype=input.features.dtype)
else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.features.dtype)
conv_spt = self._conv_forward(input, scaled_weight, zero_bias)
conv = conv_spt.features
conv_orig = conv / scale_factor.reshape(bias_shape)
if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape(bias_shape)
conv = self.bn(conv_orig)
if add_input is not None:
conv = conv + add_input.features
conv_spt = conv_spt.replace_feature(conv)
return conv_spt
def _forward_slow(self, input: SparseConvTensor):
"""
TODO not implemented for now
A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
It requires two forward passes but handles the case bn.weight == 0
Conv: Y = WX + B_c
Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
Batch statistics:
mean_Y = Y.mean()
= Y0.mean() + B_c
var_Y = (Y - mean_Y)^2.mean()
= (Y0 - Y0.mean())^2.mean()
BN (r: bn.weight, beta: bn.bias):
Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
= r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
= (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
Fused Conv BN inference (running_std = sqrt(running_var + eps)):
Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
QAT with fused conv bn:
Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
= conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
"""
assert self.bn.running_var is not None
assert self.bn.running_mean is not None
# using zero bias here since the bias for original conv
# will be added later
zero_bias = torch.zeros(self.out_channels, device=self.weight.device, dtype=input.features.dtype)
weight_shape = [1] * len(self.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.weight.shape)
bias_shape[1] = -1
conv_out = torch.Tensor()
if self.bn.training:
# needed to compute batch mean/std
conv_spt = self._conv_forward(input, self.weight, zero_bias)
conv_out = conv_spt.features
# update bn statistics
with torch.no_grad():
conv_out_bias = (
conv_out if self.bias is None else conv_out + self.bias.reshape(bias_shape)
)
self.bn(conv_out_bias)
# fused conv + bn without bias using bn running statistics
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
scaled_weight = self.weight_fake_quant(
self.weight * scale_factor.reshape(weight_shape)
)
# fused conv without bias for inference: (r * W / running_std) * X
conv_bn_spt = self._conv_forward(input, scaled_weight, zero_bias)
conv_bn = conv_bn_spt.features
if self.bn.training:
avg_dims = [0] + list(range(2, len(self.weight.shape)))
batch_mean = conv_out.mean(avg_dims)
batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
avg_dims
)
batch_std = torch.sqrt(batch_var + self.bn.eps)
# scale to use batch std in training mode
# conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
unscale_factor = running_std / batch_std
conv_bn *= unscale_factor.reshape(bias_shape)
fused_mean = batch_mean
fused_std = batch_std
else:
fused_mean = self.bn.running_mean - (self.bias if self.bias is not None else 0)
fused_std = running_std
# fused bias = beta - r * mean / std
fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
conv_bn += fused_bias.reshape(bias_shape)
# HACK to let conv bias particpiate in loss to avoid DDP error (parameters
# were not used in producing loss)
if self.bias is not None:
conv_bn += (self.bias - self.bias).reshape(bias_shape)
conv_bn_spt = conv_bn_spt.replace_feature(conv_bn)
return conv_bn_spt
return conv_bn
def extra_repr(self):
# TODO(jerryzh): extend
return super(_SparseConvBn, self).extra_repr()
def forward(self, input):
return self._forward(input)
def train(self, mode=True):
"""
Batchnorm's training behavior is using the self.training flag. Prevent
changing it if BN is frozen. This makes sure that calling `model.train()`
on a model with a frozen BN will behave properly.
"""
self.training = mode
if not self.freeze_bn:
for module in self.children():
module.train(mode)
return self
# ===== Serialization version history =====
#
# Version 1/None
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- gamma : Tensor
# |--- beta : Tensor
# |--- running_mean : Tensor
# |--- running_var : Tensor
# |--- num_batches_tracked : Tensor
#
# Version 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- bn : Module
# |--- weight : Tensor (moved from v1.self.gamma)
# |--- bias : Tensor (moved from v1.self.beta)
# |--- running_mean : Tensor (moved from v1.self.running_mean)
# |--- running_var : Tensor (moved from v1.self.running_var)
# |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version == 1:
# BN related parameters and buffers were moved into the BN module for v2
v2_to_v1_names = {
'bn.weight': 'gamma',
'bn.bias': 'beta',
'bn.running_mean': 'running_mean',
'bn.running_var': 'running_var',
'bn.num_batches_tracked': 'num_batches_tracked',
}
for v2_name, v1_name in v2_to_v1_names.items():
if prefix + v1_name in state_dict:
state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
state_dict.pop(prefix + v1_name)
elif prefix + v2_name in state_dict:
# there was a brief period where forward compatibility
# for this module was broken (between
# https://github.com/pytorch/pytorch/pull/38478
# and https://github.com/pytorch/pytorch/pull/38820)
# and modules emitted the v2 state_dict format while
# specifying that version == 1. This patches the forward
# compatibility issue by allowing the v2 style entries to
# be used.
pass
elif strict:
missing_keys.append(prefix + v2_name)
super(_SparseConvBn, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
# The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
# has no __name__ (code is fine though)
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
qconfig = mod.qconfig
conv: SparseConvolution = mod[0]
bn: nn.BatchNorm1d = mod[1]
qat_convbn = cls(conv.ndim, conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups,
conv.bias is not None,
subm=conv.subm,
output_padding=conv.output_padding,
transposed=conv.transposed,
inverse=conv.inverse,
indice_key=conv.indice_key,
algo=conv.algo,
fp32_accum=conv.fp32_accum,
record_voxel_count=conv.record_voxel_count,
act_type=conv.act_type,
act_alpha=conv.act_alpha,
act_beta=conv.act_beta,
name=conv.name,
eps=bn.eps, momentum=bn.momentum,
freeze_bn=False,
qconfig=qconfig)
qat_convbn.weight = conv.weight
qat_convbn.bias = conv.bias
qat_convbn.bn.weight = bn.weight
qat_convbn.bn.bias = bn.bias
qat_convbn.bn.running_mean = bn.running_mean
qat_convbn.bn.running_var = bn.running_var
# mypy error: Cannot determine type of 'num_batches_tracked'
qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type]
return qat_convbn
def to_float(self):
cls = type(self)
conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
self.ndim,
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.groups,
self.bias is not None,
subm=self.subm,
output_padding=self.output_padding,
transposed=self.transposed,
inverse=self.inverse,
indice_key=self.indice_key,
algo=self.algo,
fp32_accum=self.fp32_accum,
record_voxel_count=self.record_voxel_count,
act_type=self.act_type,
act_alpha=self.act_alpha,
act_beta=self.act_beta,
name=self.name)
conv.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
conv.bias = torch.nn.Parameter(self.bias.detach())
if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
# fuse bn into conv
conv.weight, conv.bias = fuse_conv_bn_weights(
conv.weight,
conv.bias,
self.bn.running_mean,
self.bn.running_var,
self.bn.eps,
self.bn.weight,
self.bn.bias
)
if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined]
modules = []
modules.append(conv)
relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
modules.append(relu)
conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined]
conv_relu.train(self.training)
return conv_relu
else:
conv.train(self.training)
return conv
class SparseConvBn(_SparseConvBn):
r"""
A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d`.
Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE = snni.SpconvBnNd # type: ignore[assignment]
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = nn.BatchNorm1d
_FLOAT_RELU_MODULE = None
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE = snni.SpconvReLUNd
class SparseConvBnReLU(_SparseConvBn):
r"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
# base class defines _FLOAT_MODULE as "ConvBn1d"
_FLOAT_MODULE = snni.SpconvBnReLUNd # type: ignore[assignment]
_FLOAT_CONV_MODULE = SparseConvolution
_FLOAT_BN_MODULE = nn.BatchNorm1d
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
# module class after fusing bn into conv
_FUSED_FLOAT_MODULE = snni.SpconvReLUNd
def forward(self, input):
x = _SparseConvBn._forward(self, input)
return x.replace_feature(F.relu(x.features))
@classmethod
def from_float(cls, mod):
return super(SparseConvBnReLU, cls).from_float(mod)
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize, fused_wt_fake_quant_range_neg_127_to_127
from spconv.pytorch.core import SparseConvTensor
import torch
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MovingAverageMinMaxObserver
class SparseFusedMovingAvgObsFakeQuantize(FusedMovingAvgObsFakeQuantize):
def forward(self, input:SparseConvTensor):
# add lines to support spconv
x = input.features
res_features = super().forward(x)
return input.replace_feature(res_features)
default_symmetric_spconv_qat_qconfig = QConfig(
activation=SparseFusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
reduce_range=False,
eps=2 ** -12),
weight=fused_wt_fake_quant_range_neg_127_to_127)
from typing import Union, Callable, Tuple, Dict, Optional, Type, Any
import torch.nn as nn
import spconv.pytorch as spconv
from .utils import fuse_spconv_bn_eval
from . import intrinsic as snni
from .conv_fused import SparseConvBn, SparseConvBnReLU
def fuse_conv_bn(conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map = {
spconv.SubMConv1d: snni.SpconvBnNd,
spconv.SparseConv1d: snni.SpconvBnNd,
spconv.SparseInverseConv1d: snni.SpconvBnNd,
spconv.SubMConv2d: snni.SpconvBnNd,
spconv.SparseConv2d: snni.SpconvBnNd,
spconv.SparseInverseConv2d: snni.SpconvBnNd,
spconv.SubMConv3d: snni.SpconvBnNd,
spconv.SparseConv3d: snni.SpconvBnNd,
spconv.SparseInverseConv3d: snni.SpconvBnNd,
}
if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
fused_module_class = fused_module_class_map.get((type(conv)), None)
if fused_module_class is not None:
return fused_module_class(conv, bn)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
else:
return fuse_spconv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[spconv.SparseSequential]] = None
if conv.training:
map_to_fused_module_train = {
spconv.SubMConv1d: snni.SpconvBnReLUNd,
spconv.SparseConv1d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv1d: snni.SpconvBnReLUNd,
spconv.SubMConv2d: snni.SpconvBnReLUNd,
spconv.SparseConv2d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv2d: snni.SpconvBnReLUNd,
spconv.SubMConv3d: snni.SpconvBnReLUNd,
spconv.SparseConv3d: snni.SpconvBnReLUNd,
spconv.SparseInverseConv3d: snni.SpconvBnReLUNd,
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module = map_to_fused_module_train.get(type(conv), None)
if fused_module is not None:
return fused_module(conv, bn, relu)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
map_to_fused_module_eval = {
spconv.SubMConv1d: snni.SpconvReLUNd,
spconv.SparseConv1d: snni.SpconvReLUNd,
spconv.SparseInverseConv1d: snni.SpconvReLUNd,
spconv.SubMConv2d: snni.SpconvReLUNd,
spconv.SparseConv2d: snni.SpconvReLUNd,
spconv.SparseInverseConv2d: snni.SpconvReLUNd,
spconv.SubMConv3d: snni.SpconvReLUNd,
spconv.SparseConv3d: snni.SpconvReLUNd,
spconv.SparseInverseConv3d: snni.SpconvReLUNd,
}
fused_module = map_to_fused_module_eval.get(type(conv), None)
if fused_module is not None:
fused_conv = fuse_spconv_bn_eval(conv, bn)
return fused_module(fused_conv, relu)
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
DEFAULT_SPCONV_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(spconv.SubMConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv1d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SubMConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv2d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv2d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SubMConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SubMConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(spconv.SparseInverseConv3d, nn.BatchNorm1d): fuse_conv_bn,
(spconv.SparseInverseConv3d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
}
# Default map for swapping float module to qat modules
DEFAULT_SPCONV_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
# nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
snni.SpconvBnNd: SparseConvBn,
snni.SpconvBnReLUNd: SparseConvBnReLU,
}
import torch
from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d
from torch.nn.utils.parametrize import type_before_parametrizations
import torch.ao.nn.intrinsic as nni
from spconv.pytorch.conv import SparseConvolution
class SpconvReLUNd(nni._FusedModule):
r"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert isinstance(conv, SparseConvolution) and isinstance(relu, ReLU), \
'Incorrect types for input modules{}{}'.format(
type(conv), type(relu))
super().__init__(conv, relu)
class SpconvBnNd(nni._FusedModule):
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert isinstance(conv, SparseConvolution) and isinstance(bn, BatchNorm1d), \
'Incorrect types for input modules{}{}'.format(
type(conv), type(bn))
super().__init__(conv, bn)
class SpconvBnReLUNd(nni._FusedModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert isinstance(conv, SparseConvolution) and isinstance(bn, BatchNorm1d) and \
isinstance(relu, ReLU), 'Incorrect types for input modules{}{}{}' \
.format(type(conv), type(bn), type(relu))
super().__init__(conv, bn, relu)
from spconv.pytorch.modules import SparseModule
from spconv.pytorch.conv import SparseConvolution
from spconv.pytorch.core import SparseConvTensor
import torch
class ConvBatchNormAddAct(torch.nn.Module):
"""for simple int8 residual op fusion, we can use this module to handle add.
"""
def __init__(self, conv: SparseConvolution, bn: torch.nn.BatchNorm1d, act: torch.nn.ReLU) -> None:
super().__init__()
self.conv = conv
self.bn = bn
self.act = act
def forward(self, x: SparseConvTensor, x_add: SparseConvTensor):
x = self.conv(x)
x = x.replace_feature(self.bn(x.features))
return self.act(x.replace_feature(x.features + x_add.features))
import torch
import copy
from cumm import tensorview as tv
def fuse_spconv_bn_weights(conv_w_OKI, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
NDim = conv_w_OKI.ndim - 2
permute = [0, NDim+1] + [i+1 for i in range(NDim)]
conv_w_OIK = conv_w_OKI.permute(*permute)
# OIDHW
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
conv_w_OIK = conv_w_OIK * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w_OIK.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
permute = [0,] + [i+2 for i in range(NDim)] + [1,]
conv_w_OKI = conv_w_OIK.permute(*permute).contiguous()
return torch.nn.Parameter(conv_w_OKI), torch.nn.Parameter(conv_b)
def fuse_spconv_bn_eval(conv, bn):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_conv.weight, fused_conv.bias = \
fuse_spconv_bn_weights(fused_conv.weight, fused_conv.bias,
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
return fused_conv
def fuse_spconv_act_eval(conv, act):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert(not (conv.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
if isinstance(act, torch.nn.ReLU):
fused_conv.act_type = tv.gemm.Activation.ReLU
elif isinstance(act, torch.nn.LeakyReLU):
fused_conv.act_type = tv.gemm.Activation.LeakyReLU
fused_conv.act_alpha = act.negative_slope
else:
raise NotImplementedError
return fused_conv
......@@ -53,7 +53,7 @@ class TestCase(unittest.TestCase):
print("not equal rhs = ", y)
np.testing.assert_array_equal(a, b)
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg: str = ""):
"""Asserts that two numpy arrays, or dicts of same, have near values.
This does not support nested dicts.
Args:
......@@ -68,22 +68,22 @@ class TestCase(unittest.TestCase):
"""
is_a_dict = isinstance(a, dict)
if is_a_dict != isinstance(b, dict):
raise ValueError("Can't compare dict to non-dict, %s vs %s." %
raise ValueError(f"Can't compare dict to non-dict, %s vs %s. {msg}" %
(a, b))
if is_a_dict:
self.assertCountEqual(a.keys(),
b.keys(),
msg="mismatched keys, expected %s, got %s" %
msg=f"mismatched keys, expected %s, got %s. {msg}" %
(a.keys(), b.keys()))
for k in a:
self._assertArrayLikeAllClose(a[k],
b[k],
rtol=rtol,
atol=atol,
msg="%s: expected %s, got %s." %
msg=f"%s: expected %s, got %s. {msg}" %
(k, a, b))
else:
self._assertArrayLikeAllClose(a, b, rtol=rtol, atol=atol)
self._assertArrayLikeAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
......
......@@ -37,6 +37,8 @@ from cumm import tensorview as tv
from spconv.constants import SPCONV_ALLOW_TF32
from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType
import os
from cumm.dtypes import get_npdtype_from_tvdtype
from cumm.gemm.codeops import div_up
from spconv.core import AlgoHint, ConvAlgo
from spconv.pytorch.conv import expand_nd
......@@ -63,14 +65,18 @@ NUMPY_DTYPE_TO_TORCH = {
}
class SparseConvTester:
def __init__(self, algo: ConvAlgo, subm: bool, shape: List[int], bs: int, dtype: np.dtype, N: int, K: int, C: int,
ksize: int, stride: int, padding: int, dilation: int, check_bias: bool = False, check_act: bool = False) -> None:
def __init__(self, algo: ConvAlgo, subm: bool, shape: List[int], bs: int, dtype: np.dtype, out_dtype: np.dtype, N: int, K: int, C: int,
ksize: int, stride: int, padding: int, dilation: int, check_bias: bool = False, check_act: bool = False,
check_int8_infer: bool = False, dtype_comp: np.dtype = np.dtype(np.float32)) -> None:
ndim = 3
transpose = False
self.shape = shape
self.bs = bs
self.dtype = dtype
self.out_dtype = out_dtype
self.dtype_th = NUMPY_DTYPE_TO_TORCH[dtype]
self.out_dtype_th = NUMPY_DTYPE_TO_TORCH[out_dtype]
self.K = K
self.C = C
self.ksize = expand_nd(ndim, ksize)
......@@ -82,6 +88,12 @@ class SparseConvTester:
op = expand_nd(ndim, 0)
self.kv: int = np.prod(self.ksize)
self.num_split = 1 if algo == ConvAlgo.MaskImplicitGemm else 2
self.output_scale: float = 1.0
self.check_int8_infer = check_int8_infer
if check_int8_infer:
assert check_bias and self.dtype == np.int8
self.dtype_comp = dtype_comp
if not subm:
if transpose:
out_shape = ops.get_deconv_output_size(shape, self.ksize, self.stride,
......@@ -91,6 +103,7 @@ class SparseConvTester:
self.padding, self.dilation)
else:
out_shape = shape
self.scales = np.random.uniform(0.5, 1.5, size=K).astype(dtype_comp)
sparse_dict = generate_sparse_data(shape, [N] * bs, C)
......@@ -109,7 +122,7 @@ class SparseConvTester:
self.pair_native = pair_ref
self.indice_num_per_loc = indice_num_per_loc
self.use_direct_table = True
self.mask_int_count = div_up(self.kv, 32)
self.out_shape = out_shape
if algo == ConvAlgo.Native:
self.out_inds: torch.Tensor = out_inds
......@@ -135,7 +148,6 @@ class SparseConvTester:
self.mask_argsort_fwd_splits = res[6]
self.mask_argsort_bwd_splits = res[7]
self.masks = res[8]
self.mask_int_count = res[9]
self.out_inds_scalar = Fsp._indice_to_scalar(self.out_inds.long(), [bs, *out_shape])
......@@ -159,18 +171,28 @@ class SparseConvTester:
self.check_act = check_act
self.subm = subm
self.output_add_scale = 1.0
if dtype == np.int8:
self.inp = np.random.randint(-2, 2, size=[voxels_np.shape[0],
self.inp = np.random.randint(-1, 1, size=[voxels_np.shape[0],
C]).astype(np.int8)
self.weight = np.random.randint(-2, 2, size=[K, *self.ksize,
self.weight = np.random.randint(-1, 1, size=[K, *self.ksize,
C]).astype(np.int8)
self.output = np.random.randint(-2, 2, size=[
self.output = np.random.randint(-1, 1, size=[
self.out_inds.shape[0], K
]).astype(dtype)
self.bias = np.random.randint(-2, 2, size=[
K
]).astype(dtype)
]).astype(out_dtype)
self.output_add = np.random.randint(-1, 1, size=[
self.out_inds.shape[0], K
]).astype(out_dtype)
self.output_add_scale = 14.2
if check_int8_infer:
self.bias = np.random.uniform(-5, 5, size=[
K
]).astype(dtype_comp)
else:
self.bias = np.random.randint(-4, 4, size=[
K
]).astype(dtype)
else:
self.inp = np.random.uniform(-1, 1, size=[
voxels_np.shape[0], C
......@@ -178,28 +200,31 @@ class SparseConvTester:
self.weight = np.random.uniform(-1, 1, size=[K, *self.ksize, C]).astype(dtype)
self.output = np.random.uniform(-1, 1, size=[
self.out_inds.shape[0], K
]).astype(dtype)
]).astype(out_dtype)
self.output_add = np.random.uniform(-1, 1, size=[
self.out_inds.shape[0], K
]).astype(out_dtype)
self.bias = np.random.uniform(-1, 1, size=[
K
]).astype(dtype)
# self.bias[:] = 0
# self.scales[:] = 1
self.weight_ref = self.weight.transpose(1, 2, 3, 0, 4)
self.weight_ref = np.ascontiguousarray(self.weight_ref).reshape(-1, K, C)
self.out_ref, self.din_ref, self.dw_ref = self._get_ref_output()
if check_bias:
self.out_ref += self.bias
# relu
if check_act:
self.out_ref = np.maximum(self.out_ref, 0)
self.dw_ref = np.ascontiguousarray(self.dw_ref.transpose(1, 0, 2).reshape(K, *self.ksize, C))
self.arch = tv.get_compute_capability()
def get_output_ref_spt(self):
return SparseConvTensor(torch.from_numpy(self.out_ref).cuda(), self.ref_out_inds, self.out_shape, self.bs)
def _get_ref_output(self):
output_ref = np.zeros_like(self.output, dtype=np.float32)
out_dtype = np.float32
if self.dtype == np.int8:
out_dtype = np.int32
output_ref = np.zeros_like(self.output, dtype=out_dtype)
dinput_ref = np.zeros_like(self.inp, dtype=np.float32)
dw_ref = np.zeros_like(self.weight_ref,
dtype=np.float32) # KV, K, C
......@@ -215,9 +240,14 @@ class SparseConvTester:
i_inds = self.indice_pairs_np[0][filter_offset][:nhot]
o_inds = self.indice_pairs_np[1][filter_offset][:nhot]
a = self.inp[i_inds]
cc = a.astype(
np.float32) @ self.weight_ref[filter_offset].T.astype(
np.float32)
if self.dtype == np.int8:
cc = a.astype(
np.int32) @ self.weight_ref[filter_offset].T.astype(
np.int32)
else:
cc = a.astype(
np.float32) @ self.weight_ref[filter_offset].T.astype(
np.float32)
output_ref[o_inds] += cc
# we use random output as dout here
a = self.output[self.out_order][o_inds]
......@@ -233,8 +263,25 @@ class SparseConvTester:
dw_res = out_gather.astype(
np.float32).T @ inp_gather.astype(np.float32)
dw_ref[filter_offset] = dw_res
if not self.check_int8_infer:
if self.check_bias:
output_ref += self.bias
# relu
if self.check_act:
output_ref = np.maximum(output_ref, 0)
if self.dtype == np.int8:
output_ref = np.clip(output_ref, -127, 127)
if self.check_int8_infer:
rescaled = output_ref.astype(self.dtype_comp) * self.scales.astype(self.dtype_comp)
rescaled += self.bias.astype(self.dtype_comp)
rescaled += self.output_add[self.out_order].astype(self.dtype_comp) * self.output_add_scale
if self.check_act:
rescaled = np.maximum(rescaled, 0)
if self.out_dtype == np.int8:
output_ref = np.clip(np.round(rescaled), -128, 127).astype(np.int8)
else:
output_ref = rescaled.astype(self.out_dtype)
else:
output_ref = np.clip(output_ref, -127, 127)
return output_ref, dinput_ref, dw_ref
def get_operands(self, op_type: ConvOpType):
......@@ -248,7 +295,7 @@ class SparseConvTester:
else:
weight_tv = tv.from_numpy(self.weight).cuda()
if op_type == ConvOpType.kForward:
output_tv = zeros_func(list(self.output.shape), self.dtype, 0)
output_tv = zeros_func(list(self.output.shape), self.out_dtype, 0)
else:
output_tv = tv.from_numpy(self.output).cuda()
return inp_tv, weight_tv, output_tv
......@@ -280,26 +327,31 @@ def _test_impgemm_conv_cuda(subm: bool):
device = torch.device("cuda:0")
shapes = [[19, 18, 17]]
batchsizes = [1]
dtypes = [np.float32, np.float16]
# dtypes = [(np.float32, np.float32), (np.float16, np.float16)]
# dtypes = [np.float16]
dtypes = [(np.int8, np.int8), (np.int8, np.float32), (np.int8, np.float16)]
dtypes = [(np.int8, np.int8)]
# dtypes = [(np.float16, np.float16)]
# dtypes = [np.int8]
test_case = TestCase()
# in_channels = [32]
# out_channels = [32, 48, 64]
in_channels = [32, 47]
out_channels = [32, 48, 62]
# in_channels = [32]
# out_channels = [32]
in_channels = [16]
out_channels = [16]
multiple_base = 16
if subm:
ksizes = [3, (3, 3, 5), (3, 5, 5), 5]
# ksizes = [3, (3, 3, 5), (3, 5, 5), 5]
ksizes = [3]
strides = [1]
paddings = [0]
dilations = [1]
else:
ksizes = [2, 3, (3, 3, 4), 4, (4, 5, 5), 5]
ksizes = [2, 3]
strides = [1, 2, 3]
paddings = [0, 1]
dilations = [1, 2]
......@@ -310,9 +362,16 @@ def _test_impgemm_conv_cuda(subm: bool):
]
arch = torch.cuda.get_device_capability()
force_nvrtc = False
for shape, bs, C, K, k, s, p, d, algo, dtype in tqdm.tqdm(params_grid(
for shape, bs, C, K, k, s, p, d, algo, dtype_outdtype in tqdm.tqdm(params_grid(
shapes, batchsizes, in_channels, out_channels, ksizes,
strides, paddings, dilations, algos, dtypes)):
dtype, out_dtype = dtype_outdtype
if (C % 16 != 0 or K % 16 != 0) and dtype == np.int8:
continue
dcomp = np.float32
check_int8_infer = True
if dtype != np.int8:
check_int8_infer = False
shape_prod = np.prod(shape)
num_batch = np.random.randint(int(0.2 * shape_prod), int(0.7 * shape_prod))
# C = np.random.randint(int(0.3 * C), int(0.7 * C))
......@@ -320,32 +379,51 @@ def _test_impgemm_conv_cuda(subm: bool):
multipler = max(C, K) / multiple_base
multipler = max(multipler, 1.0)
# print(num_batch)
tester = SparseConvTester(algo, subm, shape, bs, dtype, num_batch, K, C, k, s, p, d, check_bias=True, check_act=True)
tester = SparseConvTester(algo, subm, shape, bs, dtype, out_dtype, num_batch, K, C, k, s, p, d,
check_bias=True, check_act=True, check_int8_infer=check_int8_infer, dtype_comp=np.float32)
enable_dy_mask = tester.kv > 32
output_add_cuda = tv.from_numpy(tester.output_add).cuda()
bias = None
scales = None
act = tv.gemm.Activation.None_
if tester.check_bias:
bias = tv.from_numpy(tester.bias).cuda()
if check_int8_infer:
bias = tv.from_numpy(tester.bias.astype(dcomp)).cuda()
else:
bias = tv.from_numpy(tester.bias).cuda()
if check_int8_infer:
scales = tv.from_numpy(tester.scales.astype(dcomp)).cuda()
atol, rtol = dtype_to_tol[dtype]
mask_width_to_mask_out_fwd: Dict[int, torch.Tensor] = {}
mask_width_to_mask_out_bwd: Dict[int, torch.Tensor] = {}
op_types = [ConvOpType.kForward, ConvOpType.kBackwardInput]
spk = 1
for op_type in op_types:
if tester.dtype == np.int8 and op_type != ConvOpType.kForward:
continue
inp_tv, weight_tv, output_tv = tester.get_operands(op_type)
if SPCONV_CPP_GEMM:
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch, op_type.value, -1, True, False,
use_tf32=SPCONV_ALLOW_TF32)
use_tf32=SPCONV_ALLOW_TF32, bias=bias if bias is not None else tv.Tensor(),
scale=scales if scales is not None else tv.Tensor())
else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1,
use_tf32=SPCONV_ALLOW_TF32)
use_tf32=SPCONV_ALLOW_TF32, bias=bias if bias is not None else tv.Tensor(),
scale=scales if scales is not None else tv.Tensor())
if op_type == ConvOpType.kForward and tester.check_act:
act = tv.gemm.Activation.ReLU
else:
act = tv.gemm.Activation.None_
assert avail_desps
for desp in avail_desps:
dcomp = get_npdtype_from_tvdtype(desp.dcomp)
if enable_dy_mask and not desp.dynamic_mask:
continue
if tester.check_int8_infer and not desp.is_int8_inference:
continue
if not subm:
if op_type == ConvOpType.kForward:
output_tv.zero_()
......@@ -353,11 +431,13 @@ def _test_impgemm_conv_cuda(subm: bool):
inp_tv.zero_()
# this algo must success
mask_width = desp.tile_shape[0]
alpha = 1.0
if tester.check_int8_infer:
alpha = tester.output_scale
# if mask_width != 32:
# continue
if mask_width not in mask_width_to_mask_out_fwd:
mask_width_to_mask_out_fwd[mask_width] = torch.zeros([2, tester.mask_int_count * div_up(tester.out_inds.shape[0], mask_width)],
mask_width_to_mask_out_fwd[mask_width] = torch.zeros([2, div_up(tester.out_inds.shape[0], mask_width), tester.mask_int_count],
dtype=torch.int32,
device=tester.device)
mask_output_fwd = mask_width_to_mask_out_fwd[mask_width]
......@@ -365,6 +445,11 @@ def _test_impgemm_conv_cuda(subm: bool):
bias_cur = bias
if op_type != ConvOpType.kForward:
bias_cur = None
output_add_cur_tv = tv.Tensor()
output_add_cur = None
if is_fwd and tester.check_int8_infer:
output_add_cur = output_add_cuda
output_add_cur_tv = output_add_cur
if subm:
if desp.op_type.value == ConvOpType.kForward.value:
indice_pairs = tester.pair_fwd
......@@ -376,10 +461,13 @@ def _test_impgemm_conv_cuda(subm: bool):
# print([bin(x.item()) for x in masks])
for j in range(tester.num_split):
beta = 1 if j > 0 else 0
if bias_cur is not None:
if bias_cur is not None and not tester.check_int8_infer:
# this beta is used for C-beta (use C as bias, not standalone bias)
beta = 1
if j > 0:
bias_cur = None
if output_add_cur is not None and tester.check_int8_infer:
beta = tester.output_add_scale
mask_filter = tester.masks[j].item()
reverse_mask = False
if desp.op_type.value == ConvOpType.kBackwardWeight.value:
......@@ -396,7 +484,6 @@ def _test_impgemm_conv_cuda(subm: bool):
# desp.is_nvrtc = True
# print(force_nvrtc, desp.op_type, op_type)
if SPCONV_CPP_GEMM:
CONV_CPP.run_with_tuned_result(
ConvTuneResult(desp, tester.arch, spk),
desp.op_type.value,
......@@ -410,13 +497,14 @@ def _test_impgemm_conv_cuda(subm: bool):
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
alpha=alpha,
beta=beta,
verbose=False,
force_nvrtc=force_nvrtc,
bias=bias_cur if is_fwd and bias_cur is not None else tv.Tensor(),
scale=scales if is_fwd and scales is not None else tv.Tensor(),
act_type=act,
mask_int_count=tester.mask_int_count,
)
output_add=output_add_cur_tv)
else:
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
......@@ -431,12 +519,14 @@ def _test_impgemm_conv_cuda(subm: bool):
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
alpha=alpha,
beta=beta,
verbose=False,
force_nvrtc=force_nvrtc,
bias=bias_cur if is_fwd else None,
scale=scales if is_fwd else None,
act_type=act,
mask_int_count=tester.mask_int_count
output_add=output_add_cur,
)
else:
......@@ -465,10 +555,13 @@ def _test_impgemm_conv_cuda(subm: bool):
for j in range(tester.num_split):
# beta = 1 if j == 1 else 0
beta = 1 if j > 0 else 0
if bias_cur is not None:
if bias_cur is not None and not tester.check_int8_infer:
# this beta is used for C-beta (use C as bias, not standalone bias)
beta = 1
if j > 0:
bias_cur = None
if output_add_cur is not None and tester.check_int8_infer:
beta = tester.output_add_scale
mask_filter = tester.masks[j].item()
reverse_mask = False
if desp.op_type.value == ConvOpType.kBackwardWeight.value:
......@@ -476,7 +569,6 @@ def _test_impgemm_conv_cuda(subm: bool):
else:
mask_op = mask_ops[j]
if SPCONV_CPP_GEMM:
CONV_CPP.run_with_tuned_result(
ConvTuneResult(desp, tester.arch, spk),
desp.op_type.value,
......@@ -493,9 +585,10 @@ def _test_impgemm_conv_cuda(subm: bool):
beta=beta,
verbose=False,
force_nvrtc=force_nvrtc,
bias=bias if is_fwd and bias is not None else tv.Tensor(),
bias=bias_cur if is_fwd and bias_cur is not None else tv.Tensor(),
scale=scales if is_fwd and scales is not None else tv.Tensor(),
act_type=act,
mask_int_count=tester.mask_int_count,
output_add=output_add_cur_tv,
)
else:
CONV.run_with_tuned_result(
......@@ -514,9 +607,10 @@ def _test_impgemm_conv_cuda(subm: bool):
beta=beta,
verbose=False,
force_nvrtc=force_nvrtc,
bias=bias if is_fwd else None,
bias=bias_cur if is_fwd else None,
scale=scales if is_fwd else None,
act_type=act,
mask_int_count=tester.mask_int_count,
output_add=output_add_cur,
)
out_ref = tester.out_ref
......@@ -526,6 +620,8 @@ def _test_impgemm_conv_cuda(subm: bool):
out_my = output_tv.cpu().numpy()
out_my = out_my[tester.out_order]
if dtype != np.float16:
if dtype == np.int8:
print("max int8 diff", np.abs(out_ref - out_my).max())
test_case.assertAllClose(out_ref, out_my, atol=atol, rtol=rtol)
else:
error_norm = np.linalg.norm(out_ref.reshape(-1) - out_my.reshape(-1))
......@@ -539,86 +635,86 @@ def _test_impgemm_conv_cuda(subm: bool):
else:
error_norm = np.linalg.norm(din_ref.reshape(-1) - din_my.reshape(-1))
assert error_norm < 10 * multipler, f"{desp}, {error_norm}, {k}, {s}, {p}, {d}"
inp_tv, weight_tv, output_tv = tester.get_operands(ConvOpType.kBackwardWeight)
for spk in [1, 4, 16, 64]:
for mask_width, mask_output in mask_width_to_mask_out_fwd.items():
if SPCONV_CPP_GEMM:
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch,
ConvOpType.kBackwardWeight.value, mask_width, True, False,
use_tf32=SPCONV_ALLOW_TF32)
else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width,
use_tf32=SPCONV_ALLOW_TF32)
for desp in avail_desps:
weight_tv.zero_()
if subm:
indice_pairs = tester.pair_fwd
for j in range(tester.num_split):
beta = 0
mask_filter = tester.masks[j].item()
mask_op = mask_output[j]
mask_op_tv = torch_tensor_to_tv(mask_op, dtype=tv.uint32)
# mask_op_np = mask_op_tv.cpu().numpy()
# bit_ref = np.bitwise_or.reduce(mask_op_np, axis=0)
# bit_my = mask_filter
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
mask_op_tv,
torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]),
tv.Tensor(),
torch_tensor_to_tv(indice_pairs),
reverse_mask=False,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
mask_int_count=tester.mask_int_count,
)
if not tester.check_int8_infer:
inp_tv, weight_tv, output_tv = tester.get_operands(ConvOpType.kBackwardWeight)
for spk in [1, 4, 16, 64]:
for mask_width, mask_output in mask_width_to_mask_out_fwd.items():
if SPCONV_CPP_GEMM:
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch,
ConvOpType.kBackwardWeight.value, mask_width, True, False,
use_tf32=SPCONV_ALLOW_TF32)
else:
indice_pairs = tester.pair_fwd # inp -> out
mask_ops = tester.pair_mask_fwd_splits
mask_argsorts = tester.mask_argsort_fwd_splits
for j in range(tester.num_split):
# beta = 1 if j == 1 else 0
beta = 0
mask_filter = tester.masks[j].item()
reverse_mask = False
mask_op = mask_output[j]
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width,
use_tf32=SPCONV_ALLOW_TF32)
for desp in avail_desps:
if enable_dy_mask and not desp.dynamic_mask:
continue
weight_tv.zero_()
if subm:
indice_pairs = tester.pair_fwd
for j in range(tester.num_split):
beta = 0
mask_filter = tester.masks[j].item()
mask_op = mask_output[j]
mask_op_tv = torch_tensor_to_tv(mask_op, dtype=tv.uint32)
# mask_op_np = mask_op_tv.cpu().numpy()
# bit_ref = np.bitwise_or.reduce(mask_op_np, axis=0)
# bit_my = mask_filter
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
mask_op_tv,
torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]),
tv.Tensor(),
torch_tensor_to_tv(indice_pairs),
reverse_mask=False,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
else:
indice_pairs = tester.pair_fwd # inp -> out
mask_ops = tester.pair_mask_fwd_splits
mask_argsorts = tester.mask_argsort_fwd_splits
for j in range(tester.num_split):
# beta = 1 if j == 1 else 0
beta = 0
mask_filter = tester.masks[j].item()
reverse_mask = False
mask_op = mask_output[j]
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(mask_argsorts[j]),
torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
mask_int_count=tester.mask_int_count,
)
dw_ref = tester.dw_ref
dw_my = weight_tv.cpu().numpy()
if dtype != np.float16:
# print(desp, spk, K, C, mask_width, algo)
test_case.assertAllClose(dw_ref, dw_my, atol=atol, rtol=rtol)
else:
error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1))
# print(desp, error_norm)
if (error_norm > 5):
print(f"{desp}, Error={error_norm}, {spk}")
assert error_norm < 10 * multipler
CONV.run_with_tuned_result(
BestConvAlgoByProfile(desp, tester.arch, spk),
desp.op_type.value,
inp_tv,
weight_tv,
output_tv,
torch_tensor_to_tv(mask_op, dtype=tv.uint32),
torch_tensor_to_tv(mask_argsorts[j]),
torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
torch_tensor_to_tv(indice_pairs),
reverse_mask,
mask_filter=mask_filter,
mask_width=mask_width,
beta=beta,
verbose=False,
)
dw_ref = tester.dw_ref
dw_my = weight_tv.cpu().numpy()
if dtype != np.float16:
test_case.assertAllClose(dw_ref, dw_my, atol=atol, rtol=rtol)
else:
error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1))
# print(desp, error_norm)
if (error_norm > 5):
print(f"{desp}, Error={error_norm}, {spk}")
assert error_norm < 10 * multipler
def _test_native_conv_cuda(subm: bool):
ndim = 3
......@@ -924,7 +1020,7 @@ def _test_native_conv_cuda(subm: bool):
def test_all_algo_unit():
# for i in range(5):
_test_impgemm_conv_cuda(True)
# _test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment