Commit 004effbd authored by yan.yan's avatar yan.yan
Browse files

add some example

parent 2309ebe5
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to write custom fx2trt like tool to convert
pytorch model to tensorrt.
"""
from __future__ import print_function
import argparse
import contextlib
import copy
from typing import Dict, Optional
import torch
import torch.ao.quantization
import torch.ao.quantization.quantize_fx as qfx
import torch.cuda.amp
import torch.fx
import torch.nn as nn
import torch.optim as optim
from torch.fx import Tracer
import tensorrt as trt
from spconv.pytorch.quantization.interpreter import NetworkInterpreter, register_node_handler, register_method_handler
from spconv.pytorch.cppcore import torch_tensor_to_tv
import numpy as np
import spconv.constants as spconvc
import torch.nn.functional as F
def _simple_repr(x):
return f"Tensor[{x.shape}|{x.dtype}]"
# add verbose for ITensor
trt.ITensor.__repr__ = _simple_repr
class NetDense(nn.Module):
def __init__(self):
super(NetDense, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.conv_pool = nn.Conv2d(64, 64, 2, 2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv_pool(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
if self.training:
x = F.log_softmax(x, dim=1)
return x
def _activation(net, x, act_type, alpha=None, beta=None, name=None):
layer = net.add_activation(x, act_type)
if alpha is not None:
layer.alpha = alpha
if beta is not None:
layer.beta = beta
output = layer.get_output(0)
if name is not None:
output.name = name
layer.name = name
return output
def _trt_reshape(net, inp, shape, name):
layer = net.add_shuffle(inp)
layer.reshape_dims = shape
output = layer.get_output(0)
layer.name = name
output.name = name
return output
# add module handler
@register_node_handler(nn.Conv2d)
def _conv2d(net, target: nn.Conv2d, args, kwargs, name: str):
x = args[0]
bias = target.bias
if target.bias is None:
bias = None
else:
bias = target.bias.detach().cpu().numpy()
weight = target.weight.detach().cpu().numpy()
O, I_groups, *ksize = weight.shape
I = I_groups * target.groups
stride = target.stride
padding = target.padding
dilation = target.dilation
weight_qdq = None
if not isinstance(weight, np.ndarray):
weight_qdq = weight
weight = trt.Weights()
else:
weight = trt.Weights(weight)
if bias is None:
bias = trt.Weights()
else:
bias = trt.Weights(bias)
layer = net.add_convolution_nd(x, O, tuple(ksize), weight, bias)
if weight_qdq is not None:
# in explicit quantization, we need this
layer.set_input(1, weight_qdq)
layer.stride_nd = tuple(stride)
layer.padding_nd = tuple(padding)
layer.dilation_nd = tuple(dilation)
layer.num_groups = target.groups
output = layer.get_output(0)
output.name = name
layer.name = name
return output
@register_node_handler(F.relu)
def _relu(net, target: nn.Conv2d, args, kwargs, name: str):
return _activation(net, args[0], trt.ActivationType.RELU, name=name)
@register_node_handler(nn.Dropout)
@register_node_handler(nn.Dropout1d)
@register_node_handler(nn.Dropout2d)
@register_node_handler(nn.Dropout3d)
def _identity_single(net, target, args, kwargs, name: str):
return args[0]
@register_node_handler(torch.flatten)
def _flatten(net, target, args, kwargs, name: str):
start_dim = args[1]
x = args[0]
return _trt_reshape(net, x, [*x.shape[:start_dim], int(np.prod(x.shape[start_dim:]))], name)
def _dot(net, x, y, transpose_x=False, transpose_y=False, name=None):
mode_x = trt.MatrixOperation.NONE
if transpose_x:
mode_x = trt.MatrixOperation.TRANSPOSE
mode_y = trt.MatrixOperation.NONE
if transpose_y:
mode_y = trt.MatrixOperation.TRANSPOSE
layer = net.add_matrix_multiply(x, mode_x, y, mode_y)
output = layer.get_output(0)
assert name is not None
output.name = name
layer.name = name
return output
def _constant(net, array, name):
array = np.array(array)
layer = net.add_constant(array.shape, trt.Weights(array.reshape(-1)))
out = layer.get_output(0)
layer.name = name
out.name = name
return out
@register_node_handler(nn.Linear)
def _linear(net, target: nn.Linear, args, kwargs, name: str):
x = args[0]
bias = target.bias
if target.bias is None:
bias = None
else:
bias = target.bias.detach().cpu().numpy()
weight = target.weight.detach().cpu().numpy()
weight_trt = _constant(net, weight, name + "/weight")
res = _dot(net, x, weight_trt, transpose_y=True, name=name)
if bias is not None:
bias_trt = _constant(net, bias.reshape(1, -1), name + "/bias")
layer = net.add_elementwise(res, bias_trt, trt.ElementWiseOperation.SUM)
res = layer.get_output(0)
add_name = name + "/add"
res.name = add_name
layer.name = add_name
return res
def main():
model = NetDense()
model = model.eval()
tc = Tracer()
graph_trace = tc.trace(model)
gm = torch.fx.GraphModule(tc.root, graph_trace)
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# try:
# import pycuda.autoprimaryctx
# except ModuleNotFoundError:
# import pycuda.autoinit
with trt.Runtime(TRT_LOGGER) as rt:
with trt.Builder(TRT_LOGGER) as builder:
with builder.create_network(True) as network:
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
input_tensor = network.add_input(name="inp", dtype=trt.float32, shape=[1, 1, 28, 28])
interp = NetworkInterpreter(network, gm, [input_tensor], verbose=True)
# get converted outputs from interp
outputs = interp.run()
network.mark_output(tensor=outputs[0])
plan = builder.build_serialized_network(network, config)
engine = rt.deserialize_cuda_engine(plan)
if __name__ == '__main__':
main()
from typing import Any, Dict, List, Optional, Set, Type
import torch
import torch.fx
REGISTERED_NODE_HANDLERS: Dict[Any, Any] = {}
def register_node_handler(*names):
def wrap_func(handler):
global REGISTERED_NODE_HANDLERS
for n in names:
REGISTERED_NODE_HANDLERS[n] = handler
def new_handler(*args, **kwargs):
return handler(*args, **kwargs)
return new_handler
return wrap_func
def register_method_handler(name: str, tensor_classes):
if not isinstance(tensor_classes, (list, tuple)):
tensor_classes = [tensor_classes]
def wrap_func(handler):
global REGISTERED_NODE_HANDLERS
for tcls in tensor_classes:
REGISTERED_NODE_HANDLERS[(tcls, name)] = handler
def new_handler(*args, **kwargs):
return handler(*args, **kwargs)
return new_handler
return wrap_func
def get_node_handler(name):
global REGISTERED_NODE_HANDLERS
msg = "missing handler " + str(name)
msg += ", available handlers: {}".format(
list(REGISTERED_NODE_HANDLERS.keys()))
assert name in REGISTERED_NODE_HANDLERS, msg
return REGISTERED_NODE_HANDLERS[name]
class NetworkInterpreter(torch.fx.Interpreter):
def __init__(self,
network_ctx,
module: torch.fx.GraphModule,
inputs: List[Any],
verbose: bool = False):
super().__init__(module)
self.network_ctx = network_ctx
self._inputs = inputs
self._outputs = None
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
self._verbose = verbose
def run(self):
super().run(*self._inputs)
assert self._outputs is not None
return self._outputs
def run_node(self, n):
self._cur_node_name = str(n)
return super().run_node(n)
def call_module(self, target, args, kwargs):
assert isinstance(target, str)
submod = self.fetch_attr(target)
submod_type = getattr(submod, "_base_class_origin", type(submod))
type_str = submod_type.__qualname__
type_str_parts = type_str.split(".")
msg = f"[Module.{type_str_parts[-1]}]{target}({args}|{kwargs}) => "
try:
converter = get_node_handler(submod_type)
res = converter(self.network_ctx, submod, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def call_function(self, target, args, kwargs):
msg = f"[Func]{target}({args}|{kwargs}) => "
try:
converter = get_node_handler(target)
res = converter(self.network_ctx, target, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def call_method(self, target, args, kwargs):
msg = f"[Method]{target}({args}|{kwargs}) => "
assert isinstance(target, str)
try:
key = (type(args[0]), target)
converter = get_node_handler(key)
res = converter(self.network_ctx, target, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def output(self, target, args, kwargs):
self._outputs = args
from typing import Any, Dict, List, Optional, Set, Type
import torch
import torch.fx
REGISTERED_NODE_HANDLERS: Dict[Any, Any] = {}
def register_node_handler(*names):
def wrap_func(handler):
global REGISTERED_NODE_HANDLERS
for n in names:
REGISTERED_NODE_HANDLERS[n] = handler
def new_handler(inputs, attributes, scope):
return handler(inputs, attributes, scope)
return new_handler
return wrap_func
def register_method_handler(name: str, tensor_classes):
if not isinstance(tensor_classes, (list, tuple)):
tensor_classes = [tensor_classes]
def wrap_func(handler):
global REGISTERED_NODE_HANDLERS
for tcls in tensor_classes:
REGISTERED_NODE_HANDLERS[(tcls, name)] = handler
def new_handler(inputs, attributes, scope):
return handler(inputs, attributes, scope)
return new_handler
return wrap_func
def get_node_handler(name):
global REGISTERED_NODE_HANDLERS
msg = "missing handler " + str(name)
msg += ", available handlers: {}".format(
list(REGISTERED_NODE_HANDLERS.keys()))
assert name in REGISTERED_NODE_HANDLERS, msg
return REGISTERED_NODE_HANDLERS[name]
class NetworkInterpreter(torch.fx.Interpreter):
def __init__(self,
network_ctx,
module: torch.fx.GraphModule,
inputs: List[Any],
verbose: bool = False):
super().__init__(module)
self.network_ctx = network_ctx
self._inputs = inputs
self._outputs = None
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
self._verbose = verbose
def run(self):
super().run(*self._inputs)
assert self._outputs is not None
return self._outputs
def run_node(self, n):
self._cur_node_name = str(n)
return super().run_node(n)
def call_module(self, target, args, kwargs):
assert isinstance(target, str)
submod = self.fetch_attr(target)
submod_type = getattr(submod, "_base_class_origin", type(submod))
type_str = submod_type.__qualname__
type_str_parts = type_str.split(".")
msg = f"[Module.{type_str_parts[-1]}]{target}({args}|{kwargs}) => "
try:
converter = get_node_handler(submod_type)
res = converter(self.network_ctx, submod, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def call_function(self, target, args, kwargs):
msg = f"[Func]{target}({args}|{kwargs}) => "
try:
converter = get_node_handler(target)
res = converter(self.network_ctx, target, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def call_method(self, target, args, kwargs):
msg = f"[Method]{target}({args}|{kwargs}) => "
assert isinstance(target, str)
try:
key = (type(args[0]), target)
converter = get_node_handler(key)
res = converter(self.network_ctx, target, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e
def output(self, target, args, kwargs):
self._outputs = args
from ..interpreter import *
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment