Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
004effbd
Commit
004effbd
authored
Feb 02, 2023
by
yan.yan
Browse files
add some example
parent
2309ebe5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
353 additions
and
129 deletions
+353
-129
example/mnist/custom_fx2trt.py
example/mnist/custom_fx2trt.py
+223
-0
spconv/pytorch/interpreter.py
spconv/pytorch/interpreter.py
+129
-0
spconv/pytorch/quantization/interpreter.py
spconv/pytorch/quantization/interpreter.py
+1
-129
No files found.
example/mnist/custom_fx2trt.py
0 → 100644
View file @
004effbd
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to write custom fx2trt like tool to convert
pytorch model to tensorrt.
"""
from
__future__
import
print_function
import
argparse
import
contextlib
import
copy
from
typing
import
Dict
,
Optional
import
torch
import
torch.ao.quantization
import
torch.ao.quantization.quantize_fx
as
qfx
import
torch.cuda.amp
import
torch.fx
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torch.fx
import
Tracer
import
tensorrt
as
trt
from
spconv.pytorch.quantization.interpreter
import
NetworkInterpreter
,
register_node_handler
,
register_method_handler
from
spconv.pytorch.cppcore
import
torch_tensor_to_tv
import
numpy
as
np
import
spconv.constants
as
spconvc
import
torch.nn.functional
as
F
def
_simple_repr
(
x
):
return
f
"Tensor[
{
x
.
shape
}
|
{
x
.
dtype
}
]"
# add verbose for ITensor
trt
.
ITensor
.
__repr__
=
_simple_repr
class
NetDense
(
nn
.
Module
):
def
__init__
(
self
):
super
(
NetDense
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
3
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
32
,
64
,
3
,
1
)
self
.
dropout1
=
nn
.
Dropout
(
0.25
)
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
self
.
fc1
=
nn
.
Linear
(
9216
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
self
.
conv_pool
=
nn
.
Conv2d
(
64
,
64
,
2
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv_pool
(
x
,
2
)
x
=
self
.
dropout1
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
if
self
.
training
:
x
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
x
def
_activation
(
net
,
x
,
act_type
,
alpha
=
None
,
beta
=
None
,
name
=
None
):
layer
=
net
.
add_activation
(
x
,
act_type
)
if
alpha
is
not
None
:
layer
.
alpha
=
alpha
if
beta
is
not
None
:
layer
.
beta
=
beta
output
=
layer
.
get_output
(
0
)
if
name
is
not
None
:
output
.
name
=
name
layer
.
name
=
name
return
output
def
_trt_reshape
(
net
,
inp
,
shape
,
name
):
layer
=
net
.
add_shuffle
(
inp
)
layer
.
reshape_dims
=
shape
output
=
layer
.
get_output
(
0
)
layer
.
name
=
name
output
.
name
=
name
return
output
# add module handler
@
register_node_handler
(
nn
.
Conv2d
)
def
_conv2d
(
net
,
target
:
nn
.
Conv2d
,
args
,
kwargs
,
name
:
str
):
x
=
args
[
0
]
bias
=
target
.
bias
if
target
.
bias
is
None
:
bias
=
None
else
:
bias
=
target
.
bias
.
detach
().
cpu
().
numpy
()
weight
=
target
.
weight
.
detach
().
cpu
().
numpy
()
O
,
I_groups
,
*
ksize
=
weight
.
shape
I
=
I_groups
*
target
.
groups
stride
=
target
.
stride
padding
=
target
.
padding
dilation
=
target
.
dilation
weight_qdq
=
None
if
not
isinstance
(
weight
,
np
.
ndarray
):
weight_qdq
=
weight
weight
=
trt
.
Weights
()
else
:
weight
=
trt
.
Weights
(
weight
)
if
bias
is
None
:
bias
=
trt
.
Weights
()
else
:
bias
=
trt
.
Weights
(
bias
)
layer
=
net
.
add_convolution_nd
(
x
,
O
,
tuple
(
ksize
),
weight
,
bias
)
if
weight_qdq
is
not
None
:
# in explicit quantization, we need this
layer
.
set_input
(
1
,
weight_qdq
)
layer
.
stride_nd
=
tuple
(
stride
)
layer
.
padding_nd
=
tuple
(
padding
)
layer
.
dilation_nd
=
tuple
(
dilation
)
layer
.
num_groups
=
target
.
groups
output
=
layer
.
get_output
(
0
)
output
.
name
=
name
layer
.
name
=
name
return
output
@
register_node_handler
(
F
.
relu
)
def
_relu
(
net
,
target
:
nn
.
Conv2d
,
args
,
kwargs
,
name
:
str
):
return
_activation
(
net
,
args
[
0
],
trt
.
ActivationType
.
RELU
,
name
=
name
)
@
register_node_handler
(
nn
.
Dropout
)
@
register_node_handler
(
nn
.
Dropout1d
)
@
register_node_handler
(
nn
.
Dropout2d
)
@
register_node_handler
(
nn
.
Dropout3d
)
def
_identity_single
(
net
,
target
,
args
,
kwargs
,
name
:
str
):
return
args
[
0
]
@
register_node_handler
(
torch
.
flatten
)
def
_flatten
(
net
,
target
,
args
,
kwargs
,
name
:
str
):
start_dim
=
args
[
1
]
x
=
args
[
0
]
return
_trt_reshape
(
net
,
x
,
[
*
x
.
shape
[:
start_dim
],
int
(
np
.
prod
(
x
.
shape
[
start_dim
:]))],
name
)
def
_dot
(
net
,
x
,
y
,
transpose_x
=
False
,
transpose_y
=
False
,
name
=
None
):
mode_x
=
trt
.
MatrixOperation
.
NONE
if
transpose_x
:
mode_x
=
trt
.
MatrixOperation
.
TRANSPOSE
mode_y
=
trt
.
MatrixOperation
.
NONE
if
transpose_y
:
mode_y
=
trt
.
MatrixOperation
.
TRANSPOSE
layer
=
net
.
add_matrix_multiply
(
x
,
mode_x
,
y
,
mode_y
)
output
=
layer
.
get_output
(
0
)
assert
name
is
not
None
output
.
name
=
name
layer
.
name
=
name
return
output
def
_constant
(
net
,
array
,
name
):
array
=
np
.
array
(
array
)
layer
=
net
.
add_constant
(
array
.
shape
,
trt
.
Weights
(
array
.
reshape
(
-
1
)))
out
=
layer
.
get_output
(
0
)
layer
.
name
=
name
out
.
name
=
name
return
out
@
register_node_handler
(
nn
.
Linear
)
def
_linear
(
net
,
target
:
nn
.
Linear
,
args
,
kwargs
,
name
:
str
):
x
=
args
[
0
]
bias
=
target
.
bias
if
target
.
bias
is
None
:
bias
=
None
else
:
bias
=
target
.
bias
.
detach
().
cpu
().
numpy
()
weight
=
target
.
weight
.
detach
().
cpu
().
numpy
()
weight_trt
=
_constant
(
net
,
weight
,
name
+
"/weight"
)
res
=
_dot
(
net
,
x
,
weight_trt
,
transpose_y
=
True
,
name
=
name
)
if
bias
is
not
None
:
bias_trt
=
_constant
(
net
,
bias
.
reshape
(
1
,
-
1
),
name
+
"/bias"
)
layer
=
net
.
add_elementwise
(
res
,
bias_trt
,
trt
.
ElementWiseOperation
.
SUM
)
res
=
layer
.
get_output
(
0
)
add_name
=
name
+
"/add"
res
.
name
=
add_name
layer
.
name
=
add_name
return
res
def
main
():
model
=
NetDense
()
model
=
model
.
eval
()
tc
=
Tracer
()
graph_trace
=
tc
.
trace
(
model
)
gm
=
torch
.
fx
.
GraphModule
(
tc
.
root
,
graph_trace
)
import
tensorrt
as
trt
TRT_LOGGER
=
trt
.
Logger
(
trt
.
Logger
.
WARNING
)
# try:
# import pycuda.autoprimaryctx
# except ModuleNotFoundError:
# import pycuda.autoinit
with
trt
.
Runtime
(
TRT_LOGGER
)
as
rt
:
with
trt
.
Builder
(
TRT_LOGGER
)
as
builder
:
with
builder
.
create_network
(
True
)
as
network
:
config
=
builder
.
create_builder_config
()
config
.
max_workspace_size
=
1
<<
30
input_tensor
=
network
.
add_input
(
name
=
"inp"
,
dtype
=
trt
.
float32
,
shape
=
[
1
,
1
,
28
,
28
])
interp
=
NetworkInterpreter
(
network
,
gm
,
[
input_tensor
],
verbose
=
True
)
# get converted outputs from interp
outputs
=
interp
.
run
()
network
.
mark_output
(
tensor
=
outputs
[
0
])
plan
=
builder
.
build_serialized_network
(
network
,
config
)
engine
=
rt
.
deserialize_cuda_engine
(
plan
)
if
__name__
==
'__main__'
:
main
()
spconv/pytorch/interpreter.py
0 → 100644
View file @
004effbd
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
import
torch
import
torch.fx
REGISTERED_NODE_HANDLERS
:
Dict
[
Any
,
Any
]
=
{}
def
register_node_handler
(
*
names
):
def
wrap_func
(
handler
):
global
REGISTERED_NODE_HANDLERS
for
n
in
names
:
REGISTERED_NODE_HANDLERS
[
n
]
=
handler
def
new_handler
(
*
args
,
**
kwargs
):
return
handler
(
*
args
,
**
kwargs
)
return
new_handler
return
wrap_func
def
register_method_handler
(
name
:
str
,
tensor_classes
):
if
not
isinstance
(
tensor_classes
,
(
list
,
tuple
)):
tensor_classes
=
[
tensor_classes
]
def
wrap_func
(
handler
):
global
REGISTERED_NODE_HANDLERS
for
tcls
in
tensor_classes
:
REGISTERED_NODE_HANDLERS
[(
tcls
,
name
)]
=
handler
def
new_handler
(
*
args
,
**
kwargs
):
return
handler
(
*
args
,
**
kwargs
)
return
new_handler
return
wrap_func
def
get_node_handler
(
name
):
global
REGISTERED_NODE_HANDLERS
msg
=
"missing handler "
+
str
(
name
)
msg
+=
", available handlers: {}"
.
format
(
list
(
REGISTERED_NODE_HANDLERS
.
keys
()))
assert
name
in
REGISTERED_NODE_HANDLERS
,
msg
return
REGISTERED_NODE_HANDLERS
[
name
]
class
NetworkInterpreter
(
torch
.
fx
.
Interpreter
):
def
__init__
(
self
,
network_ctx
,
module
:
torch
.
fx
.
GraphModule
,
inputs
:
List
[
Any
],
verbose
:
bool
=
False
):
super
().
__init__
(
module
)
self
.
network_ctx
=
network_ctx
self
.
_inputs
=
inputs
self
.
_outputs
=
None
self
.
_cur_node_name
:
Optional
[
str
]
=
None
self
.
_input_names
:
List
[
str
]
=
[]
self
.
_output_names
:
List
[
str
]
=
[]
self
.
_verbose
=
verbose
def
run
(
self
):
super
().
run
(
*
self
.
_inputs
)
assert
self
.
_outputs
is
not
None
return
self
.
_outputs
def
run_node
(
self
,
n
):
self
.
_cur_node_name
=
str
(
n
)
return
super
().
run_node
(
n
)
def
call_module
(
self
,
target
,
args
,
kwargs
):
assert
isinstance
(
target
,
str
)
submod
=
self
.
fetch_attr
(
target
)
submod_type
=
getattr
(
submod
,
"_base_class_origin"
,
type
(
submod
))
type_str
=
submod_type
.
__qualname__
type_str_parts
=
type_str
.
split
(
"."
)
msg
=
f
"[Module.
{
type_str_parts
[
-
1
]
}
]
{
target
}
(
{
args
}
|
{
kwargs
}
) => "
try
:
converter
=
get_node_handler
(
submod_type
)
res
=
converter
(
self
.
network_ctx
,
submod
,
args
,
kwargs
,
self
.
_cur_node_name
)
msg
+=
f
"
{
res
}
"
if
self
.
_verbose
:
print
(
msg
)
return
res
except
Exception
as
e
:
if
self
.
_verbose
:
print
(
msg
)
raise
e
def
call_function
(
self
,
target
,
args
,
kwargs
):
msg
=
f
"[Func]
{
target
}
(
{
args
}
|
{
kwargs
}
) => "
try
:
converter
=
get_node_handler
(
target
)
res
=
converter
(
self
.
network_ctx
,
target
,
args
,
kwargs
,
self
.
_cur_node_name
)
msg
+=
f
"
{
res
}
"
if
self
.
_verbose
:
print
(
msg
)
return
res
except
Exception
as
e
:
if
self
.
_verbose
:
print
(
msg
)
raise
e
def
call_method
(
self
,
target
,
args
,
kwargs
):
msg
=
f
"[Method]
{
target
}
(
{
args
}
|
{
kwargs
}
) => "
assert
isinstance
(
target
,
str
)
try
:
key
=
(
type
(
args
[
0
]),
target
)
converter
=
get_node_handler
(
key
)
res
=
converter
(
self
.
network_ctx
,
target
,
args
,
kwargs
,
self
.
_cur_node_name
)
msg
+=
f
"
{
res
}
"
if
self
.
_verbose
:
print
(
msg
)
return
res
except
Exception
as
e
:
if
self
.
_verbose
:
print
(
msg
)
raise
e
def
output
(
self
,
target
,
args
,
kwargs
):
self
.
_outputs
=
args
spconv/pytorch/quantization/interpreter.py
View file @
004effbd
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
from
..interpreter
import
*
\ No newline at end of file
import
torch
import
torch.fx
REGISTERED_NODE_HANDLERS
:
Dict
[
Any
,
Any
]
=
{}
def
register_node_handler
(
*
names
):
def
wrap_func
(
handler
):
global
REGISTERED_NODE_HANDLERS
for
n
in
names
:
REGISTERED_NODE_HANDLERS
[
n
]
=
handler
def
new_handler
(
inputs
,
attributes
,
scope
):
return
handler
(
inputs
,
attributes
,
scope
)
return
new_handler
return
wrap_func
def
register_method_handler
(
name
:
str
,
tensor_classes
):
if
not
isinstance
(
tensor_classes
,
(
list
,
tuple
)):
tensor_classes
=
[
tensor_classes
]
def
wrap_func
(
handler
):
global
REGISTERED_NODE_HANDLERS
for
tcls
in
tensor_classes
:
REGISTERED_NODE_HANDLERS
[(
tcls
,
name
)]
=
handler
def
new_handler
(
inputs
,
attributes
,
scope
):
return
handler
(
inputs
,
attributes
,
scope
)
return
new_handler
return
wrap_func
def
get_node_handler
(
name
):
global
REGISTERED_NODE_HANDLERS
msg
=
"missing handler "
+
str
(
name
)
msg
+=
", available handlers: {}"
.
format
(
list
(
REGISTERED_NODE_HANDLERS
.
keys
()))
assert
name
in
REGISTERED_NODE_HANDLERS
,
msg
return
REGISTERED_NODE_HANDLERS
[
name
]
class
NetworkInterpreter
(
torch
.
fx
.
Interpreter
):
def
__init__
(
self
,
network_ctx
,
module
:
torch
.
fx
.
GraphModule
,
inputs
:
List
[
Any
],
verbose
:
bool
=
False
):
super
().
__init__
(
module
)
self
.
network_ctx
=
network_ctx
self
.
_inputs
=
inputs
self
.
_outputs
=
None
self
.
_cur_node_name
:
Optional
[
str
]
=
None
self
.
_input_names
:
List
[
str
]
=
[]
self
.
_output_names
:
List
[
str
]
=
[]
self
.
_verbose
=
verbose
def
run
(
self
):
super
().
run
(
*
self
.
_inputs
)
assert
self
.
_outputs
is
not
None
return
self
.
_outputs
def
run_node
(
self
,
n
):
self
.
_cur_node_name
=
str
(
n
)
return
super
().
run_node
(
n
)
def
call_module
(
self
,
target
,
args
,
kwargs
):
assert
isinstance
(
target
,
str
)
submod
=
self
.
fetch_attr
(
target
)
submod_type
=
getattr
(
submod
,
"_base_class_origin"
,
type
(
submod
))
type_str
=
submod_type
.
__qualname__
type_str_parts
=
type_str
.
split
(
"."
)
msg
=
f
"[Module.
{
type_str_parts
[
-
1
]
}
]
{
target
}
(
{
args
}
|
{
kwargs
}
) => "
try
:
converter
=
get_node_handler
(
submod_type
)
res
=
converter
(
self
.
network_ctx
,
submod
,
args
,
kwargs
,
self
.
_cur_node_name
)
msg
+=
f
"
{
res
}
"
if
self
.
_verbose
:
print
(
msg
)
return
res
except
Exception
as
e
:
if
self
.
_verbose
:
print
(
msg
)
raise
e
def
call_function
(
self
,
target
,
args
,
kwargs
):
msg
=
f
"[Func]
{
target
}
(
{
args
}
|
{
kwargs
}
) => "
try
:
converter
=
get_node_handler
(
target
)
res
=
converter
(
self
.
network_ctx
,
target
,
args
,
kwargs
,
self
.
_cur_node_name
)
msg
+=
f
"
{
res
}
"
if
self
.
_verbose
:
print
(
msg
)
return
res
except
Exception
as
e
:
if
self
.
_verbose
:
print
(
msg
)
raise
e
def
call_method
(
self
,
target
,
args
,
kwargs
):
msg
=
f
"[Method]
{
target
}
(
{
args
}
|
{
kwargs
}
) => "
assert
isinstance
(
target
,
str
)
try
:
key
=
(
type
(
args
[
0
]),
target
)
converter
=
get_node_handler
(
key
)
res
=
converter
(
self
.
network_ctx
,
target
,
args
,
kwargs
,
self
.
_cur_node_name
)
msg
+=
f
"
{
res
}
"
if
self
.
_verbose
:
print
(
msg
)
return
res
except
Exception
as
e
:
if
self
.
_verbose
:
print
(
msg
)
raise
e
def
output
(
self
,
target
,
args
,
kwargs
):
self
.
_outputs
=
args
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment