spconv currently support int8 kernels with following requirements: ```input_channel % 32 == 0 && output_channel % 32 == 0```. Int8 kernels runs faster than fp16 kernel with following shapes:
```
C == 32 && K == 64
C == 64 && K == 32
C >= 64 && K >= 64
```
spconv currently don't support pooling int8 operation.
### Prepare model (Common)
We need to modify model to make sure it can be symbolic traced by ```torch.fx```. Here are some tips and requirements:
* only ```forward``` and its content can be traced.
* all conditional statement, such as ```if``` and ```assert```, can't depend on ```forward``` arguments. use a environment variable / global variable to remove asserts during tracing.
* all traced functions can't have dynamic arguments ```*args``` and ```**kwargs```. change them to static arguments or make them non traceable.
* non traceable code can be ignored by top-level functions and ```torch.nn.Module```, put all non-traceable code to state-less top-level functions or Modules:
1. write a top-level function (function declared in global scope), then use ```torch.fx.wrap``` to make sure it's non-traceable
2. write a Module that contains all non-traceable code.
* all attributes, methods and static functions are lost except non-traceable modules after tracing. if you still want to use some method in traced modules, you need to refactor them to a non-traceable child module, or make them static.
### Prepare model (Spconv)
Spconv int8 support subm residual fusion:
```mermaid
graph TD;
X-->Add;
A-->SubMConv;
SubMConv-->BatchNorm;
BatchNorm-->Add;
Add-->ReLU
```
to
```mermaid
graph TD;
X-->SubMConvAddReLU;
A-->SubMConvAddReLU
```
Due to limitations of ```torch.fx```, this fusion requires your residual code have no spconv stuffs such as ```replace_feature```.
The following residual module can't be fused due to ```out.replace_feature``` and ```out.features```, this operations are recorded as a standalone node in graph, so it's hard to recognize and fuse them.
After test, we need to convert torch model to tensorrt. please see [this doc](TENSORRT_INT8_GUIDE.md) for more details.
Since all int8 kernels are compiled in runtime in spconv python package, you can use environment variable ```SPCONV_INT8_DEBUG=1``` to remove most of candidate int8 kernels to reduce compile time.
If you get error that some op don't support CUDA backend, just disable quantization for them in ```qconfig_mapping```.
### Quantization Aware Training (QAT)
see [example](../example/mnist/mnist_qat.py) for a runnable example in mnist.
To perform QAT in pytorch, we firstly need to trace model via ```torch.fx``` and insert observers and fake quantize nodes to model.
Due to limitation of tensorrt, following requirements must be satisfied:
1. pad all inputs to a static shape
2. use a tensor to save current number of voxels, copy it to cpu and slice all inputs to real shape during inference (enqueue).
3.```supportsFormatCombination``` must allow exactly one combination, i.e. we must set dtype of this layer during network build. for example, if we want to use fp16, this function must accept fp16 and reject other dtypes to avoid tensorrt perform dtype/format selection during engine build.
4. Number of dimensions of int8 tensor for plugin must larger or equal to 3. (tested in tensorrt 8.4)
5. TensorRT version >= 8.4, tensorrt 8.0 don't support int8 plugin
There are two int8 mode in tensorrt: implicit and explicit.
For Implicit, we can use tensorrt int8 calibrator to calculate scale and use them in plugin. This isn't tested and doesn't covered here.
For Explicit, we insert qdq to network, tensorrt will fuse QDQ and convert layers to quantized based on QDQ layers. see [this doc](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks).
There is a important drawback in tensorrt int8: tensorrt won't fuse QDQ for custom int8 plugins. So we must fuse QDQ by ourself (in pytorch), and **keep QDQ** in regular layers such linear and convolution.
Pytorch will add QDQ in ```convert_fx``` and ```convert_to_reference_fx```.
```convert_to_reference_fx```: insert qdq and convert fused module to reference, but it **doesn't** fuse any QDQ in your network. If we don't want to write fuse code manually, we can't use this function.
```convert_fx```: insert qdq and convert fused module to quantized for native (CPU) backend. this function will fuse **ALL** QDQs in your network, if we want to use tensorrt explicit quantization, we must keep QDQ for regular layers.
Currently we implement this via pytorch ```convert_fx``` and use some hack:
```Python
import torch.ao.nn.intrinsic as nni
import torch.nn.quantized._reference as nnqr
from torch.ao.quantization.fx._lower_to_native_backend import \
If your network contains convolutions, you can do same thing for conv layers. this isn't covered in ```spconvq.prepare_spconv_torch_inference```.
## Steps
### Record number of voxels for each layer
there is a argument in ```SparseConvolution``` layers: ```record_voxel_count```. If you enable it, max number of voxels will be recorded in a registered buffer during inference. Turn on it and run inference in whole training dataset.
After inference, we know max number of voxels of each spconv layer, which is required in tensorrt plugin.
### write ```torch.fx``` based torch->trt conversion
After PTQ/QAT model ready, we can use [```torch.fx.Interpreter```](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter) to transform traced pytorch model to tensorrt.
see [example](../example/mnist/mnist_net_transform.py).
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic
# prepare: fuse your model, all patterns such as conv-bn-relu fuse to modules in torch.ao.quantization.intrinsic / spconv.pytorch.quantization.intrinsic