Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdeploy
Commits
546b4279
Commit
546b4279
authored
Jun 25, 2025
by
limm
Browse files
add csrc and mmdeploy module
parent
502f4fb9
Pipeline
#2810
canceled with stages
Changes
447
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1340 additions
and
0 deletions
+1340
-0
csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu
...oy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu
+75
-0
csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp
...y/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp
+11
-0
csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt
csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt
+3
-0
csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt
csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt
+49
-0
csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp
csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp
+13
-0
csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp
...backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp
+31
-0
csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp
...t/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp
+94
-0
csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu
...t/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu
+97
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/CMakeLists.txt
...mmdeploy/backend_ops/torchscript/optimizer/CMakeLists.txt
+18
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp
csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp
+47
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp
...backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp
+313
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h
...y/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h
+38
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp
.../mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp
+70
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h
csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h
+10
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp
...ipt/optimizer/passes/onnx/common_subgraph_elimination.cpp
+138
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h
...cript/optimizer/passes/onnx/common_subgraph_elimination.h
+20
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp
...ps/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp
+119
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h
..._ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h
+14
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp
.../torchscript/optimizer/passes/onnx/fuse_select_assign.cpp
+163
-0
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h
...ps/torchscript/optimizer/passes/onnx/fuse_select_assign.h
+17
-0
No files found.
Too many changes to show.
To preserve performance only
447 of 447+
files are displayed.
Plain diff
Email patch
csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include <stdio.h>
#include <vector>
#include "common_cuda_helper.hpp"
#include "trt_plugin_helper.hpp"
using
mmdeploy
::
TensorDesc
;
template
<
typename
T
>
__global__
void
onnx_scatternd_kernel
(
const
int
n
,
const
int
*
indices
,
const
T
*
update
,
T
*
output
,
TensorDesc
tensor_desc
,
TensorDesc
indice_desc
)
{
const
int
indice_cols
=
indice_desc
.
shape
[
indice_desc
.
dim
-
1
];
const
int
copy_stride
=
tensor_desc
.
stride
[
indice_cols
-
1
];
const
int
*
stride
=
&
(
tensor_desc
.
stride
[
0
]);
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
int
output_offset
=
0
;
const
int
*
indices_current
=
indices
+
index
*
indice_cols
;
for
(
int
i
=
0
;
i
<
indice_cols
;
++
i
)
{
output_offset
+=
stride
[
i
]
*
indices_current
[
i
];
}
memcpy
(
output
+
output_offset
,
update
+
index
*
copy_stride
,
copy_stride
*
sizeof
(
T
));
}
}
template
<
typename
T
>
void
TRTONNXScatterNDKernelLauncher
(
const
T
*
data
,
const
int
*
indices
,
const
T
*
update
,
const
int
*
dims
,
int
nbDims
,
const
int
*
indices_dims
,
int
indice_nbDims
,
T
*
output
,
cudaStream_t
stream
)
{
// fill tensordesc and initial
TensorDesc
tensor_desc
;
memset
((
void
*
)
&
tensor_desc
,
0
,
sizeof
(
TensorDesc
));
tensor_desc
.
dim
=
nbDims
;
tensor_desc
.
shape
[
nbDims
-
1
]
=
dims
[
nbDims
-
1
];
tensor_desc
.
stride
[
nbDims
-
1
]
=
1
;
for
(
int
i
=
nbDims
-
2
;
i
>=
0
;
--
i
)
{
tensor_desc
.
shape
[
i
]
=
dims
[
i
];
tensor_desc
.
stride
[
i
]
=
dims
[
i
+
1
]
*
tensor_desc
.
stride
[
i
+
1
];
}
const
int
data_size
=
tensor_desc
.
stride
[
0
]
*
tensor_desc
.
shape
[
0
];
TensorDesc
indice_desc
;
memset
((
void
*
)
&
indice_desc
,
0
,
sizeof
(
TensorDesc
));
indice_desc
.
dim
=
indice_nbDims
;
indice_desc
.
shape
[
indice_nbDims
-
1
]
=
indices_dims
[
indice_nbDims
-
1
];
indice_desc
.
stride
[
indice_nbDims
-
1
]
=
1
;
for
(
int
i
=
indice_nbDims
-
2
;
i
>=
0
;
--
i
)
{
indice_desc
.
shape
[
i
]
=
indices_dims
[
i
];
indice_desc
.
stride
[
i
]
=
indices_dims
[
i
+
1
]
*
indice_desc
.
stride
[
i
+
1
];
}
// output = np.copy(data)
cudaMemcpyAsync
(
output
,
data
,
data_size
*
sizeof
(
T
),
cudaMemcpyDeviceToDevice
,
stream
);
int
num_update_indice
=
1
;
for
(
int
i
=
0
;
i
<
indice_nbDims
-
1
;
++
i
)
{
num_update_indice
*=
indice_desc
.
shape
[
i
];
}
// scatter
const
int
col_block
=
DIVUP
(
num_update_indice
,
THREADS_PER_BLOCK
);
onnx_scatternd_kernel
<<<
col_block
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
num_update_indice
,
indices
,
update
,
output
,
tensor_desc
,
indice_desc
);
}
template
void
TRTONNXScatterNDKernelLauncher
<
float
>(
const
float
*
data
,
const
int
*
indices
,
const
float
*
update
,
const
int
*
dims
,
int
nbDims
,
const
int
*
indices_dims
,
int
indice_nbDims
,
float
*
output
,
cudaStream_t
stream
);
template
void
TRTONNXScatterNDKernelLauncher
<
int
>(
const
int
*
data
,
const
int
*
indices
,
const
int
*
update
,
const
int
*
dims
,
int
nbDims
,
const
int
*
indices_dims
,
int
indice_nbDims
,
int
*
output
,
cudaStream_t
stream
);
csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCATTERND_KERNEL_HPP
#define TRT_SCATTERND_KERNEL_HPP
#include <cuda_runtime.h>
template
<
typename
T
>
void
TRTONNXScatterNDKernelLauncher
(
const
T
*
data
,
const
int
*
indices
,
const
T
*
update
,
const
int
*
dims
,
int
nbDims
,
const
int
*
indices_dims
,
int
indice_nbDims
,
T
*
output
,
cudaStream_t
stream
);
#endif // TRT_SCATTERND_KERNEL_HPP
csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt
0 → 100644
View file @
546b4279
# Copyright (c) OpenMMLab. All rights reserved.
add_subdirectory
(
ops
)
csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt
0 → 100644
View file @
546b4279
# Copyright (c) OpenMMLab. All rights reserved.
if
(
"cuda"
IN_LIST MMDEPLOY_TARGET_DEVICES
)
project
(
mmdeploy_torchscript_ops CUDA CXX
)
file
(
GLOB_RECURSE BACKEND_OPS_SRCS *.cpp *.cu
)
else
()
project
(
mmdeploy_torchscript_ops CXX
)
file
(
GLOB_RECURSE BACKEND_OPS_SRCS *.cpp
)
endif
()
find_package
(
Torch REQUIRED
)
if
(
MSVC
)
# workaround to fix building torchscript ops on windows
set
(
_TORCH_TARGET torch_cuda_cu torch_cuda_cpp torch_cpu
)
foreach
(
_target IN LISTS _TORCH_TARGET
)
if
(
TARGET
${
_target
}
)
get_property
(
FIXED_TORCH_CPU_COMPILE_OPTIONS TARGET
${
_target
}
PROPERTY INTERFACE_COMPILE_OPTIONS
)
string
(
REPLACE
";"
" "
FIXED_TORCH_CPU_COMPILE_OPTIONS
"
${
FIXED_TORCH_CPU_COMPILE_OPTIONS
}
"
)
set_property
(
TARGET
${
_target
}
PROPERTY INTERFACE_COMPILE_OPTIONS -Xcompiler
"
${
FIXED_TORCH_CPU_COMPILE_OPTIONS
}
"
)
else
()
message
(
WARNING
"Target
${
_target
}
not found."
)
endif
()
endforeach
()
endif
()
add_library
(
${
PROJECT_NAME
}
_obj OBJECT
"
${
BACKEND_OPS_SRCS
}
"
)
set_target_properties
(
${
PROJECT_NAME
}
_obj PROPERTIES POSITION_INDEPENDENT_CODE 1
)
target_compile_definitions
(
${
PROJECT_NAME
}
_obj
PRIVATE -DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1
)
target_include_directories
(
${
PROJECT_NAME
}
_obj
PRIVATE
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../common
)
target_include_directories
(
${
PROJECT_NAME
}
_obj
PRIVATE
${
CMAKE_CURRENT_SOURCE_DIR
}
/common
)
if
(
"cuda"
IN_LIST MMDEPLOY_TARGET_DEVICES
)
target_include_directories
(
${
PROJECT_NAME
}
_obj
PRIVATE
${
CUDA_TOOLKIT_ROOT_DIR
}
/include
)
endif
()
target_link_libraries
(
${
PROJECT_NAME
}
_obj PRIVATE
${
TORCH_LIBRARIES
}
)
mmdeploy_export
(
${
PROJECT_NAME
}
_obj
)
# Build module library. It is used to inference with torchscript
mmdeploy_add_module
(
${
PROJECT_NAME
}
MODULE EXCLUDE
""
)
target_link_libraries
(
${
PROJECT_NAME
}
PUBLIC
${
PROJECT_NAME
}
_obj
)
add_library
(
mmdeploy::torchscript_ops ALIAS
${
PROJECT_NAME
}
)
set
(
_TORCHJIT_OPS_DIR
${
CMAKE_SOURCE_DIR
}
/mmdeploy/lib
)
install
(
TARGETS
${
PROJECT_NAME
}
DESTINATION
${
_TORCHJIT_OPS_DIR
}
)
csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "torch/script.h"
TORCH_LIBRARY
(
mmdeploy
,
m
)
{
m
.
def
(
"modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor offset, Tensor "
"mask, "
"int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int "
"dilation_h,int dilation_w, int groups, int deform_groups, bool with_bias) -> Tensor"
)
.
def
(
"coreml_nms(Tensor boxes, Tensor scores, float iou_threshold, "
"float score_threshold, int max_boxes) -> Tensor[]"
);
}
csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp
0 → 100644
View file @
546b4279
#include <assert.h>
#include <vector>
#include "torch/script.h"
namespace
mmdeploy
{
using
at
::
Tensor
;
std
::
vector
<
Tensor
>
coreml_nms_cpu
(
Tensor
boxes
,
Tensor
scores
,
double
iou_threshold
,
double
score_threshold
,
int64_t
max_boxes
)
{
assert
(
boxes
.
dim
()
==
3
);
// bboxes with shape (batch_size, num_bboxes, 4)
assert
(
boxes
.
size
(
2
)
==
4
);
assert
(
boxes
.
size
(
0
)
==
scores
.
size
(
0
));
// check batch size
assert
(
boxes
.
size
(
1
)
==
scores
.
size
(
1
));
// check num boxes
auto
batch_size
=
boxes
.
size
(
0
);
auto
num_boxes
=
boxes
.
size
(
1
);
auto
num_classes
=
scores
.
size
(
2
);
Tensor
ret_boxes
=
at
::
zeros
({
batch_size
,
max_boxes
,
4
});
Tensor
ret_scores
=
at
::
zeros
({
batch_size
,
max_boxes
,
num_classes
});
Tensor
indices
=
at
::
zeros
({
batch_size
,
max_boxes
},
at
::
kInt
);
Tensor
num_outputs
=
at
::
zeros
({
batch_size
},
at
::
kInt
);
return
std
::
vector
<
Tensor
>
({
ret_boxes
,
ret_scores
,
indices
,
num_outputs
});
}
TORCH_LIBRARY_IMPL
(
mmdeploy
,
CPU
,
m
)
{
m
.
impl
(
"coreml_nms"
,
coreml_nms_cpu
);
}
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "modulated_deform_conv/modulated_deform_conv_cpu.h"
#include "torch/script.h"
namespace
mmdeploy
{
void
modulated_deformable_im2col_cpu
(
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
at
::
Tensor
data_mask
,
const
int64_t
batch_size
,
const
int64_t
channels
,
const
int64_t
height_im
,
const
int64_t
width_im
,
const
int64_t
height_col
,
const
int64_t
width_col
,
const
int64_t
kernel_h
,
const
int64_t
kernel_w
,
const
int64_t
pad_h
,
const
int64_t
pad_w
,
const
int64_t
stride_h
,
const
int64_t
stride_w
,
const
int64_t
dilation_h
,
const
int64_t
dilation_w
,
int64_t
deformable_group
,
at
::
Tensor
data_col
)
{
// num_axes should be smaller than block size
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
data_im
.
scalar_type
(),
"modulated_deformable_im2col_cpu"
,
([
&
]
{
const
scalar_t
*
data_im_
=
data_im
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_offset_
=
data_offset
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_mask_
=
data_mask
.
data_ptr
<
scalar_t
>
();
scalar_t
*
data_col_
=
data_col
.
data_ptr
<
scalar_t
>
();
deformable_im2col_2d
<
scalar_t
>
(
data_im_
,
data_offset_
,
data_mask_
,
height_im
,
width_im
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
deformable_group
,
height_col
,
width_col
,
data_mask_
!=
nullptr
,
data_col_
);
}));
}
at
::
Tensor
modulated_deform_conv_forward_cpu
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
at
::
Tensor
offset
,
at
::
Tensor
mask
,
int64_t
kernel_h
,
int64_t
kernel_w
,
int64_t
stride_h
,
int64_t
stride_w
,
int64_t
pad_h
,
int64_t
pad_w
,
int64_t
dilation_h
,
int64_t
dilation_w
,
int64_t
group
,
int64_t
deformable_group
,
bool
with_bias
)
{
at
::
DeviceGuard
guard
(
input
.
device
());
const
int
batch
=
input
.
size
(
0
);
const
int
channels
=
input
.
size
(
1
);
const
int
height
=
input
.
size
(
2
);
const
int
width
=
input
.
size
(
3
);
const
int
channels_out
=
weight
.
size
(
0
);
const
int
channels_kernel
=
weight
.
size
(
1
);
const
int
kernel_h_
=
weight
.
size
(
2
);
const
int
kernel_w_
=
weight
.
size
(
3
);
if
(
kernel_h_
!=
kernel_h
||
kernel_w_
!=
kernel_w
)
AT_ERROR
(
"Input shape and kernel shape won't match: (%d x %d vs %d x %d)."
,
kernel_h_
,
kernel_w
,
kernel_h_
,
kernel_w_
);
if
(
channels
!=
channels_kernel
*
group
)
AT_ERROR
(
"Input shape and kernel channels won't match: (%d vs %d)."
,
channels
,
channels_kernel
*
group
);
const
int
height_out
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
width_out
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
// resize output
at
::
Tensor
output
=
at
::
zeros
({
batch
,
group
,
channels_out
/
group
,
height_out
,
width_out
},
input
.
options
());
// resize temporary columns
at
::
Tensor
columns
=
at
::
zeros
(
{
group
,
channels
*
kernel_h
*
kernel_w
/
group
,
1
*
height_out
*
width_out
},
input
.
options
());
// divide into group
weight
=
weight
.
view
({
group
,
weight
.
size
(
0
)
/
group
,
weight
.
size
(
1
),
weight
.
size
(
2
),
weight
.
size
(
3
)});
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
modulated_deformable_im2col_cpu
(
input
[
b
],
offset
[
b
],
mask
[
b
],
1
,
channels
,
height
,
width
,
height_out
,
width_out
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
deformable_group
,
columns
);
for
(
int
g
=
0
;
g
<
group
;
g
++
)
{
output
[
b
][
g
]
=
output
[
b
][
g
].
flatten
(
1
).
addmm_
(
weight
[
g
].
flatten
(
1
),
columns
[
g
]).
view_as
(
output
[
b
][
g
]);
}
}
output
=
output
.
view
(
{
output
.
size
(
0
),
output
.
size
(
1
)
*
output
.
size
(
2
),
output
.
size
(
3
),
output
.
size
(
4
)});
if
(
with_bias
)
{
output
+=
bias
.
view
({
1
,
bias
.
size
(
0
),
1
,
1
});
}
return
output
;
}
TORCH_LIBRARY_IMPL
(
mmdeploy
,
CPU
,
m
)
{
m
.
impl
(
"modulated_deform_conv"
,
modulated_deform_conv_forward_cpu
);
}
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "c10/cuda/CUDAStream.h"
#include "modulated_deform_conv/modulated_deform_conv_cuda.cuh"
#include "torch/script.h"
namespace
mmdeploy
{
void
modulated_deformable_im2col_cuda
(
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
at
::
Tensor
data_mask
,
const
int64_t
batch_size
,
const
int64_t
channels
,
const
int64_t
height_im
,
const
int64_t
width_im
,
const
int64_t
height_col
,
const
int64_t
width_col
,
const
int64_t
kernel_h
,
const
int64_t
kernel_w
,
const
int64_t
pad_h
,
const
int64_t
pad_w
,
const
int64_t
stride_h
,
const
int64_t
stride_w
,
const
int64_t
dilation_h
,
const
int64_t
dilation_w
,
const
int64_t
deformable_group
,
at
::
Tensor
data_col
)
{
// num_axes should be smaller than block size
const
int
channel_per_deformable_group
=
channels
/
deformable_group
;
const
int
num_kernels
=
channels
*
batch_size
*
height_col
*
width_col
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
data_im
.
scalar_type
(),
"modulated_deformable_im2col_cuda"
,
([
&
]
{
const
scalar_t
*
data_im_
=
data_im
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_offset_
=
data_offset
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_mask_
=
data_mask
.
data_ptr
<
scalar_t
>
();
scalar_t
*
data_col_
=
data_col
.
data_ptr
<
scalar_t
>
();
modulated_deformable_im2col_gpu_kernel
<
scalar_t
>
<<<
GET_BLOCKS
(
num_kernels
),
THREADS_PER_BLOCK
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
num_kernels
,
data_im_
,
data_offset_
,
data_mask_
,
height_im
,
width_im
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channel_per_deformable_group
,
batch_size
,
channels
,
deformable_group
,
height_col
,
width_col
,
data_col_
);
}));
}
at
::
Tensor
modulated_deform_conv_forward_cuda
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
at
::
Tensor
offset
,
at
::
Tensor
mask
,
int64_t
kernel_h
,
int64_t
kernel_w
,
int64_t
stride_h
,
int64_t
stride_w
,
int64_t
pad_h
,
int64_t
pad_w
,
int64_t
dilation_h
,
int64_t
dilation_w
,
int64_t
group
,
int64_t
deformable_group
,
bool
with_bias
)
{
at
::
DeviceGuard
guard
(
input
.
device
());
const
int
batch
=
input
.
size
(
0
);
const
int
channels
=
input
.
size
(
1
);
const
int
height
=
input
.
size
(
2
);
const
int
width
=
input
.
size
(
3
);
const
int
channels_out
=
weight
.
size
(
0
);
const
int
channels_kernel
=
weight
.
size
(
1
);
const
int
kernel_h_
=
weight
.
size
(
2
);
const
int
kernel_w_
=
weight
.
size
(
3
);
if
(
kernel_h_
!=
kernel_h
||
kernel_w_
!=
kernel_w
)
AT_ERROR
(
"Input shape and kernel shape won't match: (%d x %d vs %d x %d)."
,
kernel_h_
,
kernel_w
,
kernel_h_
,
kernel_w_
);
if
(
channels
!=
channels_kernel
*
group
)
AT_ERROR
(
"Input shape and kernel channels won't match: (%d vs %d)."
,
channels
,
channels_kernel
*
group
);
const
int
height_out
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
width_out
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
// resize output
at
::
Tensor
output
=
at
::
zeros
({
batch
,
group
,
channels_out
/
group
,
height_out
,
width_out
},
input
.
options
());
// resize temporary columns
at
::
Tensor
columns
=
at
::
zeros
(
{
group
,
channels
*
kernel_h
*
kernel_w
/
group
,
1
*
height_out
*
width_out
},
input
.
options
());
// divide into group
weight
=
weight
.
view
({
group
,
weight
.
size
(
0
)
/
group
,
weight
.
size
(
1
),
weight
.
size
(
2
),
weight
.
size
(
3
)});
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
modulated_deformable_im2col_cuda
(
input
[
b
],
offset
[
b
],
mask
[
b
],
1
,
channels
,
height
,
width
,
height_out
,
width_out
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
deformable_group
,
columns
);
for
(
int
g
=
0
;
g
<
group
;
g
++
)
{
output
[
b
][
g
]
=
output
[
b
][
g
].
flatten
(
1
).
addmm_
(
weight
[
g
].
flatten
(
1
),
columns
[
g
]).
view_as
(
output
[
b
][
g
]);
}
}
output
=
output
.
view
(
{
output
.
size
(
0
),
output
.
size
(
1
)
*
output
.
size
(
2
),
output
.
size
(
3
),
output
.
size
(
4
)});
if
(
with_bias
)
{
output
+=
bias
.
view
({
1
,
bias
.
size
(
0
),
1
,
1
});
}
return
output
;
}
TORCH_LIBRARY_IMPL
(
mmdeploy
,
CUDA
,
m
)
{
m
.
impl
(
"modulated_deform_conv"
,
modulated_deform_conv_forward_cuda
);
}
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/optimizer/CMakeLists.txt
0 → 100644
View file @
546b4279
# Copyright (c) OpenMMLab. All rights reserved.
project
(
ts_optimizer
)
find_package
(
Torch REQUIRED
)
find_library
(
TORCH_PYTHON_LIBRARY torch_python PATHS
"
${
TORCH_INSTALL_PREFIX
}
/lib"
)
if
(
NOT TARGET pybind11
)
add_subdirectory
(
${
CMAKE_SOURCE_DIR
}
/third_party/pybind11 pybind11
)
endif
()
file
(
GLOB_RECURSE OPTIMIZER_SRCS *.cpp
)
pybind11_add_module
(
${
PROJECT_NAME
}
${
OPTIMIZER_SRCS
}
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
${
TORCH_LIBRARIES
}
${
TORCH_PYTHON_LIBRARY
}
)
target_link_directories
(
${
PROJECT_NAME
}
PRIVATE mmdeploy::torchscript_ops
)
set_target_properties
(
${
PROJECT_NAME
}
PROPERTIES LIBRARY_OUTPUT_DIRECTORY
${
CMAKE_SOURCE_DIR
}
/mmdeploy/backend/torchscript
)
csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <string>
#include "optimizer.h"
#include "passes/onnx/common_subgraph_elimination.h"
#include "passes/onnx/flatten_cls_head.h"
#include "passes/onnx/fuse_select_assign.h"
#include "passes/onnx/merge_shape_concate.h"
#include "passes/onnx/onnx_peephole.h"
namespace
mmdeploy
{
namespace
torch_jit
{
void
optimize_for_backend
(
torch
::
jit
::
Module
&
model
,
const
std
::
string
&
ir
=
"torchscript"
,
const
std
::
string
&
backend
=
"torchscript"
)
{
if
(
ir
==
"torchscript"
)
{
model
=
optimize_for_torchscript
(
model
);
}
else
if
(
ir
==
"onnx"
)
{
model
=
optimize_for_onnx
(
model
);
}
else
{
fprintf
(
stderr
,
"No optimize for combination ir: %s backend: %s
\n
"
,
ir
.
c_str
(),
backend
.
c_str
());
exit
(
-
1
);
}
}
PYBIND11_MODULE
(
ts_optimizer
,
m
)
{
namespace
py
=
pybind11
;
m
.
def
(
"optimize_for_backend"
,
optimize_for_backend
,
py
::
arg
(
"module"
),
py
::
arg
(
"ir"
)
=
std
::
string
(
"torchscript"
),
py
::
arg
(
"backend"
)
=
std
::
string
(
"torchscript"
));
py
::
module_
onnx_module
=
m
.
def_submodule
(
"onnx"
);
onnx_module
.
def
(
"_jit_pass_merge_shape_concate"
,
MergeShapeConcate
,
py
::
arg
(
"graph"
));
onnx_module
.
def
(
"_jit_pass_onnx_peephole"
,
ONNXPeephole
,
py
::
arg
(
"graph"
));
onnx_module
.
def
(
"_jit_pass_flatten_cls_head"
,
FlattenClsHead
,
py
::
arg
(
"graph"
));
onnx_module
.
def
(
"_jit_pass_fuse_select_assign"
,
FuseSelectAssign
,
py
::
arg
(
"graph"
),
py
::
arg
(
"params"
));
onnx_module
.
def
(
"_jit_pass_common_subgraph_elimination"
,
CommonSubgraphElimination
,
py
::
arg
(
"graph"
),
py
::
arg
(
"params"
));
}
}
// namespace torch_jit
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp
0 → 100644
View file @
546b4279
// modify from:
// https://github.com/pytorch/pytorch/blob/v1.8.1/torch/csrc/jit/ir/subgraph_matcher.cpp
#include "subgraph_matcher.h"
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/jit_log.h>
#include <regex>
#include <stack>
namespace
mmdeploy
{
namespace
torch_jit
{
using
torch
::
jit
::
AttributeKind
;
using
torch
::
jit
::
ClassType
;
using
torch
::
jit
::
Node
;
using
torch
::
jit
::
Symbol
;
using
torch
::
jit
::
Value
;
namespace
prim
{
using
namespace
::
c10
::
prim
;
}
namespace
attr
{
using
namespace
::
c10
::
attr
;
}
/**
* \brief A class implementing an API for comparing subgraphs.
*/
class
SubgraphMatcher
::
SubgraphMatcherImpl
{
public:
explicit
SubgraphMatcherImpl
(
const
Graph
&
pattern
,
MatchAttribute
match_attribute
)
:
pattern_
(
pattern
),
match_attribute_
(
match_attribute
)
{}
/**
* \brief Compare matchGraph with the part of the graph denoted by a node \p
* ANCHOR.
*
* The anchor node would be compared against the deepest node in the
* match-graph. A node is considered matching if its number of inputs/outputs
* is the same as in the corresponding matchGraph node, its type is the same,
* and all nodes producing input-values also match.
*/
bool
matchesSubgraphFromAnchorNode
(
Node
*
anchor
);
/** \brief Return match map for nodes. */
std
::
unordered_map
<
const
Node
*
,
Node
*>
nodes_map
()
const
{
return
nodes_map_
;
}
/** \brief Return match map for values. */
std
::
unordered_map
<
const
Value
*
,
Value
*>
values_map
()
const
{
return
values_map_
;
}
private:
bool
matchValues
(
const
Value
*
v1
,
Value
*
v2
);
bool
matchNodes
(
const
Node
*
n1
,
Node
*
n2
);
bool
matchAttributes
(
const
Node
*
n1
,
Node
*
n2
);
static
bool
isInput
(
const
Value
*
v
);
static
bool
isOutput
(
const
Value
*
v
);
std
::
unordered_map
<
const
Node
*
,
Node
*>
nodes_map_
;
std
::
unordered_map
<
const
Value
*
,
Value
*>
values_map_
;
const
MatchAttribute
match_attribute_
;
const
Graph
&
pattern_
;
const
Node
*
anchor_
=
nullptr
;
};
bool
SubgraphMatcher
::
SubgraphMatcherImpl
::
isInput
(
const
Value
*
v
)
{
return
v
->
node
()
->
kind
()
==
prim
::
Param
;
}
bool
SubgraphMatcher
::
SubgraphMatcherImpl
::
isOutput
(
const
Value
*
v
)
{
for
(
const
Value
*
output
:
v
->
owningGraph
()
->
outputs
())
{
if
(
v
==
output
)
{
return
true
;
}
}
return
false
;
}
/**
* Compare two Values. V1 is from pattern, V2 is from the actual graph.
*
* The values are considered matching if:
* 1) the nodes defining them match
* 2) they have the same number of uses, except they are entry or exit nodes.
*/
bool
SubgraphMatcher
::
SubgraphMatcherImpl
::
matchValues
(
const
Value
*
v1
,
Value
*
v2
)
{
// Check if we've already visited these values.
if
(
values_map_
.
count
(
v1
))
{
if
(
values_map_
.
at
(
v1
)
!=
v2
)
{
GRAPH_DEBUG
(
"Values %"
,
v1
->
debugName
(),
" and %"
,
v2
->
debugName
(),
" did not match because %"
,
v1
->
debugName
(),
" has already been matched with %"
,
values_map_
.
at
(
v1
)
->
debugName
(),
".
\n
"
);
return
false
;
}
return
true
;
}
// When V2 is ANCHOR, we're comparing exiting values, and when V1->node is
// PARAM, we're comparing entering values - in these two cases the number of
// uses don't need to be the same.
if
(
v1
->
uses
().
size
()
!=
v2
->
uses
().
size
()
&&
!
isOutput
(
v1
)
&&
!
isInput
(
v1
))
{
GRAPH_DEBUG
(
"Values %"
,
v1
->
debugName
(),
" and %"
,
v2
->
debugName
(),
" did not match because number of their uses is different.
\n
"
);
return
false
;
}
// Add the values to the map before calling matchNodes to avoid infinite
// recursion.
GRAPH_DEBUG
(
"Values %"
,
v1
->
debugName
(),
" and %"
,
v2
->
debugName
(),
" matched.
\n
"
);
values_map_
[
v1
]
=
v2
;
return
matchNodes
(
v1
->
node
(),
v2
->
node
());
}
bool
SubgraphMatcher
::
SubgraphMatcherImpl
::
matchAttributes
(
const
Node
*
n1
,
Node
*
n2
)
{
if
(
match_attribute_
==
FORCE_MATCH
&&
n1
->
numAttributes
()
!=
n2
->
numAttributes
())
{
GRAPH_DEBUG
(
"Nodes did not match in number attributes:
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
for
(
const
Symbol
&
attr_name
:
n1
->
attributeNames
())
{
if
(
n1
->
kindOf
(
attr_name
)
!=
n2
->
kindOf
(
attr_name
))
{
GRAPH_DEBUG
(
"Nodes did not match because type of attribute '"
,
attr_name
.
toQualString
(),
"' did not match:
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
std
::
vector
<
int64_t
>
n1is
,
n2is
;
std
::
vector
<
double
>
n1fs
,
n2fs
;
switch
(
n1
->
kindOf
(
attr_name
))
{
case
AttributeKind
::
s
:
if
(
!
std
::
regex_match
(
n2
->
s
(
attr_name
),
std
::
regex
(
n1
->
s
(
attr_name
))))
{
GRAPH_DEBUG
(
"Nodes did not match because attribute '"
,
attr_name
.
toQualString
(),
"' did not match: "
,
n1
->
s
(
attr_name
),
" != "
,
n2
->
s
(
attr_name
),
"
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
break
;
case
AttributeKind
::
f
:
if
(
n1
->
f
(
attr_name
)
!=
n2
->
f
(
attr_name
))
{
GRAPH_DEBUG
(
"Nodes did not match because attribute '"
,
attr_name
.
toQualString
(),
"' did not match:"
,
n1
->
f
(
attr_name
),
" != "
,
n2
->
f
(
attr_name
),
"
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
break
;
case
AttributeKind
::
i
:
if
(
n1
->
i
(
attr_name
)
!=
n2
->
i
(
attr_name
))
{
GRAPH_DEBUG
(
"Nodes did not match because attribute '"
,
attr_name
.
toQualString
(),
"' did not match:"
,
n1
->
i
(
attr_name
),
" != "
,
n2
->
i
(
attr_name
),
"
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
break
;
case
AttributeKind
::
is
:
n1is
=
n1
->
is
(
attr_name
);
n2is
=
n2
->
is
(
attr_name
);
if
(
n1is
.
size
()
!=
n2is
.
size
())
return
false
;
for
(
size_t
i
=
0
;
i
<
n1is
.
size
();
++
i
)
{
if
(
n1is
[
i
]
!=
n2is
[
i
])
return
false
;
}
break
;
case
AttributeKind
::
fs
:
n1fs
=
n1
->
fs
(
attr_name
);
n2fs
=
n2
->
fs
(
attr_name
);
if
(
n1fs
.
size
()
!=
n2fs
.
size
())
return
false
;
for
(
size_t
i
=
0
;
i
<
n1fs
.
size
();
++
i
)
{
if
(
n1fs
[
i
]
!=
n2fs
[
i
])
return
false
;
}
break
;
default:
{
// Other attributes types not supported yet
GRAPH_DEBUG
(
"Nodes did not match because type of attribute '"
,
attr_name
.
toQualString
(),
"' is not supported.
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
}
}
return
true
;
}
static
bool
endsWith
(
const
std
::
string
&
str
,
const
std
::
string
&
suffix
)
{
return
str
.
size
()
>=
suffix
.
size
()
&&
0
==
str
.
compare
(
str
.
size
()
-
suffix
.
size
(),
suffix
.
size
(),
suffix
);
}
/**
* Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
*
* The nodes are considered matching if:
* 1) N1 and N2 are of the same kind.
* 2) Number of inputs and outputs is the same.
* 3) All input and output values match.
*
* A special case is when N1 is PARAM - this is considered outside the pattern,
* so it matches everything.
*/
bool
SubgraphMatcher
::
SubgraphMatcherImpl
::
matchNodes
(
const
Node
*
n1
,
Node
*
n2
)
{
// Check if we've already visited these nodes.
if
(
nodes_map_
.
count
(
n1
))
{
return
nodes_map_
.
at
(
n1
)
==
n2
;
}
// Param node in pattern graph matches everything.
if
(
n1
->
kind
()
==
prim
::
Param
)
{
GRAPH_DEBUG
(
"Nodes matched:
\n
"
,
*
n1
,
*
n2
);
return
true
;
}
// We don't allow matches to span across blocks, so check if N2 is in the same
// block as the first (anchor) node.
if
(
n2
->
owningBlock
()
!=
anchor_
->
owningBlock
())
{
GRAPH_DEBUG
(
"Nodes did not match because it is in the different block:
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
// Special handling for matching modules
if
(
n1
->
kind
()
==
Symbol
::
fromQualString
(
"match::module"
))
{
if
(
n2
->
kind
()
==
prim
::
GetAttr
)
{
if
(
!
n1
->
hasAttributeS
(
"name"
))
{
GRAPH_DEBUG
(
"Nodes did not match because special node match::module does not have 'name' "
"attribute:
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
auto
t
=
n2
->
output
()
->
type
()
->
expect
<
ClassType
>
();
auto
real_typename
=
t
->
name
()
->
qualifiedName
();
auto
pattern_typename
=
n1
->
s
(
attr
::
name
);
if
(
!
endsWith
(
real_typename
,
pattern_typename
))
{
GRAPH_DEBUG
(
"Nodes did not match because expected module type is different:
\n
"
);
GRAPH_DEBUG
(
" actualtype: "
,
real_typename
,
"
\n
"
);
GRAPH_DEBUG
(
" expected type: "
,
pattern_typename
,
"
\n
"
);
GRAPH_DEBUG
(
"Nodes:"
,
*
n1
,
*
n2
);
return
false
;
}
}
}
else
{
if
(
n1
->
kind
()
!=
n2
->
kind
()
||
n1
->
outputs
().
size
()
!=
n2
->
outputs
().
size
()
||
n1
->
inputs
().
size
()
!=
n2
->
inputs
().
size
())
{
GRAPH_DEBUG
(
"Nodes did not match in their kind or number of inputs/outputs:
\n
"
,
*
n1
,
*
n2
);
return
false
;
}
if
(
match_attribute_
!=
NO_MATCH
)
{
if
(
!
matchAttributes
(
n1
,
n2
))
{
return
false
;
}
}
}
// Add nodes to the map before calling matchValues to avoid infinite
// recursion.
nodes_map_
[
n1
]
=
n2
;
for
(
const
auto
i
:
c10
::
irange
(
n1
->
outputs
().
size
()))
{
if
(
!
matchValues
(
n1
->
outputs
()[
i
],
n2
->
outputs
()[
i
]))
{
return
false
;
}
}
for
(
const
auto
i
:
c10
::
irange
(
n1
->
inputs
().
size
()))
{
if
(
!
matchValues
(
n1
->
inputs
()[
i
],
n2
->
inputs
()[
i
]))
{
return
false
;
}
}
GRAPH_DEBUG
(
"Nodes matched:
\n
"
,
*
n1
,
*
n2
);
return
true
;
}
/**
* Recursively try to match pattern with the actual graph starting from the
* exiting node in the pattern and anchor node in the actual graph.
*/
bool
SubgraphMatcher
::
SubgraphMatcherImpl
::
matchesSubgraphFromAnchorNode
(
Node
*
anchor
)
{
GRAPH_UPDATE
(
"Starting match from a new anchor: "
,
*
anchor
);
nodes_map_
.
clear
();
values_map_
.
clear
();
anchor_
=
anchor
;
const
Node
*
bottom_node
=
*
(
pattern_
.
nodes
().
end
());
bottom_node
=
bottom_node
->
input
(
0
)
->
node
();
if
(
!
matchNodes
(
bottom_node
,
anchor
))
{
return
false
;
}
for
(
const
Value
*
output
:
pattern_
.
outputs
())
{
AT_ASSERT
(
values_map_
.
count
(
output
));
}
GRAPH_UPDATE
(
"Pattern matched!
\n
"
);
return
true
;
}
SubgraphMatcher
::
SubgraphMatcher
(
const
Graph
&
pattern
,
MatchAttribute
match_attribute
)
:
impl_
(
new
SubgraphMatcher
::
SubgraphMatcherImpl
(
pattern
,
match_attribute
))
{}
SubgraphMatcher
::~
SubgraphMatcher
()
=
default
;
bool
SubgraphMatcher
::
matchesSubgraphFromAnchorNode
(
Node
*
anchor
)
{
return
impl_
->
matchesSubgraphFromAnchorNode
(
anchor
);
}
std
::
unordered_map
<
const
Node
*
,
Node
*>
SubgraphMatcher
::
nodes_map
()
const
{
return
impl_
->
nodes_map
();
}
std
::
unordered_map
<
const
Value
*
,
Value
*>
SubgraphMatcher
::
values_map
()
const
{
return
impl_
->
values_map
();
}
}
// namespace torch_jit
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _SUBGRAPH_MATCHER_H_
#define _SUBGRAPH_MATCHER_H_
#include <torch/script.h>
#include <memory>
namespace
mmdeploy
{
namespace
torch_jit
{
using
torch
::
jit
::
Graph
;
using
torch
::
jit
::
Node
;
using
torch
::
jit
::
Value
;
enum
MatchAttribute
{
FORCE_MATCH
,
TRY_MATCH
,
NO_MATCH
};
class
SubgraphMatcher
{
public:
explicit
SubgraphMatcher
(
const
Graph
&
pattern
,
MatchAttribute
match_attribute
=
TRY_MATCH
);
~
SubgraphMatcher
();
bool
matchesSubgraphFromAnchorNode
(
Node
*
anchor
);
/** \brief Return match map for nodes. */
std
::
unordered_map
<
const
Node
*
,
Node
*>
nodes_map
()
const
;
/** \brief Return match map for values. */
std
::
unordered_map
<
const
Value
*
,
Value
*>
values_map
()
const
;
private:
class
SubgraphMatcherImpl
;
std
::
unique_ptr
<
SubgraphMatcherImpl
>
impl_
;
};
}
// namespace torch_jit
}
// namespace mmdeploy
#endif
csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "optimizer.h"
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/remove_expands.h>
#if TORCH_VERSION_MINOR >= 9
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
#endif
namespace
mmdeploy
{
using
torch
::
jit
::
Graph
;
const
std
::
shared_ptr
<
Graph
>&
required_passes
(
const
std
::
shared_ptr
<
Graph
>&
graph
)
{
RemoveExpands
(
graph
);
CanonicalizeOps
(
graph
);
EliminateDeadCode
(
graph
);
return
graph
;
}
Module
optimize_for_torchscript
(
const
Module
&
model
)
{
auto
frozen_model
=
freeze_module
(
model
);
auto
graph
=
frozen_model
.
get_method
(
"forward"
).
graph
();
OptimizeFrozenGraph
(
graph
,
true
);
#if TORCH_VERSION_MINOR >= 9
FuseFrozenConvAddRelu
(
graph
);
ConvertFrozenOpsToMKLDNN
(
graph
);
FrozenLinearTranspose
(
graph
);
#endif
graph
=
required_passes
(
graph
);
EliminateCommonSubexpression
(
graph
);
PeepholeOptimize
(
graph
);
ConstantPropagation
(
graph
);
ConstantPooling
(
graph
);
// TODO: add more custom passes
return
frozen_model
;
}
Module
optimize_for_onnx
(
const
Module
&
model
)
{
auto
frozen_model
=
freeze_module
(
model
,
{
"training"
});
auto
graph
=
frozen_model
.
get_method
(
"forward"
).
graph
();
OptimizeFrozenGraph
(
graph
,
true
);
#if TORCH_VERSION_MINOR >= 9
FuseFrozenConvAddRelu
(
graph
);
ConvertFrozenOpsToMKLDNN
(
graph
);
FrozenLinearTranspose
(
graph
);
#endif
// TODO: add more custom passes
return
frozen_model
;
}
// TODO: add optimizer for other backend/onnx
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include <torch/script.h>
namespace
mmdeploy
{
using
torch
::
jit
::
script
::
Module
;
Module
optimize_for_torchscript
(
const
Module
&
model
);
Module
optimize_for_onnx
(
const
Module
&
model
);
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp
0 → 100644
View file @
546b4279
// https://github.com/pytorch/pytorch/blob/v1.8.1/torch/csrc/jit/passes/common_subexpression_elimination.cpp
#include "common_subgraph_elimination.h"
#include <torch/csrc/jit/ir/node_hashing.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
namespace
mmdeploy
{
namespace
torch_jit
{
using
c10
::
Symbol
;
using
torch
::
jit
::
Block
;
using
torch
::
jit
::
EqualNode
;
using
torch
::
jit
::
HashNode
;
using
torch
::
jit
::
Node
;
using
torch
::
jit
::
Value
;
struct
EqualNodeWithParams
{
EqualNodeWithParams
(
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params
)
:
params_
(
params
)
{}
bool
operator
()(
const
Node
*
lhs
,
const
Node
*
rhs
)
const
{
auto
lhs_inputs
=
lhs
->
inputs
();
auto
rhs_inputs
=
rhs
->
inputs
();
}
private:
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params_
;
};
struct
CommonSubexpressionEliminator
{
using
ParamMapType
=
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
Tensor
,
Value
*>>
;
CommonSubexpressionEliminator
(
std
::
shared_ptr
<
Graph
>
graph
,
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params
)
:
graph_
(
std
::
move
(
graph
)),
params_
(
params
)
{}
bool
run
(
std
::
function
<
Node
*
(
Node
*
)
>
parent_lookup_fn
)
{
ParamMapType
param_map
;
return
run
(
graph_
->
block
(),
std
::
move
(
parent_lookup_fn
),
param_map
);
}
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
// returns true if CSE made changes to a graph
bool
run
(
Block
*
block
,
std
::
function
<
Node
*
(
Node
*
)
>
parent_lookup_fn
,
ParamMapType
&
param_map
)
{
std
::
unordered_set
<
Node
*
,
HashNode
,
EqualNode
>
subexprs
;
bool
changed
=
false
;
for
(
auto
it
=
block
->
nodes
().
begin
();
it
!=
block
->
nodes
().
end
();
++
it
)
{
auto
node
=
*
it
;
// check if inputs come from params(graph input)
auto
node_inputs
=
node
->
inputs
();
for
(
auto
input
:
node_inputs
)
{
if
(
input
->
node
()
->
kind
()
==
Symbol
::
fromQualString
(
"prim::Param"
))
{
auto
debug_name
=
input
->
debugName
();
// check if input in params_
if
(
params_
.
find
(
debug_name
)
==
params_
.
end
())
continue
;
// check if input is already visited.
if
(
param_map
.
find
(
debug_name
)
!=
param_map
.
end
())
continue
;
// check if there is a param has same value with input
auto
val
=
params_
[
debug_name
];
bool
update_map
=
true
;
for
(
auto
kv
:
param_map
)
{
auto
param_val
=
kv
.
second
.
first
;
if
(
val
.
device
()
!=
param_val
.
device
())
continue
;
if
(
val
.
dtype
()
!=
param_val
.
dtype
())
continue
;
if
(
!
val
.
equal
(
param_val
))
continue
;
input
->
replaceAllUsesWith
(
kv
.
second
.
second
);
update_map
=
false
;
break
;
}
// add input to param_map
if
(
update_map
)
{
param_map
.
emplace
(
debug_name
,
std
::
make_pair
<
Tensor
,
Value
*>
(
std
::
move
(
val
),
std
::
move
(
input
)));
}
}
}
if
(
!
node
->
blocks
().
empty
())
{
// Traverse sub-blocks.
for
(
auto
block
:
node
->
blocks
())
{
changed
|=
run
(
block
,
[
&
](
Node
*
n
)
{
auto
existing
=
subexprs
.
find
(
n
);
if
(
existing
!=
subexprs
.
end
())
{
return
*
existing
;
}
return
parent_lookup_fn
(
n
);
},
param_map
);
}
continue
;
}
// Check for CSE opportunities in the parent block.
auto
parent_lookup
=
parent_lookup_fn
(
node
);
auto
g_out
=
node
->
owningGraph
()
->
outputs
();
if
(
parent_lookup
!=
nullptr
)
{
changed
=
true
;
node
->
replaceAllUsesWith
(
parent_lookup
);
it
.
destroyCurrent
();
continue
;
}
// Check whether the same subexpression already exists.
auto
subit
=
subexprs
.
insert
(
node
);
if
(
!
subit
.
second
)
{
// Subexpression exists, replace the uses of node, and destroy it.
auto
existing
=
*
subit
.
first
;
changed
=
true
;
node
->
replaceAllUsesWith
(
existing
);
// Destroy the node.
it
.
destroyCurrent
();
}
}
return
changed
;
}
private:
std
::
shared_ptr
<
Graph
>
graph_
;
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params_
;
};
void
CommonSubgraphElimination
(
std
::
shared_ptr
<
Graph
>&
graph
,
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params
)
{
CommonSubexpressionEliminator
cse
(
graph
,
params
);
cse
.
run
([](
Node
*
)
{
return
nullptr
;
});
}
}
// namespace torch_jit
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _COMMON_SUBGRAPH_ELIMINATION_H_
#define _COMMON_SUBGRAPH_ELIMINATION_H_
#include <torch/script.h>
namespace
mmdeploy
{
namespace
torch_jit
{
using
torch
::
Tensor
;
using
torch
::
jit
::
Graph
;
// This pass is used eliminate the common subgraph.
// There are two main difference between the one in torch/csrc/jit/pass
// 1. AliasDb is not needed in ONNX model
// 2. params might also participated in the elimination
void
CommonSubgraphElimination
(
std
::
shared_ptr
<
Graph
>&
graph
,
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params
);
}
// namespace torch_jit
}
// namespace mmdeploy
#endif
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "flatten_cls_head.h"
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <vector>
#include "utils.h"
namespace
mmdeploy
{
namespace
torch_jit
{
using
c10
::
Symbol
;
using
torch
::
jit
::
IValue
;
using
torch
::
jit
::
Match
;
using
torch
::
jit
::
TensorType
;
using
torch
::
jit
::
TypeKind
;
using
torch
::
jit
::
Value
;
static
bool
matchClsHead
(
const
Match
&
match
,
const
std
::
unordered_map
<
std
::
string
,
Value
*>&
map
)
{
// TODO: check if value map in latest pytorch can ease the filter.
// check cat -1
{
// check if the shape of second inputs is 1
auto
cat_v1
=
match
.
values_map
.
at
(
map
.
at
(
"cat1"
));
if
(
cat_v1
->
type
()
->
kind
()
!=
TypeKind
::
TensorType
)
return
false
;
auto
cat_v1_type
=
cat_v1
->
type
()
->
cast
<
TensorType
>
();
auto
cat_v1_size
=
cat_v1_type
->
sizes
().
concrete_sizes
();
if
(
!
cat_v1_size
.
has_value
())
return
false
;
IValue
cat_v1_size_value
(
cat_v1_size
.
value
());
auto
size_list
=
cat_v1_size_value
.
toIntList
();
if
(
size_list
.
size
()
!=
1
||
size_list
[
0
]
!=
1
)
return
false
;
}
// check unsqueeze
auto
cat_v0
=
match
.
values_map
.
at
(
map
.
at
(
"cat0"
));
auto
unsqueeze_node
=
cat_v0
->
node
();
{
if
(
!
is_kind
(
unsqueeze_node
,
"onnx::Unsqueeze"
))
return
false
;
auto
unsqueeze_axes
=
unsqueeze_node
->
is
(
Symbol
::
attr
(
"axes"
));
if
(
unsqueeze_axes
.
size
()
!=
1
||
unsqueeze_axes
[
0
]
!=
0
)
return
false
;
}
// check gather
auto
gather_node
=
unsqueeze_node
->
input
()
->
node
();
auto
gather_inputs
=
gather_node
->
inputs
();
{
if
(
!
is_kind
(
gather_node
,
"onnx::Gather"
))
return
false
;
auto
gather_axis
=
gather_node
->
i
(
Symbol
::
attr
(
"axis"
));
if
(
gather_axis
!=
0
)
return
false
;
}
auto
x
=
match
.
values_map
.
at
(
map
.
at
(
"x"
));
// check shape
auto
shape_node
=
gather_inputs
[
0
]
->
node
();
{
if
(
!
is_kind
(
shape_node
,
"onnx::Shape"
))
return
false
;
if
(
shape_node
->
input
()
!=
x
)
return
false
;
}
// check constant
auto
const_node
=
gather_inputs
[
1
]
->
node
();
{
if
(
!
is_kind
(
const_node
,
"onnx::Constant"
))
return
false
;
auto
ival
=
const_node
->
t
(
Symbol
::
attr
(
"value"
));
if
(
ival
.
dim
()
!=
0
)
return
false
;
auto
ival_dataptr
=
ival
.
data_ptr
<
int64_t
>
();
if
(
ival_dataptr
[
0
]
!=
0
)
return
false
;
}
// check if reshape is the output of the graph
auto
reshape_pattern
=
map
.
at
(
"reshape"
);
auto
reshape_node
=
match
.
values_map
.
at
(
reshape_pattern
);
auto
uses
=
reshape_node
->
uses
();
for
(
auto
use
:
uses
)
{
auto
user
=
use
.
user
;
if
(
is_kind
(
user
,
"prim::Return"
))
return
false
;
}
return
true
;
}
// from:
// x->shape->gather->unsqueeze->concat
// | |
// gap--------------------------reshape
//
// to:
// x->gap->flatten
void
FlattenClsHead
(
std
::
shared_ptr
<
Graph
>&
graph
)
{
std
::
string
pattern
=
R"IR(
graph(%x, %cat0, %cat1):
%gap = onnx::GlobalAveragePool(%x)
%cat = onnx::Concat[axis=0](%cat0, %cat1)
%reshape = onnx::Reshape(%gap, %cat)
return (%reshape)
)IR"
;
std
::
string
replacement
=
R"IR(
graph(%x, %cat0, %cat1):
%gap = onnx::GlobalAveragePool(%x)
%flatten = onnx::Flatten(%gap)
return (%flatten)
)IR"
;
torch
::
jit
::
SubgraphRewriter
subgraph_rewriter
;
subgraph_rewriter
.
RegisterRewritePattern
(
pattern
,
replacement
);
subgraph_rewriter
.
runOnGraph
(
graph
,
matchClsHead
);
torch
::
jit
::
EliminateDeadCode
(
graph
->
block
(),
true
,
torch
::
jit
::
DCESideEffectPolicy
::
ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS
);
}
}
// namespace torch_jit
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _FLATTEN_CLS_HEAD_H_
#define _FLATTEN_CLS_HEAD_H_
#include <torch/script.h>
namespace
mmdeploy
{
namespace
torch_jit
{
using
torch
::
jit
::
Graph
;
void
FlattenClsHead
(
std
::
shared_ptr
<
Graph
>&
graph
);
}
// namespace torch_jit
}
// namespace mmdeploy
#endif
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp
0 → 100644
View file @
546b4279
#include "fuse_select_assign.h"
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include "../../ir/subgraph_matcher.h"
#include "common_subgraph_elimination.h"
#include "torch/csrc/jit/ir/irparser.h"
namespace
mmdeploy
{
namespace
torch_jit
{
using
c10
::
Symbol
;
using
torch
::
jit
::
Block
;
using
torch
::
jit
::
IValue
;
using
torch
::
jit
::
Node
;
bool
RemoveBoolCast
(
Node
*
node
)
{
auto
bottom_node
=
node
->
input
()
->
node
();
if
(
bottom_node
->
kind
()
!=
Symbol
::
onnx
(
"Greater"
)
&&
bottom_node
->
kind
()
!=
Symbol
::
onnx
(
"Less"
))
{
return
false
;
}
node
->
output
()
->
replaceAllUsesWith
(
bottom_node
->
output
());
return
true
;
}
bool
FuseSelectAssign
(
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params
,
std
::
unordered_map
<
std
::
string
,
Value
*>&
vmap
,
SubgraphMatcher
&
matcher
)
{
auto
values_map
=
matcher
.
values_map
();
auto
cmp1
=
values_map
[
vmap
[
"cmp_1"
]]
->
node
();
auto
cmp2
=
values_map
[
vmap
[
"cmp_2"
]]
->
node
();
if
(
cmp1
!=
cmp2
)
{
// cmp_1 == cmp_2, cmp in (Great, Less)
if
(
cmp1
->
kind
()
!=
cmp2
->
kind
())
return
false
;
if
(
!
(
cmp1
->
kind
()
==
Symbol
::
onnx
(
"Greater"
)
||
cmp1
->
kind
()
==
Symbol
::
onnx
(
"Less"
)))
return
false
;
// check threshold
Node
*
cmps
[]
=
{
cmp1
,
cmp2
};
float
thres
=
0.0
f
;
Node
*
x
=
nullptr
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
auto
cmp
=
cmps
[
i
];
auto
threshold
=
cmp
->
inputs
()[
1
]
->
node
();
if
(
threshold
->
kind
()
!=
Symbol
::
onnx
(
"Constant"
))
return
false
;
auto
thres_val
=
threshold
->
t
(
Symbol
::
attr
(
"value"
));
if
(
i
==
0
)
{
thres
=
thres_val
.
data_ptr
<
float
>
()[
0
];
x
=
cmp
->
inputs
()[
0
]
->
node
();
}
else
{
float
tmp_val
=
thres_val
.
data_ptr
<
float
>
()[
0
];
if
(
fabs
(
thres
-
tmp_val
)
>
1e-10
)
{
return
false
;
}
if
(
x
!=
cmp
->
inputs
()[
0
]
->
node
())
{
return
false
;
}
}
}
}
{
// check shape of reshape
Node
*
shape
=
values_map
[
vmap
[
"reshape_1_shape"
]]
->
node
();
auto
shape_val
=
shape
->
t
(
Symbol
::
attr
(
"value"
));
if
(
shape_val
.
dim
()
!=
1
)
return
false
;
if
(
shape_val
.
data_ptr
<
int64_t
>
()[
0
]
!=
-
1
)
return
false
;
}
{
// check transpose
Node
*
trans
[]
=
{
values_map
[
vmap
[
"trans_1"
]]
->
node
(),
values_map
[
vmap
[
"trans_2"
]]
->
node
()};
for
(
auto
tran
:
trans
)
{
auto
tran_perm
=
tran
->
is
(
Symbol
::
attr
(
"perm"
));
if
(
tran_perm
.
size
()
!=
2
)
return
false
;
if
(
tran_perm
[
0
]
!=
1
||
tran_perm
[
1
]
!=
0
)
return
false
;
}
}
{
// check gather indice
Node
*
gather_inds
=
values_map
[
vmap
[
"gather_inds_2"
]]
->
node
();
auto
inds_val
=
gather_inds
->
t
(
Symbol
::
attr
(
"value"
));
if
(
inds_val
.
dim
()
!=
0
)
return
false
;
if
(
inds_val
.
data_ptr
<
int64_t
>
()[
0
]
!=
0
)
return
false
;
}
{
// check slice start
Node
*
slice
=
values_map
[
vmap
[
"slice_2"
]]
->
node
();
auto
start_name
=
slice
->
inputs
()[
1
]
->
debugName
();
auto
start_val
=
params
[
start_name
];
if
(
start_val
.
dim
()
!=
1
)
return
false
;
if
(
start_val
.
data_ptr
<
int64_t
>
()[
0
]
!=
0
)
return
false
;
}
// create new node
auto
graph
=
node
->
owningGraph
();
auto
z
=
values_map
[
vmap
[
"z"
]];
auto
y
=
values_map
[
vmap
[
"y"
]];
auto
where_node
=
graph
->
create
(
Symbol
::
onnx
(
"Where"
),
{
cmp1
->
output
(),
z
,
y
});
where_node
->
insertBefore
(
node
);
where_node
->
output
()
->
copyMetadata
(
node
->
output
());
node
->
output
()
->
replaceAllUsesWith
(
where_node
->
output
());
return
true
;
}
void
FuseSelectAssign
(
Block
*
block
,
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params
,
std
::
unordered_map
<
std
::
string
,
Value
*>&
vmap
,
SubgraphMatcher
&
matcher
)
{
auto
graph
=
block
->
owningGraph
();
auto
it
=
block
->
nodes
().
begin
();
while
(
it
!=
block
->
nodes
().
end
())
{
auto
node
=
*
it
;
++
it
;
for
(
auto
block
:
node
->
blocks
())
{
FuseSelectAssign
(
block
,
params
,
vmap
,
matcher
);
}
if
(
node
->
kind
()
==
Symbol
::
onnx
(
"Cast"
)
&&
node
->
i
(
Symbol
::
attr
(
"to"
))
==
9
)
{
RemoveBoolCast
(
node
);
}
else
if
(
matcher
.
matchesSubgraphFromAnchorNode
(
node
))
{
FuseSelectAssign
(
node
,
params
,
vmap
,
matcher
);
}
}
}
void
FuseSelectAssign
(
std
::
shared_ptr
<
Graph
>&
graph
,
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params
)
{
// cse before search
CommonSubgraphElimination
(
graph
,
params
);
std
::
string
pattern_str
=
R"IR(
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes, %shape_2):
%nz_1 = onnx::NonZero(%cmp_1)
%trans_1 = onnx::Transpose(%nz_1)
%gather_1 = onnx::GatherND(%z, %trans_1)
%reshape_1_shape = onnx::Constant()
%reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape)
%expand_2 = onnx::Expand(%cmp_2, %shape_2)
%nz_2 = onnx::NonZero(%expand_2)
%trans_2 = onnx::Transpose(%nz_2)
%trans_shape_2 = onnx::Shape(%trans_2)
%gather_inds_2 = onnx::Constant()
%gather_2 = onnx::Gather(%trans_shape_2, %gather_inds_2)
%unsqueeze_2 = onnx::Unsqueeze(%gather_2)
%slice_2 = onnx::Slice(%reshape_1, %start, %unsqueeze_2, %axes)
%scatter_2 = onnx::ScatterND(%y, %trans_2, %slice_2)
return (%scatter_2)
)IR"
;
Graph
pattern
;
std
::
unordered_map
<
std
::
string
,
Value
*>
vmap
;
torch
::
jit
::
parseIR
(
pattern_str
,
&
pattern
,
vmap
);
SubgraphMatcher
matcher
(
pattern
,
MatchAttribute
::
NO_MATCH
);
FuseSelectAssign
(
graph
->
block
(),
params
,
vmap
,
matcher
);
torch
::
jit
::
EliminateDeadCode
(
graph
->
block
(),
true
,
torch
::
jit
::
DCESideEffectPolicy
::
ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS
);
}
}
// namespace torch_jit
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _FUSE_SELECT_ASSIGN_H_
#define _FUSE_SELECT_ASSIGN_H_
#include <torch/script.h>
namespace
mmdeploy
{
namespace
torch_jit
{
using
torch
::
Tensor
;
using
torch
::
jit
::
Graph
;
// this pass is used to fuse y[x>thres] = z[x>thres]
void
FuseSelectAssign
(
std
::
shared_ptr
<
Graph
>&
graph
,
std
::
unordered_map
<
std
::
string
,
Tensor
>&
params
);
}
// namespace torch_jit
}
// namespace mmdeploy
#endif
Prev
1
…
10
11
12
13
14
15
16
17
18
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment