"launch/dynamo-run/src/input/common.rs" did not exist on "c06b95ffdbb2f2eac710e3baa157e98889e19263"
Commit 546b4279 authored by limm's avatar limm
Browse files

add csrc and mmdeploy module

parent 502f4fb9
Pipeline #2810 canceled with stages
// Copyright (c) OpenMMLab. All rights reserved.
#include <stdio.h>
#include <vector>
#include "common_cuda_helper.hpp"
#include "trt_plugin_helper.hpp"
using mmdeploy::TensorDesc;
template <typename T>
__global__ void onnx_scatternd_kernel(const int n, const int* indices, const T* update, T* output,
TensorDesc tensor_desc, TensorDesc indice_desc) {
const int indice_cols = indice_desc.shape[indice_desc.dim - 1];
const int copy_stride = tensor_desc.stride[indice_cols - 1];
const int* stride = &(tensor_desc.stride[0]);
CUDA_1D_KERNEL_LOOP(index, n) {
int output_offset = 0;
const int* indices_current = indices + index * indice_cols;
for (int i = 0; i < indice_cols; ++i) {
output_offset += stride[i] * indices_current[i];
}
memcpy(output + output_offset, update + index * copy_stride, copy_stride * sizeof(T));
}
}
template <typename T>
void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, const T* update,
const int* dims, int nbDims, const int* indices_dims,
int indice_nbDims, T* output, cudaStream_t stream) {
// fill tensordesc and initial
TensorDesc tensor_desc;
memset((void*)&tensor_desc, 0, sizeof(TensorDesc));
tensor_desc.dim = nbDims;
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
tensor_desc.stride[nbDims - 1] = 1;
for (int i = nbDims - 2; i >= 0; --i) {
tensor_desc.shape[i] = dims[i];
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
}
const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0];
TensorDesc indice_desc;
memset((void*)&indice_desc, 0, sizeof(TensorDesc));
indice_desc.dim = indice_nbDims;
indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1];
indice_desc.stride[indice_nbDims - 1] = 1;
for (int i = indice_nbDims - 2; i >= 0; --i) {
indice_desc.shape[i] = indices_dims[i];
indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1];
}
// output = np.copy(data)
cudaMemcpyAsync(output, data, data_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
int num_update_indice = 1;
for (int i = 0; i < indice_nbDims - 1; ++i) {
num_update_indice *= indice_desc.shape[i];
}
// scatter
const int col_block = DIVUP(num_update_indice, THREADS_PER_BLOCK);
onnx_scatternd_kernel<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
num_update_indice, indices, update, output, tensor_desc, indice_desc);
}
template void TRTONNXScatterNDKernelLauncher<float>(const float* data, const int* indices,
const float* update, const int* dims,
int nbDims, const int* indices_dims,
int indice_nbDims, float* output,
cudaStream_t stream);
template void TRTONNXScatterNDKernelLauncher<int>(const int* data, const int* indices,
const int* update, const int* dims, int nbDims,
const int* indices_dims, int indice_nbDims,
int* output, cudaStream_t stream);
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCATTERND_KERNEL_HPP
#define TRT_SCATTERND_KERNEL_HPP
#include <cuda_runtime.h>
template <typename T>
void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, const T* update,
const int* dims, int nbDims, const int* indices_dims,
int indice_nbDims, T* output, cudaStream_t stream);
#endif // TRT_SCATTERND_KERNEL_HPP
# Copyright (c) OpenMMLab. All rights reserved.
add_subdirectory(ops)
# Copyright (c) OpenMMLab. All rights reserved.
if("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
project(mmdeploy_torchscript_ops CUDA CXX)
file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp *.cu)
else()
project(mmdeploy_torchscript_ops CXX)
file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp)
endif()
find_package(Torch REQUIRED)
if(MSVC)
# workaround to fix building torchscript ops on windows
set(_TORCH_TARGET torch_cuda_cu torch_cuda_cpp torch_cpu)
foreach(_target IN LISTS _TORCH_TARGET)
if(TARGET ${_target})
get_property(FIXED_TORCH_CPU_COMPILE_OPTIONS TARGET ${_target} PROPERTY INTERFACE_COMPILE_OPTIONS)
string(REPLACE ";" " " FIXED_TORCH_CPU_COMPILE_OPTIONS "${FIXED_TORCH_CPU_COMPILE_OPTIONS}")
set_property(TARGET ${_target} PROPERTY INTERFACE_COMPILE_OPTIONS -Xcompiler "${FIXED_TORCH_CPU_COMPILE_OPTIONS}")
else()
message(WARNING "Target ${_target} not found.")
endif()
endforeach()
endif()
add_library(${PROJECT_NAME}_obj OBJECT "${BACKEND_OPS_SRCS}")
set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE 1)
target_compile_definitions(${PROJECT_NAME}_obj
PRIVATE -DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1)
target_include_directories(${PROJECT_NAME}_obj
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../common)
target_include_directories(${PROJECT_NAME}_obj
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common)
if("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
target_include_directories(${PROJECT_NAME}_obj
PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/include)
endif()
target_link_libraries(${PROJECT_NAME}_obj PRIVATE ${TORCH_LIBRARIES})
mmdeploy_export(${PROJECT_NAME}_obj)
# Build module library. It is used to inference with torchscript
mmdeploy_add_module(${PROJECT_NAME} MODULE EXCLUDE "")
target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_NAME}_obj)
add_library(mmdeploy::torchscript_ops ALIAS ${PROJECT_NAME})
set(_TORCHJIT_OPS_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/lib)
install(TARGETS ${PROJECT_NAME} DESTINATION ${_TORCHJIT_OPS_DIR})
// Copyright (c) OpenMMLab. All rights reserved.
#include "torch/script.h"
TORCH_LIBRARY(mmdeploy, m) {
m.def(
"modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor offset, Tensor "
"mask, "
"int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int "
"dilation_h,int dilation_w, int groups, int deform_groups, bool with_bias) -> Tensor")
.def(
"coreml_nms(Tensor boxes, Tensor scores, float iou_threshold, "
"float score_threshold, int max_boxes) -> Tensor[]");
}
#include <assert.h>
#include <vector>
#include "torch/script.h"
namespace mmdeploy {
using at::Tensor;
std::vector<Tensor> coreml_nms_cpu(Tensor boxes, Tensor scores, double iou_threshold,
double score_threshold, int64_t max_boxes) {
assert(boxes.dim() == 3); // bboxes with shape (batch_size, num_bboxes, 4)
assert(boxes.size(2) == 4);
assert(boxes.size(0) == scores.size(0)); // check batch size
assert(boxes.size(1) == scores.size(1)); // check num boxes
auto batch_size = boxes.size(0);
auto num_boxes = boxes.size(1);
auto num_classes = scores.size(2);
Tensor ret_boxes = at::zeros({batch_size, max_boxes, 4});
Tensor ret_scores = at::zeros({batch_size, max_boxes, num_classes});
Tensor indices = at::zeros({batch_size, max_boxes}, at::kInt);
Tensor num_outputs = at::zeros({batch_size}, at::kInt);
return std::vector<Tensor>({ret_boxes, ret_scores, indices, num_outputs});
}
TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) { m.impl("coreml_nms", coreml_nms_cpu); }
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#include "modulated_deform_conv/modulated_deform_conv_cpu.h"
#include "torch/script.h"
namespace mmdeploy {
void modulated_deformable_im2col_cpu(
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int64_t batch_size, const int64_t channels, const int64_t height_im,
const int64_t width_im, const int64_t height_col, const int64_t width_col,
const int64_t kernel_h, const int64_t kernel_w, const int64_t pad_h, const int64_t pad_w,
const int64_t stride_h, const int64_t stride_w, const int64_t dilation_h,
const int64_t dilation_w, int64_t deformable_group, at::Tensor data_col) {
// num_axes should be smaller than block size
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_cpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
deformable_im2col_2d<scalar_t>(data_im_, data_offset_, data_mask_, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channels, deformable_group,
height_col, width_col, data_mask_ != nullptr, data_col_);
}));
}
at::Tensor modulated_deform_conv_forward_cpu(at::Tensor input, at::Tensor weight, at::Tensor bias,
at::Tensor offset, at::Tensor mask, int64_t kernel_h,
int64_t kernel_w, int64_t stride_h, int64_t stride_w,
int64_t pad_h, int64_t pad_w, int64_t dilation_h,
int64_t dilation_w, int64_t group,
int64_t deformable_group, bool with_bias) {
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w,
kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels,
channels_kernel * group);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
// resize output
at::Tensor output =
at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options());
// resize temporary columns
at::Tensor columns = at::zeros(
{group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, input.options());
// divide into group
weight =
weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cpu(input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w,
stride_h, stride_w, dilation_h, dilation_w, deformable_group,
columns);
for (int g = 0; g < group; g++) {
output[b][g] =
output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]);
}
}
output = output.view(
{output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
return output;
}
TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) {
m.impl("modulated_deform_conv", modulated_deform_conv_forward_cpu);
}
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#include "c10/cuda/CUDAStream.h"
#include "modulated_deform_conv/modulated_deform_conv_cuda.cuh"
#include "torch/script.h"
namespace mmdeploy {
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int64_t batch_size, const int64_t channels, const int64_t height_im,
const int64_t width_im, const int64_t height_col, const int64_t width_col,
const int64_t kernel_h, const int64_t kernel_w, const int64_t pad_h, const int64_t pad_w,
const int64_t stride_h, const int64_t stride_w, const int64_t dilation_h,
const int64_t dilation_w, const int64_t deformable_group, at::Tensor data_col) {
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_cuda", ([&] {
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
modulated_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h,
kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, batch_size, channels, deformable_group, height_col,
width_col, data_col_);
}));
}
at::Tensor modulated_deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, at::Tensor bias,
at::Tensor offset, at::Tensor mask, int64_t kernel_h,
int64_t kernel_w, int64_t stride_h, int64_t stride_w,
int64_t pad_h, int64_t pad_w, int64_t dilation_h,
int64_t dilation_w, int64_t group,
int64_t deformable_group, bool with_bias) {
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w,
kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels,
channels_kernel * group);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
// resize output
at::Tensor output =
at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options());
// resize temporary columns
at::Tensor columns = at::zeros(
{group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, input.options());
// divide into group
weight =
weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w,
stride_h, stride_w, dilation_h, dilation_w, deformable_group,
columns);
for (int g = 0; g < group; g++) {
output[b][g] =
output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]);
}
}
output = output.view(
{output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
return output;
}
TORCH_LIBRARY_IMPL(mmdeploy, CUDA, m) {
m.impl("modulated_deform_conv", modulated_deform_conv_forward_cuda);
}
} // namespace mmdeploy
# Copyright (c) OpenMMLab. All rights reserved.
project(ts_optimizer)
find_package(Torch REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
if (NOT TARGET pybind11)
add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11)
endif ()
file(GLOB_RECURSE OPTIMIZER_SRCS *.cpp)
pybind11_add_module(${PROJECT_NAME} ${OPTIMIZER_SRCS})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
target_link_directories(${PROJECT_NAME} PRIVATE mmdeploy::torchscript_ops)
set_target_properties(
${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
${CMAKE_SOURCE_DIR}/mmdeploy/backend/torchscript)
// Copyright (c) OpenMMLab. All rights reserved.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <string>
#include "optimizer.h"
#include "passes/onnx/common_subgraph_elimination.h"
#include "passes/onnx/flatten_cls_head.h"
#include "passes/onnx/fuse_select_assign.h"
#include "passes/onnx/merge_shape_concate.h"
#include "passes/onnx/onnx_peephole.h"
namespace mmdeploy {
namespace torch_jit {
void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript",
const std::string& backend = "torchscript") {
if (ir == "torchscript") {
model = optimize_for_torchscript(model);
} else if (ir == "onnx") {
model = optimize_for_onnx(model);
} else {
fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(),
backend.c_str());
exit(-1);
}
}
PYBIND11_MODULE(ts_optimizer, m) {
namespace py = pybind11;
m.def("optimize_for_backend", optimize_for_backend, py::arg("module"),
py::arg("ir") = std::string("torchscript"),
py::arg("backend") = std::string("torchscript"));
py::module_ onnx_module = m.def_submodule("onnx");
onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph"));
onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph"));
onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph"));
onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"),
py::arg("params"));
onnx_module.def("_jit_pass_common_subgraph_elimination", CommonSubgraphElimination,
py::arg("graph"), py::arg("params"));
}
} // namespace torch_jit
} // namespace mmdeploy
// modify from:
// https://github.com/pytorch/pytorch/blob/v1.8.1/torch/csrc/jit/ir/subgraph_matcher.cpp
#include "subgraph_matcher.h"
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/jit_log.h>
#include <regex>
#include <stack>
namespace mmdeploy {
namespace torch_jit {
using torch::jit::AttributeKind;
using torch::jit::ClassType;
using torch::jit::Node;
using torch::jit::Symbol;
using torch::jit::Value;
namespace prim {
using namespace ::c10::prim;
}
namespace attr {
using namespace ::c10::attr;
}
/**
* \brief A class implementing an API for comparing subgraphs.
*/
class SubgraphMatcher::SubgraphMatcherImpl {
public:
explicit SubgraphMatcherImpl(const Graph& pattern, MatchAttribute match_attribute)
: pattern_(pattern), match_attribute_(match_attribute) {}
/**
* \brief Compare matchGraph with the part of the graph denoted by a node \p
* ANCHOR.
*
* The anchor node would be compared against the deepest node in the
* match-graph. A node is considered matching if its number of inputs/outputs
* is the same as in the corresponding matchGraph node, its type is the same,
* and all nodes producing input-values also match.
*/
bool matchesSubgraphFromAnchorNode(Node* anchor);
/** \brief Return match map for nodes. */
std::unordered_map<const Node*, Node*> nodes_map() const { return nodes_map_; }
/** \brief Return match map for values. */
std::unordered_map<const Value*, Value*> values_map() const { return values_map_; }
private:
bool matchValues(const Value* v1, Value* v2);
bool matchNodes(const Node* n1, Node* n2);
bool matchAttributes(const Node* n1, Node* n2);
static bool isInput(const Value* v);
static bool isOutput(const Value* v);
std::unordered_map<const Node*, Node*> nodes_map_;
std::unordered_map<const Value*, Value*> values_map_;
const MatchAttribute match_attribute_;
const Graph& pattern_;
const Node* anchor_ = nullptr;
};
bool SubgraphMatcher::SubgraphMatcherImpl::isInput(const Value* v) {
return v->node()->kind() == prim::Param;
}
bool SubgraphMatcher::SubgraphMatcherImpl::isOutput(const Value* v) {
for (const Value* output : v->owningGraph()->outputs()) {
if (v == output) {
return true;
}
}
return false;
}
/**
* Compare two Values. V1 is from pattern, V2 is from the actual graph.
*
* The values are considered matching if:
* 1) the nodes defining them match
* 2) they have the same number of uses, except they are entry or exit nodes.
*/
bool SubgraphMatcher::SubgraphMatcherImpl::matchValues(const Value* v1, Value* v2) {
// Check if we've already visited these values.
if (values_map_.count(v1)) {
if (values_map_.at(v1) != v2) {
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(),
" did not match because %", v1->debugName(), " has already been matched with %",
values_map_.at(v1)->debugName(), ".\n");
return false;
}
return true;
}
// When V2 is ANCHOR, we're comparing exiting values, and when V1->node is
// PARAM, we're comparing entering values - in these two cases the number of
// uses don't need to be the same.
if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) {
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(),
" did not match because number of their uses is different.\n");
return false;
}
// Add the values to the map before calling matchNodes to avoid infinite
// recursion.
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n");
values_map_[v1] = v2;
return matchNodes(v1->node(), v2->node());
}
bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node* n2) {
if (match_attribute_ == FORCE_MATCH && n1->numAttributes() != n2->numAttributes()) {
GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2);
return false;
}
for (const Symbol& attr_name : n1->attributeNames()) {
if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(),
"' did not match:\n", *n1, *n2);
return false;
}
std::vector<int64_t> n1is, n2is;
std::vector<double> n1fs, n2fs;
switch (n1->kindOf(attr_name)) {
case AttributeKind::s:
if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) {
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
"' did not match: ", n1->s(attr_name), " != ", n2->s(attr_name), " \n", *n1,
*n2);
return false;
}
break;
case AttributeKind::f:
if (n1->f(attr_name) != n2->f(attr_name)) {
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
"' did not match:", n1->f(attr_name), " != ", n2->f(attr_name), " \n", *n1,
*n2);
return false;
}
break;
case AttributeKind::i:
if (n1->i(attr_name) != n2->i(attr_name)) {
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
"' did not match:", n1->i(attr_name), " != ", n2->i(attr_name), " \n", *n1,
*n2);
return false;
}
break;
case AttributeKind::is:
n1is = n1->is(attr_name);
n2is = n2->is(attr_name);
if (n1is.size() != n2is.size()) return false;
for (size_t i = 0; i < n1is.size(); ++i) {
if (n1is[i] != n2is[i]) return false;
}
break;
case AttributeKind::fs:
n1fs = n1->fs(attr_name);
n2fs = n2->fs(attr_name);
if (n1fs.size() != n2fs.size()) return false;
for (size_t i = 0; i < n1fs.size(); ++i) {
if (n1fs[i] != n2fs[i]) return false;
}
break;
default: {
// Other attributes types not supported yet
GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(),
"' is not supported.\n", *n1, *n2);
return false;
}
}
}
return true;
}
static bool endsWith(const std::string& str, const std::string& suffix) {
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
/**
* Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
*
* The nodes are considered matching if:
* 1) N1 and N2 are of the same kind.
* 2) Number of inputs and outputs is the same.
* 3) All input and output values match.
*
* A special case is when N1 is PARAM - this is considered outside the pattern,
* so it matches everything.
*/
bool SubgraphMatcher::SubgraphMatcherImpl::matchNodes(const Node* n1, Node* n2) {
// Check if we've already visited these nodes.
if (nodes_map_.count(n1)) {
return nodes_map_.at(n1) == n2;
}
// Param node in pattern graph matches everything.
if (n1->kind() == prim::Param) {
GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
return true;
}
// We don't allow matches to span across blocks, so check if N2 is in the same
// block as the first (anchor) node.
if (n2->owningBlock() != anchor_->owningBlock()) {
GRAPH_DEBUG("Nodes did not match because it is in the different block:\n", *n1, *n2);
return false;
}
// Special handling for matching modules
if (n1->kind() == Symbol::fromQualString("match::module")) {
if (n2->kind() == prim::GetAttr) {
if (!n1->hasAttributeS("name")) {
GRAPH_DEBUG(
"Nodes did not match because special node match::module does not have 'name' "
"attribute:\n",
*n1, *n2);
return false;
}
auto t = n2->output()->type()->expect<ClassType>();
auto real_typename = t->name()->qualifiedName();
auto pattern_typename = n1->s(attr::name);
if (!endsWith(real_typename, pattern_typename)) {
GRAPH_DEBUG("Nodes did not match because expected module type is different:\n");
GRAPH_DEBUG(" actualtype: ", real_typename, "\n");
GRAPH_DEBUG(" expected type: ", pattern_typename, "\n");
GRAPH_DEBUG("Nodes:", *n1, *n2);
return false;
}
}
} else {
if (n1->kind() != n2->kind() || n1->outputs().size() != n2->outputs().size() ||
n1->inputs().size() != n2->inputs().size()) {
GRAPH_DEBUG("Nodes did not match in their kind or number of inputs/outputs:\n", *n1, *n2);
return false;
}
if (match_attribute_ != NO_MATCH) {
if (!matchAttributes(n1, n2)) {
return false;
}
}
}
// Add nodes to the map before calling matchValues to avoid infinite
// recursion.
nodes_map_[n1] = n2;
for (const auto i : c10::irange(n1->outputs().size())) {
if (!matchValues(n1->outputs()[i], n2->outputs()[i])) {
return false;
}
}
for (const auto i : c10::irange(n1->inputs().size())) {
if (!matchValues(n1->inputs()[i], n2->inputs()[i])) {
return false;
}
}
GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
return true;
}
/**
* Recursively try to match pattern with the actual graph starting from the
* exiting node in the pattern and anchor node in the actual graph.
*/
bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* anchor) {
GRAPH_UPDATE("Starting match from a new anchor: ", *anchor);
nodes_map_.clear();
values_map_.clear();
anchor_ = anchor;
const Node* bottom_node = *(pattern_.nodes().end());
bottom_node = bottom_node->input(0)->node();
if (!matchNodes(bottom_node, anchor)) {
return false;
}
for (const Value* output : pattern_.outputs()) {
AT_ASSERT(values_map_.count(output));
}
GRAPH_UPDATE("Pattern matched!\n");
return true;
}
SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute)
: impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {}
SubgraphMatcher::~SubgraphMatcher() = default;
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
return impl_->matchesSubgraphFromAnchorNode(anchor);
}
std::unordered_map<const Node*, Node*> SubgraphMatcher::nodes_map() const {
return impl_->nodes_map();
}
std::unordered_map<const Value*, Value*> SubgraphMatcher::values_map() const {
return impl_->values_map();
}
} // namespace torch_jit
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _SUBGRAPH_MATCHER_H_
#define _SUBGRAPH_MATCHER_H_
#include <torch/script.h>
#include <memory>
namespace mmdeploy {
namespace torch_jit {
using torch::jit::Graph;
using torch::jit::Node;
using torch::jit::Value;
enum MatchAttribute { FORCE_MATCH, TRY_MATCH, NO_MATCH };
class SubgraphMatcher {
public:
explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH);
~SubgraphMatcher();
bool matchesSubgraphFromAnchorNode(Node* anchor);
/** \brief Return match map for nodes. */
std::unordered_map<const Node*, Node*> nodes_map() const;
/** \brief Return match map for values. */
std::unordered_map<const Value*, Value*> values_map() const;
private:
class SubgraphMatcherImpl;
std::unique_ptr<SubgraphMatcherImpl> impl_;
};
} // namespace torch_jit
} // namespace mmdeploy
#endif
// Copyright (c) OpenMMLab. All rights reserved.
#include "optimizer.h"
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/remove_expands.h>
#if TORCH_VERSION_MINOR >= 9
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
#endif
namespace mmdeploy {
using torch::jit::Graph;
const std::shared_ptr<Graph>& required_passes(const std::shared_ptr<Graph>& graph) {
RemoveExpands(graph);
CanonicalizeOps(graph);
EliminateDeadCode(graph);
return graph;
}
Module optimize_for_torchscript(const Module& model) {
auto frozen_model = freeze_module(model);
auto graph = frozen_model.get_method("forward").graph();
OptimizeFrozenGraph(graph, true);
#if TORCH_VERSION_MINOR >= 9
FuseFrozenConvAddRelu(graph);
ConvertFrozenOpsToMKLDNN(graph);
FrozenLinearTranspose(graph);
#endif
graph = required_passes(graph);
EliminateCommonSubexpression(graph);
PeepholeOptimize(graph);
ConstantPropagation(graph);
ConstantPooling(graph);
// TODO: add more custom passes
return frozen_model;
}
Module optimize_for_onnx(const Module& model) {
auto frozen_model = freeze_module(model, {"training"});
auto graph = frozen_model.get_method("forward").graph();
OptimizeFrozenGraph(graph, true);
#if TORCH_VERSION_MINOR >= 9
FuseFrozenConvAddRelu(graph);
ConvertFrozenOpsToMKLDNN(graph);
FrozenLinearTranspose(graph);
#endif
// TODO: add more custom passes
return frozen_model;
}
// TODO: add optimizer for other backend/onnx
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#include <torch/script.h>
namespace mmdeploy {
using torch::jit::script::Module;
Module optimize_for_torchscript(const Module &model);
Module optimize_for_onnx(const Module &model);
} // namespace mmdeploy
// https://github.com/pytorch/pytorch/blob/v1.8.1/torch/csrc/jit/passes/common_subexpression_elimination.cpp
#include "common_subgraph_elimination.h"
#include <torch/csrc/jit/ir/node_hashing.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
namespace mmdeploy {
namespace torch_jit {
using c10::Symbol;
using torch::jit::Block;
using torch::jit::EqualNode;
using torch::jit::HashNode;
using torch::jit::Node;
using torch::jit::Value;
struct EqualNodeWithParams {
EqualNodeWithParams(std::unordered_map<std::string, Tensor>& params) : params_(params) {}
bool operator()(const Node* lhs, const Node* rhs) const {
auto lhs_inputs = lhs->inputs();
auto rhs_inputs = rhs->inputs();
}
private:
std::unordered_map<std::string, Tensor>& params_;
};
struct CommonSubexpressionEliminator {
using ParamMapType = std::unordered_map<std::string, std::pair<Tensor, Value*>>;
CommonSubexpressionEliminator(std::shared_ptr<Graph> graph,
std::unordered_map<std::string, Tensor>& params)
: graph_(std::move(graph)), params_(params) {}
bool run(std::function<Node*(Node*)> parent_lookup_fn) {
ParamMapType param_map;
return run(graph_->block(), std::move(parent_lookup_fn), param_map);
}
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
// returns true if CSE made changes to a graph
bool run(Block* block, std::function<Node*(Node*)> parent_lookup_fn, ParamMapType& param_map) {
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
bool changed = false;
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
auto node = *it;
// check if inputs come from params(graph input)
auto node_inputs = node->inputs();
for (auto input : node_inputs) {
if (input->node()->kind() == Symbol::fromQualString("prim::Param")) {
auto debug_name = input->debugName();
// check if input in params_
if (params_.find(debug_name) == params_.end()) continue;
// check if input is already visited.
if (param_map.find(debug_name) != param_map.end()) continue;
// check if there is a param has same value with input
auto val = params_[debug_name];
bool update_map = true;
for (auto kv : param_map) {
auto param_val = kv.second.first;
if (val.device() != param_val.device()) continue;
if (val.dtype() != param_val.dtype()) continue;
if (!val.equal(param_val)) continue;
input->replaceAllUsesWith(kv.second.second);
update_map = false;
break;
}
// add input to param_map
if (update_map) {
param_map.emplace(debug_name,
std::make_pair<Tensor, Value*>(std::move(val), std::move(input)));
}
}
}
if (!node->blocks().empty()) {
// Traverse sub-blocks.
for (auto block : node->blocks()) {
changed |= run(
block,
[&](Node* n) {
auto existing = subexprs.find(n);
if (existing != subexprs.end()) {
return *existing;
}
return parent_lookup_fn(n);
},
param_map);
}
continue;
}
// Check for CSE opportunities in the parent block.
auto parent_lookup = parent_lookup_fn(node);
auto g_out = node->owningGraph()->outputs();
if (parent_lookup != nullptr) {
changed = true;
node->replaceAllUsesWith(parent_lookup);
it.destroyCurrent();
continue;
}
// Check whether the same subexpression already exists.
auto subit = subexprs.insert(node);
if (!subit.second) {
// Subexpression exists, replace the uses of node, and destroy it.
auto existing = *subit.first;
changed = true;
node->replaceAllUsesWith(existing);
// Destroy the node.
it.destroyCurrent();
}
}
return changed;
}
private:
std::shared_ptr<Graph> graph_;
std::unordered_map<std::string, Tensor>& params_;
};
void CommonSubgraphElimination(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params) {
CommonSubexpressionEliminator cse(graph, params);
cse.run([](Node*) { return nullptr; });
}
} // namespace torch_jit
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _COMMON_SUBGRAPH_ELIMINATION_H_
#define _COMMON_SUBGRAPH_ELIMINATION_H_
#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using torch::Tensor;
using torch::jit::Graph;
// This pass is used eliminate the common subgraph.
// There are two main difference between the one in torch/csrc/jit/pass
// 1. AliasDb is not needed in ONNX model
// 2. params might also participated in the elimination
void CommonSubgraphElimination(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params);
} // namespace torch_jit
} // namespace mmdeploy
#endif
// Copyright (c) OpenMMLab. All rights reserved.
#include "flatten_cls_head.h"
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <vector>
#include "utils.h"
namespace mmdeploy {
namespace torch_jit {
using c10::Symbol;
using torch::jit::IValue;
using torch::jit::Match;
using torch::jit::TensorType;
using torch::jit::TypeKind;
using torch::jit::Value;
static bool matchClsHead(const Match& match, const std::unordered_map<std::string, Value*>& map) {
// TODO: check if value map in latest pytorch can ease the filter.
// check cat -1
{
// check if the shape of second inputs is 1
auto cat_v1 = match.values_map.at(map.at("cat1"));
if (cat_v1->type()->kind() != TypeKind::TensorType) return false;
auto cat_v1_type = cat_v1->type()->cast<TensorType>();
auto cat_v1_size = cat_v1_type->sizes().concrete_sizes();
if (!cat_v1_size.has_value()) return false;
IValue cat_v1_size_value(cat_v1_size.value());
auto size_list = cat_v1_size_value.toIntList();
if (size_list.size() != 1 || size_list[0] != 1) return false;
}
// check unsqueeze
auto cat_v0 = match.values_map.at(map.at("cat0"));
auto unsqueeze_node = cat_v0->node();
{
if (!is_kind(unsqueeze_node, "onnx::Unsqueeze")) return false;
auto unsqueeze_axes = unsqueeze_node->is(Symbol::attr("axes"));
if (unsqueeze_axes.size() != 1 || unsqueeze_axes[0] != 0) return false;
}
// check gather
auto gather_node = unsqueeze_node->input()->node();
auto gather_inputs = gather_node->inputs();
{
if (!is_kind(gather_node, "onnx::Gather")) return false;
auto gather_axis = gather_node->i(Symbol::attr("axis"));
if (gather_axis != 0) return false;
}
auto x = match.values_map.at(map.at("x"));
// check shape
auto shape_node = gather_inputs[0]->node();
{
if (!is_kind(shape_node, "onnx::Shape")) return false;
if (shape_node->input() != x) return false;
}
// check constant
auto const_node = gather_inputs[1]->node();
{
if (!is_kind(const_node, "onnx::Constant")) return false;
auto ival = const_node->t(Symbol::attr("value"));
if (ival.dim() != 0) return false;
auto ival_dataptr = ival.data_ptr<int64_t>();
if (ival_dataptr[0] != 0) return false;
}
// check if reshape is the output of the graph
auto reshape_pattern = map.at("reshape");
auto reshape_node = match.values_map.at(reshape_pattern);
auto uses = reshape_node->uses();
for (auto use : uses) {
auto user = use.user;
if (is_kind(user, "prim::Return")) return false;
}
return true;
}
// from:
// x->shape->gather->unsqueeze->concat
// | |
// gap--------------------------reshape
//
// to:
// x->gap->flatten
void FlattenClsHead(std::shared_ptr<Graph>& graph) {
std::string pattern = R"IR(
graph(%x, %cat0, %cat1):
%gap = onnx::GlobalAveragePool(%x)
%cat = onnx::Concat[axis=0](%cat0, %cat1)
%reshape = onnx::Reshape(%gap, %cat)
return (%reshape)
)IR";
std::string replacement = R"IR(
graph(%x, %cat0, %cat1):
%gap = onnx::GlobalAveragePool(%x)
%flatten = onnx::Flatten(%gap)
return (%flatten)
)IR";
torch::jit::SubgraphRewriter subgraph_rewriter;
subgraph_rewriter.RegisterRewritePattern(pattern, replacement);
subgraph_rewriter.runOnGraph(graph, matchClsHead);
torch::jit::EliminateDeadCode(
graph->block(), true,
torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
} // namespace torch_jit
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _FLATTEN_CLS_HEAD_H_
#define _FLATTEN_CLS_HEAD_H_
#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using torch::jit::Graph;
void FlattenClsHead(std::shared_ptr<Graph>& graph);
} // namespace torch_jit
} // namespace mmdeploy
#endif
#include "fuse_select_assign.h"
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include "../../ir/subgraph_matcher.h"
#include "common_subgraph_elimination.h"
#include "torch/csrc/jit/ir/irparser.h"
namespace mmdeploy {
namespace torch_jit {
using c10::Symbol;
using torch::jit::Block;
using torch::jit::IValue;
using torch::jit::Node;
bool RemoveBoolCast(Node* node) {
auto bottom_node = node->input()->node();
if (bottom_node->kind() != Symbol::onnx("Greater") &&
bottom_node->kind() != Symbol::onnx("Less")) {
return false;
}
node->output()->replaceAllUsesWith(bottom_node->output());
return true;
}
bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& params,
std::unordered_map<std::string, Value*>& vmap, SubgraphMatcher& matcher) {
auto values_map = matcher.values_map();
auto cmp1 = values_map[vmap["cmp_1"]]->node();
auto cmp2 = values_map[vmap["cmp_2"]]->node();
if (cmp1 != cmp2) {
// cmp_1 == cmp_2, cmp in (Great, Less)
if (cmp1->kind() != cmp2->kind()) return false;
if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less")))
return false;
// check threshold
Node* cmps[] = {cmp1, cmp2};
float thres = 0.0f;
Node* x = nullptr;
for (int i = 0; i < 2; ++i) {
auto cmp = cmps[i];
auto threshold = cmp->inputs()[1]->node();
if (threshold->kind() != Symbol::onnx("Constant")) return false;
auto thres_val = threshold->t(Symbol::attr("value"));
if (i == 0) {
thres = thres_val.data_ptr<float>()[0];
x = cmp->inputs()[0]->node();
} else {
float tmp_val = thres_val.data_ptr<float>()[0];
if (fabs(thres - tmp_val) > 1e-10) {
return false;
}
if (x != cmp->inputs()[0]->node()) {
return false;
}
}
}
}
{
// check shape of reshape
Node* shape = values_map[vmap["reshape_1_shape"]]->node();
auto shape_val = shape->t(Symbol::attr("value"));
if (shape_val.dim() != 1) return false;
if (shape_val.data_ptr<int64_t>()[0] != -1) return false;
}
{
// check transpose
Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()};
for (auto tran : trans) {
auto tran_perm = tran->is(Symbol::attr("perm"));
if (tran_perm.size() != 2) return false;
if (tran_perm[0] != 1 || tran_perm[1] != 0) return false;
}
}
{
// check gather indice
Node* gather_inds = values_map[vmap["gather_inds_2"]]->node();
auto inds_val = gather_inds->t(Symbol::attr("value"));
if (inds_val.dim() != 0) return false;
if (inds_val.data_ptr<int64_t>()[0] != 0) return false;
}
{
// check slice start
Node* slice = values_map[vmap["slice_2"]]->node();
auto start_name = slice->inputs()[1]->debugName();
auto start_val = params[start_name];
if (start_val.dim() != 1) return false;
if (start_val.data_ptr<int64_t>()[0] != 0) return false;
}
// create new node
auto graph = node->owningGraph();
auto z = values_map[vmap["z"]];
auto y = values_map[vmap["y"]];
auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y});
where_node->insertBefore(node);
where_node->output()->copyMetadata(node->output());
node->output()->replaceAllUsesWith(where_node->output());
return true;
}
void FuseSelectAssign(Block* block, std::unordered_map<std::string, Tensor>& params,
std::unordered_map<std::string, Value*>& vmap, SubgraphMatcher& matcher) {
auto graph = block->owningGraph();
auto it = block->nodes().begin();
while (it != block->nodes().end()) {
auto node = *it;
++it;
for (auto block : node->blocks()) {
FuseSelectAssign(block, params, vmap, matcher);
}
if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) {
RemoveBoolCast(node);
} else if (matcher.matchesSubgraphFromAnchorNode(node)) {
FuseSelectAssign(node, params, vmap, matcher);
}
}
}
void FuseSelectAssign(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params) {
// cse before search
CommonSubgraphElimination(graph, params);
std::string pattern_str = R"IR(
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes, %shape_2):
%nz_1 = onnx::NonZero(%cmp_1)
%trans_1 = onnx::Transpose(%nz_1)
%gather_1 = onnx::GatherND(%z, %trans_1)
%reshape_1_shape = onnx::Constant()
%reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape)
%expand_2 = onnx::Expand(%cmp_2, %shape_2)
%nz_2 = onnx::NonZero(%expand_2)
%trans_2 = onnx::Transpose(%nz_2)
%trans_shape_2 = onnx::Shape(%trans_2)
%gather_inds_2 = onnx::Constant()
%gather_2 = onnx::Gather(%trans_shape_2, %gather_inds_2)
%unsqueeze_2 = onnx::Unsqueeze(%gather_2)
%slice_2 = onnx::Slice(%reshape_1, %start, %unsqueeze_2, %axes)
%scatter_2 = onnx::ScatterND(%y, %trans_2, %slice_2)
return (%scatter_2)
)IR";
Graph pattern;
std::unordered_map<std::string, Value*> vmap;
torch::jit::parseIR(pattern_str, &pattern, vmap);
SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH);
FuseSelectAssign(graph->block(), params, vmap, matcher);
torch::jit::EliminateDeadCode(
graph->block(), true,
torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
} // namespace torch_jit
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _FUSE_SELECT_ASSIGN_H_
#define _FUSE_SELECT_ASSIGN_H_
#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using torch::Tensor;
using torch::jit::Graph;
// this pass is used to fuse y[x>thres] = z[x>thres]
void FuseSelectAssign(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params);
} // namespace torch_jit
} // namespace mmdeploy
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment