Commit 42c7cdad authored by yan.yan's avatar yan.yan
Browse files

v2.3.0: int8 quantization

parent 1f6deed6
...@@ -16,7 +16,7 @@ jobs: ...@@ -16,7 +16,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
cuda-version: ['10.2', '11.3', '11.4', '11.6', '11.7', '11.8'] cuda-version: ['10.2', '11.3', '11.4', '11.6', '11.7', '11.8', '12.0']
steps: steps:
- uses: actions/checkout@master - uses: actions/checkout@master
- uses: dorny/paths-filter@v2 - uses: dorny/paths-filter@v2
...@@ -116,7 +116,7 @@ jobs: ...@@ -116,7 +116,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] # this version is only used for upload. python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] # this version is only used for upload.
cuda-version: ['102', '113', '114', '116', '117', '118', ''] cuda-version: ['102', '113', '114', '116', '117', '118', '120', '']
steps: steps:
- uses: actions/checkout@master - uses: actions/checkout@master
......
# Changelog # Changelog
## [2.3.0] - 2023-01-19
### Added
- Add int8 quantization support
- Add large kernel support for implicit gemm (kv <= 128)
## [2.2.6] - 2022-11-06 ## [2.2.6] - 2022-11-06
### Fixed ### Fixed
- CI fail because of pypi temporary shutdown. assign a new version and run again. - CI fail because of pypi temporary shutdown. assign a new version and run again.
......
...@@ -77,6 +77,8 @@ Update Spconv: you **MUST UNINSTALL** all spconv/cumm/spconv-cuxxx/cumm-cuxxx fi ...@@ -77,6 +77,8 @@ Update Spconv: you **MUST UNINSTALL** all spconv/cumm/spconv-cuxxx/cumm-cuxxx fi
## NEWS ## NEWS
* spconv 2.3: int8 quantization support. see docs and examples for more details.
* spconv 2.2: ampere feature support (by [EvernightAurora](https://github.com/EvernightAurora)), pure c++ code generation, nvrtc, drop python 3.6 * spconv 2.2: ampere feature support (by [EvernightAurora](https://github.com/EvernightAurora)), pure c++ code generation, nvrtc, drop python 3.6
## Spconv 2.2 vs Spconv 2.1 ## Spconv 2.2 vs Spconv 2.1
......
## Int8 Guide
This document aims to show how to perform PTQ/QAT/int8 inference in pytorch.
**WARNING** spconv int8 only support CUDA backend.
**WARNING** spconv int8 PTQ/QAT requires torch >= 1.13.
### Spconv Int8 Support
spconv currently support int8 kernels with following requirements: ```input_channel % 32 == 0 && output_channel % 32 == 0```. Int8 kernels runs faster than fp16 kernel with following shapes:
```
C == 32 && K == 64
C == 64 && K == 32
C >= 64 && K >= 64
```
spconv currently don't support pooling int8 operation.
### Prepare model (Common)
We need to modify model to make sure it can be symbolic traced by ```torch.fx```. Here are some tips and requirements:
* only ```forward``` and its content can be traced.
* all conditional statement, such as ```if``` and ```assert```, can't depend on ```forward``` arguments. use a environment variable / global variable to remove asserts during tracing.
* all traced functions can't have dynamic arguments ```*args``` and ```**kwargs```. change them to static arguments or make them non traceable.
* non traceable code can be ignored by top-level functions and ```torch.nn.Module```, put all non-traceable code to state-less top-level functions or Modules:
1. write a top-level function (function declared in global scope), then use ```torch.fx.wrap``` to make sure it's non-traceable
2. write a Module that contains all non-traceable code.
* all attributes, methods and static functions are lost except non-traceable modules after tracing. if you still want to use some method in traced modules, you need to refactor them to a non-traceable child module, or make them static.
### Prepare model (Spconv)
Spconv int8 support subm residual fusion:
```mermaid
graph TD;
X-->Add;
A-->SubMConv;
SubMConv-->BatchNorm;
BatchNorm-->Add;
Add-->ReLU
```
to
```mermaid
graph TD;
X-->SubMConvAddReLU;
A-->SubMConvAddReLU
```
Due to limitations of ```torch.fx```, this fusion requires your residual code have no spconv stuffs such as ```replace_feature```.
The following residual module can't be fused due to ```out.replace_feature``` and ```out.features```, this operations are recorded as a standalone node in graph, so it's hard to recognize and fuse them.
```Python
class SparseBasicBlock(spconv.SparseModule):
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True))
self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = nn.Identity()
def forward(self, x: spconv.SparseConvTensor):
identity = x.features
out = self.conv1_bn_relu(x)
out = self.conv2_bn(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out.replace_feature(self.relu(out.features + identity))
return out
```
The following residual module can be fused, it requires ```SparseReLU``` to avoid ```replace_feature``` node generated in ```torch.fx```
```Python
class SparseBasicBlock(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True))
self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2)
self.relu = spconv.SparseReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor):
identity = x
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1_bn_relu(x)
out = self.conv2_bn(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.relu(out + identity)
return out
```
### Post Training Quantization (PTQ)
see [example](../example/mnist/mnist_ptq.py) for a runnable example in mnist.
To perform PTQ in pytorch, we firstly need to trace model via ```torch.fx``` and insert observers to model.
spconv provide a simple function to do this:
```Python
import spconv.pytorch.quantization as spconvq
import torch.ao.quantization.quantize_fx as qfx
model = ...
is_qat = False
qconfig_mapping = spconvq.get_default_spconv_qconfig_mapping(is_qat)
# disable quantization for some layers here:
qconfig_mapping.set_module_name_regex("foo.*bar.*", None)
# disable quantization by type here:
qconfig_mapping.set_object_type(ModuleClass, None)
prepare_cfg = spconvq.get_spconv_prepare_custom_config()
# preserve static attrs for your module here:
prepare_cfg.preserved_attributes = [...]
# add nontraceable modules here:
prepare_cfg.non_traceable_module_classes.extend([...])
backend_cfg = spconvq.get_spconv_backend_config()
# add custom qconfig for your non-traceable operators:
backend_cfg.set_backend_pattern_config(BackendPatternConfig(some_op_or_module_class).set_observation_type(
ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT).set_dtype_configs(
[non_weighted_op_qint8_dtype_config]))
# keep in mind that all model user attrs are lost in prepared_model.
prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
```
the ```prepared_model``` contains PTQ observers to calculate int8 scales, then you can run inference with some test data:
```Python
for x in loader:
prepared_model(x)
```
after inference of ```prepared_model```, all params required by int8 inference are calculated.
To perform debug int8 inference in pytorch, we need to do following thing:
```Python
with_linear = True
# pytorch don't support per-channel weight for Linear in native backend, we implement a simple fake-quantization to do this.
spconvq.prepare_spconv_torch_inference(with_linear) # must call before convert_fx
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
# transform all torch quantize_per_tensor to custom quantize_per_tensor to support SparseConvTensor
converted_model = spconvq.transform_qdq(converted_model)
# remove all dq node in fused residual node
converted_model = spconvq.remove_conv_add_dq(converted_model)
test(converted_model)
```
After test, we need to convert torch model to tensorrt. please see [this doc](TENSORRT_INT8_GUIDE.md) for more details.
Since all int8 kernels are compiled in runtime in spconv python package, you can use environment variable ```SPCONV_INT8_DEBUG=1``` to remove most of candidate int8 kernels to reduce compile time.
If you get error that some op don't support CUDA backend, just disable quantization for them in ```qconfig_mapping```.
### Quantization Aware Training (QAT)
see [example](../example/mnist/mnist_qat.py) for a runnable example in mnist.
To perform QAT in pytorch, we firstly need to trace model via ```torch.fx``` and insert observers and fake quantize nodes to model.
```Python
import spconv.pytorch.quantization as spconvq
import torch.ao.quantization.quantize_fx as qfx
model = ...
is_qat = True
qconfig_mapping = spconvq.get_default_spconv_qconfig_mapping(is_qat)
prepare_cfg = spconvq.get_spconv_prepare_custom_config()
backend_cfg = spconvq.get_spconv_backend_config()
# keep in mind that all model user attrs are lost in prepared_model.
prepared_model = qfx.prepare_qat_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
```
the ```prepared_model``` contains QAT observers and fake quantize nodes, then you can run training:
```Python
train(prepared_model)
```
After training, we can use same code in PTQ to run torch int8 inference.
### Performance Guide
For int8 kernels, we find that disable sort in implicit gemm can increase performance, so don't forget to disable sort by:
```C++
bool do_sort = false;
pair_res = SpconvOps::get_indice_pairs_implicit_gemm(
alloc, input_indices_real, batch_size, input_dims,
static_cast<int>(conv_algo), ksize, stride, padding, dilation,
{0, 0, 0}, is_subm, transpose, false /*is_train*/,
reinterpret_cast<std::uintptr_t>(stream), out_inds_num_limit,
tv::CUDAKernelTimer(false), use_direct_table, do_sort);
```
# TensorRT Int8 Guide
## Prerequisites
### Plugin
Due to limitation of tensorrt, following requirements must be satisfied:
1. pad all inputs to a static shape
2. use a tensor to save current number of voxels, copy it to cpu and slice all inputs to real shape during inference (enqueue).
3. ```supportsFormatCombination``` must allow exactly one combination, i.e. we must set dtype of this layer during network build. for example, if we want to use fp16, this function must accept fp16 and reject other dtypes to avoid tensorrt perform dtype/format selection during engine build.
4. Number of dimensions of int8 tensor for plugin must larger or equal to 3. (tested in tensorrt 8.4)
5. TensorRT version >= 8.4, tensorrt 8.0 don't support int8 plugin
### Pytorch
* PTQ/QAT model is ready
### Spconv Int8 Scale/Bias Format
basic rule:
```C++
fp32_data = float(int8_data) * scale
int8_data = int8_t(saturate(round(fp32_data / scale)))
```
assume we have a pytorch quantized layer, the required scale/bias in spconv int8 is:
```Python
import spconv.pytorch.quantization.quantized.reference as snnqr
import spconv.pytorch.quantization.intrinsic.quantized as snniq
import spconv.pytorch.quantization.quantized as snnq
input_scale = ...
output_scale = ...
if isinstance(layer, snnqr.SpConv):
q_weight = layer.get_quantized_weight() # for snnqr.SpConv
bias_np = layer.bias.detach().cpu().numpy()
elif isinstance(layer, (snniq.SparseConvReLU, snniq.SparseConvAddReLU, snnq.SparseConv))
q_weight = layer.weight() # for quantized layers
bias_np = layer.bias().detach().cpu().numpy()
else:
raise NotImplementedError
w_perchannel_scales = q_weight.q_per_channel_scales().detach().cpu().numpy().astype(np.float32)
scale_for_spconv_implicit_gemm = (input_scale * w_perchannel_scales) / output_scale
bias_for_spconv_implicit_gemm = bias_np / output_scale
```
then we can feed them to ```implicit_gemm```:
```C++
// output_add and output_add_scale: for fused conv-add-relu layer
ConvGemmOps::implicit_gemm(
allocator, tuner, features_int8, weight_int8, pair_fwd,
pair_mask_splits, mask_argsort_splits, actual_out_feature_num,
mask_tensor, arch, false, is_subm,
reinterpret_cast<std::uintptr_t>(stream), tv::CUDAKernelTimer(false),
false, false, bias_for_spconv_implicit_gemm, 1.0,
0.0, tv::gemm::Activation::kReLU, false /*use_tf32*/, output_scale,
scale_for_spconv_implicit_gemm, output_add, output_add_scale);
```
### Explicit Mode or Implicit
There are two int8 mode in tensorrt: implicit and explicit.
For Implicit, we can use tensorrt int8 calibrator to calculate scale and use them in plugin. This isn't tested and doesn't covered here.
For Explicit, we insert qdq to network, tensorrt will fuse QDQ and convert layers to quantized based on QDQ layers. see [this doc](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks).
There is a important drawback in tensorrt int8: tensorrt won't fuse QDQ for custom int8 plugins. So we must fuse QDQ by ourself (in pytorch), and **keep QDQ** in regular layers such linear and convolution.
Pytorch will add QDQ in ```convert_fx``` and ```convert_to_reference_fx```.
```convert_to_reference_fx```: insert qdq and convert fused module to reference, but it **doesn't** fuse any QDQ in your network. If we don't want to write fuse code manually, we can't use this function.
```convert_fx```: insert qdq and convert fused module to quantized for native (CPU) backend. this function will fuse **ALL** QDQs in your network, if we want to use tensorrt explicit quantization, we must keep QDQ for regular layers.
Currently we implement this via pytorch ```convert_fx``` and use some hack:
```Python
import torch.ao.nn.intrinsic as nni
import torch.nn.quantized._reference as nnqr
from torch.ao.quantization.fx._lower_to_native_backend import \
STATIC_LOWER_FUSED_MODULE_MAP, STATIC_LOWER_MODULE_MAP, QBIN_OP_MAPPING
from spconv.pytorch.quantization.backend_cfg import \
SPCONV_STATIC_LOWER_FUSED_MODULE_MAP, SPCONV_STATIC_LOWER_MODULE_MAP
# add spconv layers to support QDQ fusion for sparse conv layers
STATIC_LOWER_FUSED_MODULE_MAP.update(SPCONV_STATIC_LOWER_FUSED_MODULE_MAP)
STATIC_LOWER_MODULE_MAP.update(SPCONV_STATIC_LOWER_MODULE_MAP)
# remove linear layers to avoid QDQ fusion for linear.
STATIC_LOWER_FUSED_MODULE_MAP.pop(nni.LinearReLU)
STATIC_LOWER_MODULE_MAP.pop(nnqr.Linear)
# run above BEFORE convert_fx
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
# or just use spconvq.prepare_spconv_torch_inference(True)
```
We can also use a spconv function to handle this if regular layers in your network only contains Linear:
```Python
spconvq.prepare_spconv_torch_inference(with_linear=False)
```
If your network contains convolutions, you can do same thing for conv layers. this isn't covered in ```spconvq.prepare_spconv_torch_inference```.
## Steps
### Record number of voxels for each layer
there is a argument in ```SparseConvolution``` layers: ```record_voxel_count```. If you enable it, max number of voxels will be recorded in a registered buffer during inference. Turn on it and run inference in whole training dataset.
After inference, we know max number of voxels of each spconv layer, which is required in tensorrt plugin.
### write ```torch.fx``` based torch->trt conversion
After PTQ/QAT model ready, we can use [```torch.fx.Interpreter```](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter) to transform traced pytorch model to tensorrt.
see [example](../example/mnist/mnist_net_transform.py).
...@@ -32,6 +32,9 @@ using GemmTunerSimple = ...@@ -32,6 +32,9 @@ using GemmTunerSimple =
spconvlib::spconv::csrc::sparse::convops::spops::GemmTuner; spconvlib::spconv::csrc::sparse::convops::spops::GemmTuner;
int main(int argc, char **argv) { int main(int argc, char **argv) {
bool is_int8 = false;
float inp_scale = 0.04;
float out_scale = 0.05;
tv::ssprint("Hello libspconv!!!"); tv::ssprint("Hello libspconv!!!");
TV_ASSERT_RT_ERR(argc == 2, "usage: main /path/to/benchmark-pc.jarr, you can " TV_ASSERT_RT_ERR(argc == 2, "usage: main /path/to/benchmark-pc.jarr, you can "
...@@ -161,6 +164,10 @@ int main(int argc, char **argv) { ...@@ -161,6 +164,10 @@ int main(int argc, char **argv) {
// if your kernel volume > 32, you need to use // if your kernel volume > 32, you need to use
// tv::gemm::SparseConvAlgo::kNative. otherwise use kMaskImplicitGemm. // tv::gemm::SparseConvAlgo::kNative. otherwise use kMaskImplicitGemm.
if (i == 0) { if (i == 0) {
if (is_int8){
// native don't support int8
continue;
}
auto conv_algo = tv::gemm::SparseConvAlgo::kNative; auto conv_algo = tv::gemm::SparseConvAlgo::kNative;
bool inverse = false; bool inverse = false;
// native algo code example // native algo code example
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import contextlib
import copy
from typing import Dict, Optional
import torch
import torch.ao.quantization
import torch.ao.quantization.quantize_fx as qfx
import torch.cuda.amp
import torch.fx
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import spconv.pytorch as spconv
import spconv.pytorch.quantization as spconvq
from spconv.pytorch.quantization.core import quantize_per_tensor
from spconv.pytorch.quantization.fake_q import \
get_default_spconv_qconfig_mapping
import spconv.pytorch.quantization.intrinsic.quantized as snniq
from spconv.pytorch.quantization.interpreter import NetworkInterpreter, register_node_handler, register_method_handler
import spconv.pytorch.quantization.intrinsic as snni
import spconv.pytorch.quantization.intrinsic.quantized as snniq
import spconv.pytorch.quantization.quantized as snnq
import spconv.pytorch.quantization.quantized.reference as snnqr
from spconv.pytorch.cppcore import torch_tensor_to_tv
import numpy as np
import spconv.constants as spconvc
# enable trace mode here, or use environment variable SPCONV_FX_TRACE_MODE=1
spconvc.SPCONV_FX_TRACE_MODE = True
@contextlib.contextmanager
def identity_ctx():
yield
class SubMConvBNReLU(spconv.SparseSequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(SubMConvBNReLU, self).__init__(
spconv.SubMConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm1d(out_planes, momentum=0.1),
# Replace with ReLU
nn.ReLU(inplace=False)
)
class SparseConvBNReLU(spconv.SparseSequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(SparseConvBNReLU, self).__init__(
spconv.SparseConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm1d(out_planes, momentum=0.1),
# Replace with ReLU
nn.ReLU(inplace=False)
)
class SparseBasicBlock(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True))
self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2)
self.relu = spconv.SparseReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor):
identity = x
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1_bn_relu(x)
out = self.conv2_bn(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.relu(out + identity)
return out
class ResidualNetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so
we build a pure sparse network here.
"""
def __init__(self):
super(ResidualNetPTQ, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SparseBasicBlock(32, 32),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14
SparseConvBNReLU(64, 64, 2, 2), # 7x7
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4),
# spconv.ToDense(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp = self.net(x_sp)
# print(x_sp.shape)
x = x_sp
# x = torch.flatten(x, 1)
# x = self.dequant(x)
# output = F.log_softmax(x, dim=1)
return x
def calibrate(args, model: torch.nn.Module, data_loader, device):
model.eval()
with torch.no_grad():
for image, target in data_loader:
image = image.to(device)
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(image.reshape(-1, 28, 28, 1))
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
# output = model(data_sp)
else:
output = model(image)
# add module handler
@register_node_handler(snni.SpconvReLUNd)
def _spconv_fused_relu(net, target: snni.SpconvReLUNd, args, kwargs, name: str):
# add plugin here...
print("add sparse conv plugin here...", target, name)
return args[0]
@register_node_handler(snni.SpconvAddReLUNd)
def _spconv_fused_add_relu(net, target: snni.SpconvReLUNd, args, kwargs, name: str):
# add plugin here...
print("add sparse conv plugin here...", target, name)
return args[0]
@register_node_handler(snniq.SparseConvReLU)
def _spconv_fused_q_relu(net, target: snniq.SparseConvReLU, args, kwargs, name: str):
# add plugin here...
print("add sparse conv plugin here...", target, name)
return args[0]
@register_node_handler(snniq.SparseConvAddReLU)
def _spconv_fused_q_add_relu(net, target: snniq.SparseConvAddReLU, args, kwargs, name: str):
# add fused conv-add-relu plugin here...
inp0 = args[0]
inp1 = args[1]
print("add fused sparse conv add relu plugin here...", target, name)
return args[0]
@register_node_handler(snnqr.SpConv)
def _spconv_r(net, target: snnqr.SpConv, args, kwargs, name: str):
# add plugin here...
input_scale = args[0].int8_scale
output_scale = target.scale
q_weight = target.get_quantized_weight()
w_scales = q_weight.q_per_channel_scales().detach().cpu().numpy().astype(np.float32)
bias_np = target.bias.detach().cpu().numpy()
w = torch_tensor_to_tv(q_weight).cpu().numpy()
# spconv int8 format
channel_scale = (input_scale * w_scales) / output_scale
bias_np = bias_np / output_scale
print("add sparse conv plugin here...", target, name)
return args[0]
@register_node_handler(snnq.SparseConv)
def _spconv_fused_q(net, target: snnq.SparseConv, args, kwargs, name: str):
# add plugin here...
print("add sparse conv plugin here...", target, name)
return args[0]
@register_node_handler(spconv.SparseConvTensor)
def _get_sparse_conv_tensor(net, target: spconv.SparseConvTensor, args, kwargs, name: str):
return spconv.SparseConvTensor(*args, **kwargs)
# add tensor method handler
@register_method_handler("replace_feature", spconv.SparseConvTensor)
def _replace_new_feature(net, target, args, kwargs, name: str):
input: spconv.SparseConvTensor = args[0]
if isinstance(input, spconv.SparseConvTensor):
return input.replace_feature(*args[1:])
else:
raise NotImplementedError
@register_node_handler(quantize_per_tensor)
def _quantize_per_tensor(net, target, args, kwargs, name: str):
inp: spconv.SparseConvTensor = args[0]
scale = args[1].detach().cpu().numpy()
zero_point = args[2]
print("implement quantize here...", name, scale)
# WARNING
# we need to store scale to SparseConvTensor because pytorch dequantize don't
# have any argument
inp.int8_scale = scale
return inp
@register_method_handler("dequantize", spconv.SparseConvTensor)
def _dequantize(net, target, args, kwargs, name: str):
inp: spconv.SparseConvTensor = args[0]
assert inp.int8_scale is not None
print("implement dequantize here...", inp.int8_scale)
return inp
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs',
type=int,
default=1,
metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr',
type=float,
default=1.0,
metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma',
type=float,
default=0.7,
metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument('--sparse',
action='store_true',
default=True,
help='use sparse conv network instead of dense')
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model',
action='store_true',
default=False,
help='For Saving the current Model')
parser.add_argument('--fp16',
action='store_true',
default=False,
help='For mixed precision training')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda and args.sparse else "cpu")
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
model = ResidualNetPTQ().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
model.eval()
spconvq.prepare_spconv_torch_inference(True)
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
qconfig_mapping = get_default_spconv_qconfig_mapping(is_qat=False)
prepare_cfg = spconvq.get_spconv_prepare_custom_config()
backend_cfg = spconvq.get_spconv_backend_config()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# calibrate: run model with some inputs
calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
converted_model = spconvq.transform_qdq(converted_model)
# test converted ptq model with int8 kernel
converted_model = spconvq.remove_conv_add_dq(converted_model)
# use trt ITensor as input here...
# input is same as converted_model inputs
# here we just use torch tensor. we can actually use any input here.
ft = torch.zeros([500, 1], dtype=torch.float32, device=device)
ind = torch.zeros([500, 3], dtype=torch.int32, device=device)
interp = NetworkInterpreter(None, converted_model, [ft, ind, 1])
# get converted outputs from interp
outputs = interp.run()
if __name__ == '__main__':
main()
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import contextlib
import copy
from typing import Dict, Optional
import torch
import torch.ao.quantization
import torch.ao.quantization.quantize_fx as qfx
import torch.cuda.amp
import torch.fx
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.ao.quantization import (DeQuantStub, QuantStub,
get_default_qconfig_mapping)
from torch.ao.quantization.fx._lower_to_native_backend import \
STATIC_LOWER_FUSED_MODULE_MAP, STATIC_LOWER_MODULE_MAP
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
import spconv.pytorch as spconv
import spconv.pytorch.quantization as spconvq
from spconv.pytorch.quantization import get_default_spconv_trt_ptq_qconfig
from spconv.pytorch.quantization.core import quantize_per_tensor
from spconv.pytorch.quantization.fake_q import \
get_default_spconv_qconfig_mapping
from spconv.pytorch.quantization.intrinsic.modules import SpconvBnAddReLUNd, SpconvAddReLUNd
import spconv.pytorch.quantization.intrinsic.quantized as snniq
@contextlib.contextmanager
def identity_ctx():
yield
class SubMConvBNReLU(spconv.SparseSequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(SubMConvBNReLU, self).__init__(
spconv.SubMConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm1d(out_planes, momentum=0.1),
# Replace with ReLU
nn.ReLU(inplace=False)
)
class SparseConvBNReLU(spconv.SparseSequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(SparseConvBNReLU, self).__init__(
spconv.SparseConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm1d(out_planes, momentum=0.1),
# Replace with ReLU
nn.ReLU(inplace=False)
)
class SparseBasicBlock(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True))
self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2)
self.relu = spconv.SparseReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor):
identity = x
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1_bn_relu(x)
out = self.conv2_bn(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.relu(out + identity)
return out
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2),
spconv.ToDense(),
)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x_sp: spconv.SparseConvTensor):
# def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
# x = self.quant(x)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
# x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp)
x = torch.flatten(x, 1)
x = self.dropout1(x)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
# x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetV2(nn.Module):
def __init__(self):
super(NetV2, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2),
spconv.ToDense(),
)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
x = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x = self.net(x_sp)
x = torch.flatten(x, 1)
# x = self.dropout1(x)
x = self.fc1(x)
x = F.relu(x)
# x = self.dropout2(x)
x = self.fc2(x)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so
we build a pure sparse network here.
"""
def __init__(self):
super(NetPTQ, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14
SparseConvBNReLU(64, 64, 2, 2), # 7x7
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4),
spconv.ToDense(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
features = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp = self.net(x_sp)
# print(x_sp.shape)
x = x_sp
x = torch.flatten(x, 1)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class ResidualNetPTQ(nn.Module):
"""pytorch currently don't support cuda int8 inference, so
we build a pure sparse network here.
"""
def __init__(self):
super(ResidualNetPTQ, self).__init__()
self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3),
SparseBasicBlock(32, 32),
SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14
SparseConvBNReLU(64, 64, 2, 2), # 7x7
SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
spconv.SparseConv2d(64, 10, 4, 4),
spconv.ToDense(),
)
# self.fc1 = nn.Linear(64 * 1 * 1, 128)
# self.fc2 = nn.Linear(128, 10)
# self.dropout1 = nn.Dropout2d(0.25)
# self.dropout2 = nn.Dropout2d(0.5)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
# x: [N, 28, 28, 1], must be NHWC tensor
features = self.quant(features)
# x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
# create SparseConvTensor manually: see SparseConvTensor.from_dense
x_sp = self.net(x_sp)
# print(x_sp.shape)
x = x_sp
x = torch.flatten(x, 1)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
class NetDense(nn.Module):
def __init__(self):
super(NetDense, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.iden = spconv.SparseIdentity()
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.iden(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x = self.dequant(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
scaler = torch.cuda.amp.grad_scaler.GradScaler()
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with amp_ctx:
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(data.reshape(-1, 28, 28, 1))
# output = model(data_sp)
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
else:
output = model(data)
loss = F.nll_loss(output, target)
scale = 1.0
if args.fp16:
assert loss.dtype is torch.float32
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
# scaler.unscale_(optim)
# Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
# You may use the same value for max_norm here as you would without gradient scaling.
# torch.nn.utils.clip_grad_norm_(models[0].net.parameters(), max_norm=0.1)
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
scale = scaler.get_scale()
else:
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with amp_ctx:
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(data.reshape(-1, 28, 28, 1))
# output = model(data_sp)
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
else:
output = model(data)
test_loss += F.nll_loss(
output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(
dim=1,
keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def calibrate(args, model: torch.nn.Module, data_loader, device):
model.eval()
with torch.no_grad():
for image, target in data_loader:
image = image.to(device)
if args.sparse:
data_sp = spconv.SparseConvTensor.from_dense(image.reshape(-1, 28, 28, 1))
output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
# output = model(data_sp)
else:
output = model(image)
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs',
type=int,
default=1,
metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr',
type=float,
default=1.0,
metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma',
type=float,
default=0.7,
metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument('--sparse',
action='store_true',
default=True,
help='use sparse conv network instead of dense')
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model',
action='store_true',
default=False,
help='For Saving the current Model')
parser.add_argument('--fp16',
action='store_true',
default=False,
help='For mixed precision training')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda and args.sparse else "cpu")
qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
if args.sparse:
model = ResidualNetPTQ().to(device)
else:
model = NetDense().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
model.eval()
if not args.sparse:
model = model.cpu()
spconvq.prepare_spconv_torch_inference(True)
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight.
qconfig_mapping = get_default_spconv_qconfig_mapping(is_qat=False)
prepare_cfg = spconvq.get_spconv_prepare_custom_config()
backend_cfg = spconvq.get_spconv_backend_config()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# calibrate: run model with some inputs
calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
converted_model = spconvq.transform_qdq(converted_model)
# test converted ptq model with int8 kernel
converted_model = spconvq.remove_conv_add_dq(converted_model)
print(converted_model)
test(args, converted_model, qdevice, test_loader)
if __name__ == '__main__':
main()
...@@ -86,123 +86,21 @@ class SparseBasicBlock(spconv.SparseModule): ...@@ -86,123 +86,21 @@ class SparseBasicBlock(spconv.SparseModule):
self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True)) self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True))
self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2) self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = nn.Identity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x.features)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1_bn_relu(x)
out = self.conv2_bn(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out.replace_feature(self.relu(out.features + identity))
return out
class SparseBasicBlock1(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
self.conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = nn.Identity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x.features)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out = out.replace_feature(self.relu1(self.norm1(out.features)))
out = self.conv2(out)
out = out.replace_feature(self.norm2(out.features))
# if self.downsample is not None:
# identity = self.downsample(x)
out = out.replace_feature(self.relu2(out.features + identity))
return out
class SparseBasicBlock2(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
self.conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.norm2 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.relu1 = spconv.SparseReLU(inplace=True)
self.relu2 = spconv.SparseReLU(inplace=True)
self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x)
# if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x)
out = self.relu1(self.norm1(out))
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.relu2(out + identity)
return out
class SparseBasicBlock3(spconv.SparseModule):
"""residual block that supported by spconv quantization.
"""
expansion = 1
def __init__(self,
in_planes, out_planes,
stride=1,
downsample=None):
spconv.SparseModule.__init__(self)
self.conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)
self.norm1 = spconv.SparseBatchNorm(out_planes, momentum=0.1) self.relu = spconv.SparseReLU(inplace=True)
norm2 = spconv.SparseBatchNorm(out_planes, momentum=0.1)
self.residual_conv = SpconvAddReLUNd(conv2, spconv.SparseReLU(inplace=True))
self.relu1 = spconv.SparseReLU(inplace=True)
# self.relu2 = spconv.SparseReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
self.iden_for_fx_match = spconv.SparseIdentity() self.iden_for_fx_match = spconv.SparseIdentity()
def forward(self, x: spconv.SparseConvTensor): def forward(self, x: spconv.SparseConvTensor):
identity = self.iden_for_fx_match(x) identity = x
# if self.training: # if self.training:
# assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}' # assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
out = self.conv1(x) out = self.conv1_bn_relu(x)
out = self.relu1(self.norm1(out)) out = self.conv2_bn(out)
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
out = self.residual_conv(out, identity) out = self.relu(out + identity)
return out return out
class Net(nn.Module): class Net(nn.Module):
...@@ -318,7 +216,7 @@ class ResidualNetPTQ(nn.Module): ...@@ -318,7 +216,7 @@ class ResidualNetPTQ(nn.Module):
self.net = spconv.SparseSequential( self.net = spconv.SparseSequential(
SubMConvBNReLU(1, 32, 3), SubMConvBNReLU(1, 32, 3),
# SubMConvBNReLU(32, 32, 3), # SubMConvBNReLU(32, 32, 3),
SparseBasicBlock2(32, 32), SparseBasicBlock(32, 32),
SubMConvBNReLU(32, 64, 3), SubMConvBNReLU(32, 64, 3),
SparseConvBNReLU(64, 64, 2, 2), # 14x14 SparseConvBNReLU(64, 64, 2, 2), # 14x14
SparseConvBNReLU(64, 64, 2, 2), # 7x7 SparseConvBNReLU(64, 64, 2, 2), # 7x7
...@@ -584,54 +482,22 @@ def main(): ...@@ -584,54 +482,22 @@ def main():
model = model.cpu() model = model.cpu()
model_qat = copy.deepcopy(model) model_qat = copy.deepcopy(model)
STATIC_LOWER_FUSED_MODULE_MAP.update(SPCONV_STATIC_LOWER_FUSED_MODULE_MAP) spconvq.prepare_spconv_torch_inference(True)
STATIC_LOWER_MODULE_MAP.update(SPCONV_STATIC_LOWER_MODULE_MAP) # do qat
# tensorrt only support symmetric quantization, per-tensor act and per-channel weight. qconfig_mapping_qat = get_default_spconv_qconfig_mapping(True)
qconfig_mapping = get_default_spconv_qconfig_mapping(False)
prepare_cfg = spconvq.get_spconv_prepare_custom_config() prepare_cfg = spconvq.get_spconv_prepare_custom_config()
backend_cfg = spconvq.get_spconv_backend_config() backend_cfg = spconvq.get_spconv_backend_config()
# convert_cfg = spconvq.get_spconv_convert_custom_config()
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# then add observers to fused model.
prepared_model = qfx.prepare_fx(model, qconfig_mapping, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# print(prepared_model)
# breakpoint()
# print(prepared_model)
# calibrate: run model with some inputs
calibrate(args, prepared_model, test_loader, qdevice)
# convert (ptq): replace intrinsic blocks with quantized modules
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qconfig_mapping, backend_config=backend_cfg)
converted_model = transform_qdq(converted_model)
# test converted ptq model with int8 kernel
remove_conv_add_dq(converted_model)
prepared_model_qat = qfx.prepare_qat_fx(model_qat, qconfig_mapping_qat, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
train(args, prepared_model_qat, qdevice, train_loader, optimizer, 1)
converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
converted_model = spconvq.transform_qdq(converted_model)
# test converted ptq model with int8 kernel
spconvq.remove_conv_add_dq(converted_model)
# you will see some nvrtc compile log here, which means int8 kernel is used.
print(converted_model) print(converted_model)
breakpoint()
test(args, converted_model, qdevice, test_loader) test(args, converted_model, qdevice, test_loader)
# do qat
# qconfig_mapping_qat = get_default_spconv_qconfig_mapping(True)
# prepared_model_qat = qfx.prepare_qat_fx(model_qat, qconfig_mapping_qat, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
# # converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
# # breakpoint()
# print(prepared_model_qat)
# train(args, prepared_model_qat, qdevice, train_loader, optimizer, 1)
# converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
# converted_model = transform_qdq(converted_model)
# test(args, converted_model, qdevice, test_loader)
# # [type(m) for m in prepared_model_qat.modules()]
# # model.qconfig = get_default_spconv_trt_ptq_qconfig()
# # prepare_custom_config_dict = spconvq.get_prepare_custom_config()
# # convert_custom_config_dict = spconvq.get_convert_custom_config()
# # torch.ao.quantization.prepare(model, inplace=True)
# # print('Post Training Quantization Prepare: Inserting Observers')
# # print('\n ConvBnReLUBlock:After observer insertion \n\n', model.net[0])
# # test(args, model, device, test_loader)
# print(converted_model)
# you will see some nvrtc compile log here, which means int8 kernel is used.
breakpoint()
if __name__ == '__main__': if __name__ == '__main__':
main() main()
[build-system] [build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.7"] requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.7"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu120-0.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"] # requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu120-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
...@@ -39,9 +39,9 @@ if cuda_ver: ...@@ -39,9 +39,9 @@ if cuda_ver:
cuda_ver_str = cuda_ver.replace(".", "") # 10.2 to 102 cuda_ver_str = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver_str) RELEASE_NAME += "-cu{}".format(cuda_ver_str)
deps = ["cumm-cu{}>=0.3.7".format(cuda_ver_str)] deps = ["cumm-cu{}>=0.4.0, <0.5.0".format(cuda_ver_str)]
else: else:
deps = ["cumm>=0.3.7"] deps = ["cumm>=0.4.0, <0.5.0"]
......
...@@ -116,5 +116,6 @@ SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1 ...@@ -116,5 +116,6 @@ SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
SPCONV_ALLOW_TF32 = False SPCONV_ALLOW_TF32 = False
SPCONV_INT8_DEBUG = os.getenv("SPCONV_INT8_DEBUG", "0") == "1"
SPCONV_INT8_DEBUG = False SPCONV_DO_SORT = os.getenv("SPCONV_DO_SORT", "1") == "1"
\ No newline at end of file \ No newline at end of file
...@@ -557,7 +557,7 @@ class SpconvOps: ...@@ -557,7 +557,7 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, preallocated: Dict[str, Tensor] = {}, do_sort: bool = True) -> Tuple[Tensor, int]: def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1, timer: CUDAKernelTimer = CUDAKernelTimer(False), direct_table: bool = False, do_sort: bool = True, preallocated: Dict[str, Tensor] = {}) -> Tuple[Tensor, int]:
""" """
Args: Args:
allocator: allocator:
...@@ -577,8 +577,8 @@ class SpconvOps: ...@@ -577,8 +577,8 @@ class SpconvOps:
num_out_act_bound: num_out_act_bound:
timer: timer:
direct_table: direct_table:
preallocated:
do_sort: do_sort:
preallocated:
""" """
... ...
@staticmethod @staticmethod
......
...@@ -1647,10 +1647,10 @@ class SpconvOps(pccm.Class): ...@@ -1647,10 +1647,10 @@ class SpconvOps(pccm.Class):
code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)", code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)",
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)") "cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)")
code.arg("direct_table", f"bool", "false") code.arg("direct_table", f"bool", "false")
code.arg("do_sort", f"bool", "true")
code.arg("preallocated", f"std::unordered_map<std::string, tv::Tensor>", code.arg("preallocated", f"std::unordered_map<std::string, tv::Tensor>",
"std::unordered_map<std::string, tv::Tensor>{}", "std::unordered_map<std::string, tv::Tensor>{}",
"Dict[str, cumm.tensorview.Tensor] = {}") "Dict[str, cumm.tensorview.Tensor] = {}")
code.arg("do_sort", f"bool", "true")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f""" code.raw(f"""
......
...@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator ...@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import _TORCH_DTYPE_TO_TV, TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul from spconv.pytorch.cppcore import _TORCH_DTYPE_TO_TV, TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, SPCONV_ALLOW_TF32 from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, SPCONV_ALLOW_TF32, SPCONV_DO_SORT
import spconv.core_cc as _ext import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.core_cc.csrc.sparse.inference import InferenceOps from spconv.core_cc.csrc.sparse.inference import InferenceOps
...@@ -342,7 +342,8 @@ def get_indice_pairs_implicit_gemm( ...@@ -342,7 +342,8 @@ def get_indice_pairs_implicit_gemm(
alloc: Optional[ThrustSortAllocator] = None, alloc: Optional[ThrustSortAllocator] = None,
timer: CUDAKernelTimer = CUDAKernelTimer(False), timer: CUDAKernelTimer = CUDAKernelTimer(False),
num_out_act_bound: int = -1, num_out_act_bound: int = -1,
direct_table: bool = SPCONV_USE_DIRECT_TABLE): direct_table: bool = SPCONV_USE_DIRECT_TABLE,
do_sort=SPCONV_DO_SORT):
""" """
Why return tuple? because pytorch seems don't support custom object in autograd. Why return tuple? because pytorch seems don't support custom object in autograd.
return: ( return: (
...@@ -382,7 +383,8 @@ def get_indice_pairs_implicit_gemm( ...@@ -382,7 +383,8 @@ def get_indice_pairs_implicit_gemm(
stream, stream,
num_out_act_bound, num_out_act_bound,
timer=timer_cpp, timer=timer_cpp,
direct_table=direct_table) direct_table=direct_table,
do_sort=do_sort)
mask_split_count = mask_tensor.dim(0) mask_split_count = mask_tensor.dim(0)
masks = [mask_tensor[i:i + 1].numpy() for i in range(mask_split_count)] masks = [mask_tensor[i:i + 1].numpy() for i in range(mask_split_count)]
if subm: if subm:
...@@ -548,7 +550,8 @@ def get_indice_pairs_implicit_gemm( ...@@ -548,7 +550,8 @@ def get_indice_pairs_implicit_gemm(
SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j], SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j],
alloc.alloc, alloc.alloc,
mask_argsort_tv[j], stream, mask_argsort_tv[j], stream,
mask_int_count) mask_int_count,
do_sort=do_sort)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)] pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
mask_argsort_in_splits = [ mask_argsort_in_splits = [
...@@ -764,22 +767,22 @@ def get_indice_pairs_implicit_gemm( ...@@ -764,22 +767,22 @@ def get_indice_pairs_implicit_gemm(
alloc.alloc, alloc.alloc,
mask_argsort_fwd_tv[0], mask_argsort_fwd_tv[0],
stream, stream,
mask_int_count) mask_int_count, do_sort=do_sort)
else: else:
if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream, mask_int_count) mask_argsort_bwd_tv[0], stream, mask_int_count, do_sort=do_sort)
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator(
pair_mask_fwd_tv[0], alloc.alloc, pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream, mask_int_count) mask_argsort_fwd_tv[0], stream, mask_int_count, do_sort=do_sort)
else: else:
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator(
pair_mask_fwd_tv[0], alloc.alloc, pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream, mask_int_count) mask_argsort_fwd_tv[0], stream, mask_int_count, do_sort=do_sort)
SpconvOps.sort_1d_by_key_allocator( SpconvOps.sort_1d_by_key_allocator(
pair_mask_bwd_tv[0], alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], stream, mask_int_count) mask_argsort_bwd_tv[0], stream, mask_int_count, do_sort=do_sort)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
if not is_train: if not is_train:
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
from .backend_cfg import (get_spconv_backend_config, from .backend_cfg import (get_spconv_backend_config,
get_spconv_prepare_custom_config, get_spconv_prepare_custom_config,
get_spconv_convert_custom_config) get_spconv_convert_custom_config,
prepare_spconv_torch_inference)
from .fake_q import (get_default_spconv_trt_ptq_qconfig, from .fake_q import (get_default_spconv_trt_ptq_qconfig,
get_default_spconv_trt_qat_qconfig, get_default_spconv_trt_qat_qconfig,
get_default_spconv_qconfig_mapping) get_default_spconv_qconfig_mapping)
......
...@@ -632,3 +632,18 @@ def get_spconv_convert_custom_config(): ...@@ -632,3 +632,18 @@ def get_spconv_convert_custom_config():
# cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU) # cfg.set_observed_to_quantized_mapping(snni., snniq.SparseConvReLU)
return cfg return cfg
def prepare_spconv_torch_inference(with_linear: bool):
from torch.ao.quantization.fx._lower_to_native_backend import \
STATIC_LOWER_FUSED_MODULE_MAP, STATIC_LOWER_MODULE_MAP
fmap = SPCONV_STATIC_LOWER_FUSED_MODULE_MAP.copy()
lmap = SPCONV_STATIC_LOWER_MODULE_MAP.copy()
if with_linear:
fmap.update({
nni.LinearReLU: (nnqr.Linear, snniq.LinearPerChannelWeightReLU),
})
lmap.update({
nnqr.Linear: snnq.LinearPerChannelWeight
})
STATIC_LOWER_FUSED_MODULE_MAP.update(fmap)
STATIC_LOWER_MODULE_MAP.update(lmap)
...@@ -49,8 +49,8 @@ def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule: ...@@ -49,8 +49,8 @@ def transform_qdq(m: torch.fx.GraphModule) -> torch.fx.GraphModule:
node.target = quantize_per_tensor node.target = quantize_per_tensor
if node.target == torch.ops.quantized.add: if node.target == torch.ops.quantized.add:
node.target = quantized_add node.target = quantized_add
m.graph.eliminate_dead_code()
m.recompile()
m.graph.lint() # Does some checks to make sure the m.graph.lint() # Does some checks to make sure the
# Graph is well-formed. # Graph is well-formed.
m.recompile()
return m return m
from typing import Any, Dict, List, Optional, Set, Type
import torch
import torch.fx
REGISTERED_NODE_HANDLERS: Dict[Any, Any] = {}
def register_node_handler(*names):
def wrap_func(handler):
global REGISTERED_NODE_HANDLERS
for n in names:
REGISTERED_NODE_HANDLERS[n] = handler
def new_handler(inputs, attributes, scope):
return handler(inputs, attributes, scope)
return new_handler
return wrap_func
def register_method_handler(name: str, tensor_classes):
if not isinstance(tensor_classes, (list, tuple)):
tensor_classes = [tensor_classes]
def wrap_func(handler):
global REGISTERED_NODE_HANDLERS
for tcls in tensor_classes:
REGISTERED_NODE_HANDLERS[(tcls, name)] = handler
def new_handler(inputs, attributes, scope):
return handler(inputs, attributes, scope)
return new_handler
return wrap_func
def get_node_handler(name):
global REGISTERED_NODE_HANDLERS
msg = "missing handler " + str(name)
msg += ", available handlers: {}".format(
list(REGISTERED_NODE_HANDLERS.keys()))
assert name in REGISTERED_NODE_HANDLERS, msg
return REGISTERED_NODE_HANDLERS[name]
class NetworkInterpreter(torch.fx.Interpreter):
def __init__(self,
network_ctx,
module: torch.fx.GraphModule,
inputs: List[Any],
verbose: bool = False):
super().__init__(module)
self.network_ctx = network_ctx
self._inputs = inputs
self._outputs = None
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
self._verbose = verbose
def run(self):
super().run(*self._inputs)
assert self._outputs is not None
return self._outputs
def run_node(self, n):
self._cur_node_name = str(n)
return super().run_node(n)
def call_module(self, target, args, kwargs):
assert isinstance(target, str)
submod = self.fetch_attr(target)
submod_type = getattr(submod, "_base_class_origin", type(submod))
type_str = submod_type.__qualname__
type_str_parts = type_str.split(".")
msg = f"[Module.{type_str_parts[-1]}]{target}({args}|{kwargs}) => "
try:
converter = get_node_handler(submod_type)
res = converter(self.network_ctx, submod, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def call_function(self, target, args, kwargs):
msg = f"[Func]{target}({args}|{kwargs}) => "
try:
converter = get_node_handler(target)
res = converter(self.network_ctx, target, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def call_method(self, target, args, kwargs):
msg = f"[Method]{target}({args}|{kwargs}) => "
assert isinstance(target, str)
try:
key = (type(args[0]), target)
converter = get_node_handler(key)
res = converter(self.network_ctx, target, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def output(self, target, args, kwargs):
self._outputs = args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment