Commit 546b4279 authored by limm's avatar limm
Browse files

add csrc and mmdeploy module

parent 502f4fb9
Pipeline #2810 canceled with stages
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_ARCHIVE_JSON_ARCHIVE_H_
#define MMDEPLOY_SRC_ARCHIVE_JSON_ARCHIVE_H_
#include "json.hpp"
#include "mmdeploy/core/archive.h"
#include "mmdeploy/core/value.h"
namespace mmdeploy {
namespace detail {
template <typename T>
nlohmann::json to_json_impl(T&& val);
inline nlohmann::json value_to_json(const Value& value) {
switch (value.type()) {
case ValueType::kNull:
return {};
case ValueType::kBool:
return value.get<bool>();
case ValueType::kInt:
return value.get<int64_t>();
case ValueType::kUInt:
return value.get<uint64_t>();
case ValueType::kFloat:
return value.get<double>();
case ValueType::kString:
return value.get<std::string>();
case ValueType::kArray: {
nlohmann::json json = nlohmann::json::value_t::array;
for (const auto& x : value) {
json.push_back(value_to_json(x));
}
return json;
}
case ValueType::kObject: {
nlohmann::json json = nlohmann::json::value_t::object;
for (auto it = value.begin(); it != value.end(); ++it) {
auto key = it.key();
json[key] = value_to_json(*it);
}
return json;
}
case ValueType::kAny:
return "<any>";
default:
return "<unknown>";
}
}
} // namespace detail
template <typename T, std::enable_if_t<!is_value_v<uncvref_t<T>>, int> = 0>
nlohmann::json to_json(T&& val) {
return detail::to_json_impl(std::forward<T>(val));
}
inline nlohmann::json to_json(const Value& value) { return detail::value_to_json(value); }
// save to JSON
class JsonOutputArchive : public OutputArchive<JsonOutputArchive> {
public:
explicit JsonOutputArchive(nlohmann::json& data) : data_(data) {}
void init(...) {}
template <typename T>
void named_value(const std::string& name, T&& val) {
data_[name] = to_json(std::forward<T>(val));
}
template <typename T>
void item(T&& val) {
data_.push_back(to_json(std::forward<T>(val)));
}
template <typename T, typename V = uncvref_t<T>,
std::enable_if_t<
std::disjunction_v<std::is_arithmetic<V>, std::is_same<V, const char*>,
std::is_same<V, std::string>, std::is_same<V, nlohmann::json>>,
int> = 0>
void native(T&& val) {
data_ = std::forward<T>(val);
}
private:
nlohmann::json& data_;
};
namespace detail {
template <typename T>
inline nlohmann::json to_json_impl(T&& val) {
nlohmann::json json;
JsonOutputArchive archive(json);
archive(std::forward<T>(val));
return json;
}
} // namespace detail
namespace detail {
inline Value json_to_value(const nlohmann::json& json) {
using value_t = nlohmann::json::value_t;
switch (json.type()) {
case value_t::null:
return {};
case value_t::boolean:
return json.get<bool>();
case value_t::number_integer:
return json.get<int64_t>();
case value_t::number_unsigned:
return json.get<uint64_t>();
case value_t::number_float:
return json.get<double>();
case value_t::string:
return json.get<std::string>();
case value_t::array: {
Value value = ValueType::kArray;
for (const auto& x : json) {
value.push_back(json_to_value(x));
}
return value;
}
case value_t::object: {
Value value = ValueType::kObject;
for (const auto& proxy : json.items()) {
value[proxy.key()] = json_to_value(proxy.value());
}
return value;
}
default:
MMDEPLOY_ERROR("unsupported json type: {}", json.type_name());
return {};
}
}
template <typename T>
void from_json_impl(const nlohmann::json& json, T&& val);
} // namespace detail
template <typename T, std::enable_if_t<!std::is_same_v<Value, uncvref_t<T>>, int> = 0>
void from_json(const nlohmann::json& json, T&& val) {
detail::from_json_impl(json, std::forward<T>(val));
}
inline void from_json(const nlohmann::json& json, Value& val) { val = detail::json_to_value(json); }
template <typename T>
T from_json(const nlohmann::json& json);
// load from JSON
class JsonInputArchive : public InputArchive<JsonInputArchive> {
public:
explicit JsonInputArchive(const nlohmann::json& data) : data_(data) {}
template <typename SizeType>
void init(SizeType& size) {
size = static_cast<SizeType>(data_.size());
iter_ = data_.begin();
}
template <typename T>
void named_value(std::string& name, T& val) {
name = iter_.key();
from_json(*iter_++, std::forward<T>(val));
}
template <typename T>
void named_value(const std::string& name, T&& val) {
from_json(data_[name], std::forward<T>(val));
}
template <typename T>
void item(T&& val) {
from_json(*iter_++, std::forward<T>(val));
}
template <typename T>
void native(T&& val) {
data_.get_to(val);
}
private:
const nlohmann::json& data_;
nlohmann::json::const_iterator iter_;
};
namespace detail {
template <typename T>
inline void from_json_impl(const nlohmann::json& json, T&& val) {
JsonInputArchive archive(json);
archive(std::forward<T>(val));
}
} // namespace detail
template <typename T>
inline T from_json(const nlohmann::json& json) {
T val{};
from_json(json, val);
return val;
}
void from_json(const nlohmann::json& json, Value& val);
} // namespace mmdeploy
#endif // MMDEPLOY_SRC_ARCHIVE_JSON_ARCHIVE_H_
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_ARCHIVE_VALUE_ARCHIVE_H_
#define MMDEPLOY_SRC_ARCHIVE_VALUE_ARCHIVE_H_
#include "mmdeploy/core/archive.h"
#include "mmdeploy/core/value.h"
namespace mmdeploy {
template <typename T>
Value to_value(T&& val);
// save to Value
class ValueOutputArchive : public OutputArchive<ValueOutputArchive> {
public:
explicit ValueOutputArchive(Value& data) : data_(data) {}
template <typename T>
void init(array_tag<T>) {
data_ = ValueType::kArray;
}
template <typename T>
void init(object_tag<T>) {
data_ = ValueType::kObject;
}
template <typename T>
void named_value(const std::string& name, T&& val) {
data_[name] = to_value(std::forward<T>(val));
}
template <typename T>
void item(T&& val) {
data_.push_back(to_value(std::forward<T>(val)));
}
template <typename T, std::enable_if_t<std::is_constructible_v<Value, T>, int> = 0>
void native(T&& val) {
data_ = std::forward<T>(val);
};
private:
Value& data_;
};
template <typename T>
inline Value to_value(T&& val) {
Value value;
ValueOutputArchive archive(value);
archive(std::forward<T>(val));
return value;
}
// fast path
inline Value to_value(const Value& v) { return v; }
inline Value to_value(Value&& v) { return std::move(v); }
template <typename T>
void from_value(const Value& value, T&& x);
template <typename T>
T from_value(const Value& value);
// load from Value
class ValueInputArchive : public InputArchive<ValueInputArchive> {
public:
explicit ValueInputArchive(const Value& data) : data_(data) {}
template <typename SizeType>
void init(SizeType& size) {
size = static_cast<SizeType>(data_.size());
iter_ = data_.begin();
}
template <typename T>
void named_value(std::string& name, T& val) {
name = iter_.key();
from_value(*iter_, std::forward<T>(val));
++iter_;
}
template <typename T>
void named_value(const std::string& name, T&& val) {
from_value(data_[name], std::forward<T>(val));
}
template <typename T>
void item(T&& val) {
from_value(*iter_, std::forward<T>(val));
++iter_;
}
template <typename T>
void native(T&& val) {
data_.get_to(val);
}
template <typename T>
void value(T&& value) {}
private:
const Value& data_;
Value::const_iterator iter_;
};
template <typename T>
void from_value(const Value& value, T&& x) {
ValueInputArchive archive(value);
archive(std::forward<T>(x));
}
// Required to avoid Value::Pointer being unwrapped by Value::get_to()
inline void from_value(const Value& value, Value& x) { x = value; }
template <typename T>
inline T from_value(const Value& value) {
T x{};
from_value(value, x);
return x;
}
namespace detail {
inline void load(ValueInputArchive& archive, Value& v) { archive.native(v); }
template <class T, std::enable_if_t<std::is_same<std::decay_t<T>, Value>::value, bool> = true>
inline void save(ValueOutputArchive& archive, T&& v) {
archive.native(std::forward<T>(v));
}
} // namespace detail
} // namespace mmdeploy
#endif // MMDEPLOY_SRC_ARCHIVE_VALUE_ARCHIVE_H_
if (NOT MSVC)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
endif ()
# build ONNXRUNTIME ops
if ("ort" IN_LIST MMDEPLOY_TARGET_BACKENDS)
if (NOT DEFINED ONNXRUNTIME_DIR)
set(ONNXRUNTIME_DIR $ENV{ONNXRUNTIME_DIR})
endif ()
if (NOT ONNXRUNTIME_DIR)
message(FATAL_ERROR " ONNXRUNTIME_DIR is not found.")
else ()
message(STATUS "Build ONNXRUNTIME custom ops.")
add_subdirectory(onnxruntime)
endif ()
endif ()
# build TensorRT ops
if ("trt" IN_LIST MMDEPLOY_TARGET_BACKENDS)
if (NOT DEFINED TENSORRT_DIR)
set(TENSORRT_DIR $ENV{TENSORRT_DIR})
endif ()
message(STATUS "Build TensorRT custom ops.")
add_subdirectory(tensorrt)
endif ()
# build ncnn ops
if ("ncnn" IN_LIST MMDEPLOY_TARGET_BACKENDS)
message(STATUS "Build ncnn custom ops")
add_subdirectory(ncnn)
endif ()
# build TorchScript ops
if ("torchscript" IN_LIST MMDEPLOY_TARGET_BACKENDS
OR "coreml" IN_LIST MMDEPLOY_TARGET_BACKENDS)
message(STATUS "Build torchscript custom ops")
add_subdirectory(torchscript)
endif ()
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef COMMON_CUDA_HELPER
#define COMMON_CUDA_HELPER
#include <cublas_v2.h>
#include <cuda.h>
#include <algorithm>
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define THREADS_PER_BLOCK 512
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
inline int GET_BLOCKS(const int N) {
int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK);
int max_block_num = 4096;
return std::min(optimal_block_num, max_block_num);
}
#define cudaCheckError() \
{ \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(0); \
} \
}
/**
* Returns a view of the original tensor with its dimensions permuted.
*
* @param[out] dst pointer to the destination tensor
* @param[in] src pointer to the source tensor
* @param[in] src_size shape of the src tensor
* @param[in] permute The desired ordering of dimensions
* @param[in] src_dim dim of src tensor
* @param[in] stream cuda stream handle
*/
template <class scalar_t>
void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim,
cudaStream_t stream = 0);
template <typename scalar_t>
cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha,
const scalar_t* A, int lda, const scalar_t* B, int ldb,
const scalar_t* beta, scalar_t* C, int ldc);
template <typename scalar_t>
__device__ scalar_t bilinear_interpolate(const scalar_t* input, const int height, const int width,
scalar_t y, scalar_t x) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
if (y <= 0) y = 0;
if (x <= 0) x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (scalar_t)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (scalar_t)x_low;
} else {
x_high = x_low + 1;
}
scalar_t ly = y - y_low;
scalar_t lx = x - x_low;
scalar_t hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
scalar_t v1 = input[y_low * width + x_low];
scalar_t v2 = input[y_low * width + x_high];
scalar_t v3 = input[y_high * width + x_low];
scalar_t v4 = input[y_high * width + x_high];
scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
#endif // COMMON_CUDA_HELPER
#include <cmath>
#include <cstdint>
template <typename T>
T bilinear_interpolate_2d(const T *src, const int64_t src_h, const int64_t src_w, const T h,
const T w) {
if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) {
return 0;
}
int64_t h_low = floor(h);
int64_t w_low = floor(w);
int64_t h_high = h_low + 1;
int64_t w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh;
T hw = 1 - lw;
T v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low];
T v2 = 0;
if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high];
T v3 = 0;
if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low];
T v4 = 0;
if (h_high <= src_h - 1 && w_high <= src_w - 1) v4 = src[h_high * src_w + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
// output: (channels * kernel_h * kernel_w, dst_h * dst_w)
template <typename T>
void deformable_im2col_2d(const T *input, const T *offset, const T *mask, const int64_t src_h,
const int64_t src_w, const int64_t kernel_h, const int64_t kernel_w,
const int64_t pad_h, const int64_t pad_w, const int64_t stride_h,
const int64_t stride_w, const int64_t dilation_h,
const int64_t dilation_w, const int64_t channels,
const int64_t offset_groups, const int64_t dst_h, const int64_t dst_w,
const bool use_mask, T *columns) {
const int64_t workload = channels * dst_h * dst_w;
for (int64_t index = 0; index != workload; ++index) {
const int64_t ow = index % dst_w;
const int64_t oh = (index / dst_w) % dst_h;
const int64_t ic = index / (dst_w * dst_h);
const int64_t oc = ic * kernel_h * kernel_w;
int64_t c_per_offset_grp = channels / offset_groups;
const int64_t grp_idx = ic / c_per_offset_grp;
auto columns_ptr = columns + (oc * (dst_h * dst_w) + oh * dst_w + ow);
auto input_ptr = input + ic * (src_h * src_w);
auto offset_ptr = offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w;
auto mask_ptr = mask;
if (use_mask) {
mask_ptr += grp_idx * kernel_h * kernel_w * dst_h * dst_w;
}
for (int64_t kh = 0; kh < kernel_h; ++kh) {
for (int64_t kw = 0; kw < kernel_w; ++kw) {
const int64_t mask_idx = kh * kernel_w + kw;
const int64_t offset_idx = 2 * mask_idx;
T mask_value = 1;
if (use_mask) {
mask_value = mask_ptr[mask_idx * (dst_h * dst_w) + oh * dst_w + ow];
}
const T offset_h = offset_ptr[offset_idx * (dst_h * dst_w) + oh * dst_w + ow];
const T offset_w = offset_ptr[(offset_idx + 1) * (dst_h * dst_w) + oh * dst_w + ow];
const T ih = (oh * stride_h - pad_h) + kh * dilation_h + offset_h;
const T iw = (ow * stride_w - pad_w) + kw * dilation_w + offset_w;
*columns_ptr = mask_value * bilinear_interpolate_2d<T>(input_ptr, src_h, src_w, ih, iw);
columns_ptr += dst_h * dst_w;
}
}
}
}
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer
*****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer
*********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#ifndef TRT_MODULATED_DEFORM_CONV_KERNEL_CUH
#define TRT_MODULATED_DEFORM_CONV_KERNEL_CUH
#include <cuda_fp16.h>
#include <float.h>
#include "common_cuda_helper.cuh"
template <typename T>
__device__ float mdcn_im2col_bilinear(const T *input, const int data_width, const int height,
const int width, float h, float w) {
int h_low = floorf(h);
int w_low = floorf(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh, hw = 1 - lw;
T v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
T v2 = 0;
if (h_low >= 0 && w_high <= width - 1) v2 = input[h_low * data_width + w_high];
T v3 = 0;
if (h_high <= height - 1 && w_low >= 0) v3 = input[h_high * data_width + w_low];
T v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) v4 = input[h_high * data_width + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return float(val);
}
template <>
__device__ float mdcn_im2col_bilinear<__half>(const __half *input, const int data_width,
const int height, const int width, float h, float w) {
int h_low = floorf(h);
int w_low = floorf(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh, hw = 1 - lw;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = __half2float(input[h_low * data_width + w_low]);
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) v2 = __half2float(input[h_low * data_width + w_high]);
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) v3 = __half2float(input[h_high * data_width + w_low]);
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = __half2float(input[h_high * data_width + w_high]);
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__global__ void modulated_deformable_im2col_gpu_kernel(
const int n, const T *data_im, const T *data_offset, const T *data_mask, const int height,
const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size, const int num_channels,
const int deformable_group, const int height_col, const int width_col, T *data_col) {
CUDA_1D_KERNEL_LOOP(index, n) {
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T *data_col_ptr =
data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const T *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col * width_col;
const T *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
float val = 0.0f;
const float h_im = h_in + i * dilation_h + (float)offset_h;
const float w_im = w_in + j * dilation_w + (float)offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
val = mdcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
*data_col_ptr = (T)(val * (float)mask);
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
#endif // TRT_MODULATED_DEFORM_CONV_KERNEL_CUH
# Copyright (c) OpenMMLab. All rights reserved.
# ncnn
find_package(ncnn)
if (ncnn_FOUND)
message(STATUS "ncnn library found!")
else ()
message(FATAL_ERROR "Could not locate ncnn")
endif ()
if (NOT ANDROID AND NOT IOS AND NOT CMAKE_CROSSCOMPILING)
add_subdirectory(ops)
add_subdirectory(onnx2ncnn)
add_subdirectory(pyncnn_ext)
else ()
# In case of embedded platform, like android, or ios, we only build custom ncnn
# ops, and leave the executable converter(onnx2ncnn, pyncnn_ext) built under
# the host platforms
add_subdirectory(ops)
endif ()
# Copyright (c) OpenMMLab. All rights reserved.
project(onnx2ncnn)
find_package(Protobuf)
if (PROTOBUF_FOUND)
if (${Protobuf_PROTOC_EXECUTABLE} STREQUAL "")
message(FATAL_ERROR "protoc not found, try `-DProtobuf_PROTOC_EXECUTABLE=/path/to/protoc`")
endif ()
protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS
${CMAKE_CURRENT_SOURCE_DIR}/onnx.proto)
add_executable(mmdeploy_onnx2ncnn onnx2ncnn.cpp fuse_pass.cpp shape_inference.cpp ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS})
target_include_directories(mmdeploy_onnx2ncnn PRIVATE ${PROTOBUF_INCLUDE_DIR}
${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(mmdeploy_onnx2ncnn PRIVATE ${PROTOBUF_LIBRARIES})
if (MSVC)
target_compile_options(mmdeploy_onnx2ncnn PUBLIC $<$<COMPILE_LANGUAGE:CXX>:/Za>)
endif()
set(_NCNN_CONVERTER_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/backend/ncnn)
install(TARGETS mmdeploy_onnx2ncnn DESTINATION ${_NCNN_CONVERTER_DIR})
else ()
message(
FATAL_ERROR "Protobuf not found, onnx model convert tool won't be built")
endif ()
// Copyright (c) OpenMMLab. All rights reserved.
#include "fuse_pass.h"
void fuse_identity(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
// fuse
// identity --> op
// to
// noop_reducencnn --> op
const int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; ++i) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
for (int j = 0; j < node->input_size(); ++j) {
std::string output_name = node->input(j);
onnx::NodeProto* last_node = find_node_by_output_name(mutable_graph, output_name);
if (last_node && last_node->op_type() == "Identity") {
node->set_input(j, last_node->input(0));
node_reference[last_node->output(0)] -= 1;
node_reference[last_node->input(0)] += 1;
if (node_reference[last_node->output(0)] == 0) {
last_node->set_op_type("noop_reducedncnn");
node_reference[last_node->input(0)] -= 1;
reduced_node_count += 1;
}
}
}
}
}
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
const int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; ++i) {
onnx::NodeProto* gather = mutable_graph->mutable_node(i);
if (gather->op_type() != "Gather") {
continue;
}
if (weights.find(std::string(gather->input(1))) == weights.end()) {
continue;
}
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
if (indices.size() != 1) {
continue;
}
{
// reconstruct node connections
node_reference[gather->input(1)] -= 1;
std::string origin_inp = gather->input(0);
gather->clear_input();
gather->add_input(origin_inp);
}
{
// update axis, starts and ends
int axis = get_node_attr_i(*gather, "axis", 1) - 1;
gather->set_op_type("Crop");
gather->clear_attribute();
int indice = indices[0];
set_node_attr_ai(*gather, "starts", std::vector<int>{indice});
set_node_attr_ai(*gather, "ends", std::vector<int>{indice + 1});
set_node_attr_ai(*gather, "axis", std::vector<int>{axis});
}
}
}
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// weight <= Reshape(weight)
if (node->op_type() == "Reshape") {
// check weight
if (weights.find(node->input(0)) == weights.end()) continue;
weights[node->output(0)] = weights[node->input(0)];
// set weight shape directly
std::vector<int> shape;
if (node->input_size() == 1) {
shape = get_node_attr_ai(*node, "shape");
} else if (node->input_size() == 2) {
// opset 5
shape = get_node_attr_from_input_ai(weights[node->input(1)]);
}
weights[node->output(0)].clear_dims();
for (int j = 0; j < shape.size(); j++) {
weights[node->output(0)].add_dims(shape[j]);
}
// reduce
node->set_op_type("noop_reducedncnn");
node_reference[node->input(0)] -= 1;
if (node->input_size() == 2) {
node_reference[node->input(1)] -= 1;
}
reduced_node_count += 1;
i += 1;
}
}
}
void fuse_weight_transpose(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// weight <= Transpose(weight)
if (node->op_type() == "Transpose") {
// check weight
if (weights.find(node->input(0)) == weights.end()) continue;
if (weights[node->input(0)].dims_size() != 2) continue;
// perm = (1, 0)
std::vector<int> perm = get_node_attr_ai(*node, "perm");
if (perm.size() != 2) continue;
if (perm[0] != 1 || perm[1] != 0) continue;
weights[node->output(0)] = weights[node->input(0)];
// permute weight
{
onnx::TensorProto& B = weights[node->output(0)];
const int h = B.dims(0);
const int w = B.dims(1);
std::vector<float> permuted_data;
permuted_data.reserve((size_t)h * w);
const float* bptr =
B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
for (int j = 0; j < w; j++) {
for (int k = 0; k < h; k++) {
float vb = bptr[k * w + j];
permuted_data.push_back(vb);
}
}
B.set_dims(0, w);
B.set_dims(1, h);
if (B.has_raw_data()) {
B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
} else {
for (int j = 0; j < (int)permuted_data.size(); j++) B.set_float_data(j, permuted_data[j]);
}
}
// reduce
node->set_op_type("noop_reducedncnn");
node_reference[node->input(0)] -= 1;
reduced_node_count += 1;
i += 1;
}
}
}
void fuse_shufflechannel(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// ShuffleChannel <= Reshape - Transpose - Reshape
// ShuffleChannel <= Reshape - Transpose - Constant - Reshape
if (node->op_type() == "Reshape") {
if (node_reference[node->output(0)] != 1) continue;
std::vector<int> shape;
if (node->input_size() == 1) {
shape = get_node_attr_ai(*node, "shape");
} else {
// skip weight reshape
if (weights.find(node->input(1)) == weights.end()) continue;
shape = get_node_attr_from_input_ai(weights[node->input(1)]);
}
// 1 groups channels_per_group, height, width
// reverse style = channels_per_group, groups, height * width
if (shape.size() != 5 && shape.size() != 3) continue;
if (shape.size() == 5 && shape[0] != 1) continue;
if (i + 2 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
if (node3->op_type() == "Constant") {
if (i + 3 >= node_count) continue;
node3 = mutable_graph->mutable_node(i + 3);
}
if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue;
if (node_reference[node2->output(0)] != 1) continue;
// 0 2 1 3 4
// reverse style = 1 0 2
std::vector<int> perm = get_node_attr_ai(*node2, "perm");
if (perm.size() != 5 && perm.size() != 3) continue;
if (perm.size() == 5 &&
(perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
continue;
if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)) continue;
std::vector<int> shape3;
if (node3->input_size() == 1) {
shape3 = get_node_attr_ai(*node3, "shape");
} else {
// skip weight reshape
if (weights.find(node3->input(1)) == weights.end()) continue;
shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
}
// 1, -1, height, width
// reverse style = group, -1, channels_per_group, height, width
if (shape3.size() != 4 && shape3.size() != 5) continue;
if (shape3.size() == 4 &&
(shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
continue;
if (shape3.size() == 5 &&
(shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
continue;
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
if (node->input_size() == 2) {
node_reference[node->input(1)] -= 1;
}
node_reference[node->output(0)] -= 1;
node_reference[node2->output(0)] -= 1;
if (node3->input_size() == 2) {
node_reference[node3->input(1)] -= 1;
}
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
node3->set_op_type("ShuffleChannel");
node3->set_input(0, node->input(0));
onnx::AttributeProto* attr_group = node3->add_attribute();
attr_group->set_name("group");
attr_group->set_i(shape[1]);
onnx::AttributeProto* attr_reverse = node3->add_attribute();
attr_reverse->set_name("reverse");
attr_reverse->set_i(shape.size() == 3);
reduced_node_count += 2;
i += 2;
}
}
}
void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
if (node->op_type() == "ShuffleChannel") {
// reverse = 1
int reverse = get_node_attr_i(*node, "reverse");
if (reverse != 1) continue;
if (i + 2 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
if (node2->op_type() != "Gather" || node3->op_type() != "Gather") continue;
if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) continue;
// axis = 0
int gather2_axis = get_node_attr_i(*node2, "axis");
if (gather2_axis != 0) continue;
// indices = 0
if (weights.find(node2->input(1)) == weights.end()) continue;
std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
if (gather2_indices.size() != 1 || gather2_indices[0] != 0) continue;
// axis = 0
int gather3_axis = get_node_attr_i(*node3, "axis");
if (gather3_axis != 0) continue;
// indices = 1
if (weights.find(node3->input(1)) == weights.end()) continue;
std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
if (gather3_indices.size() != 1 || gather3_indices[0] != 1) continue;
// reduce
node2->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 2;
node_reference[node2->input(1)] -= 1;
node_reference[node3->input(1)] -= 1;
node3->set_op_type("Split");
node3->clear_input();
node3->add_input(node->output(0));
node3->add_output(node3->output(0));
node3->set_output(0, node2->output(0));
node3->clear_attribute();
onnx::AttributeProto* attr_axis = node3->add_attribute();
attr_axis->set_name("axis");
attr_axis->set_i(1);
reduced_node_count += 1;
i += 1;
}
}
}
/**
* @brief fuse subgraph
*
* conv - - - - - - - - - - - -> reshape
* \ /
* shape - slice - concat
*
* to
*
* conv --> reshape
*
* @param mutable_graph
* @param weights
* @param node_reference
* @param blob_names
* @param reduced_node_count
*/
void fuse_conv_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
std::map<std::string, std::vector<int>> shape_context;
const int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* conv = mutable_graph->mutable_node(i);
if (conv->op_type() != "Conv") {
continue;
}
if (i + 4 >= node_count) {
continue;
}
onnx::NodeProto *shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr;
// match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant
std::vector<std::tuple<std::string, onnx::NodeProto**>> candidates = {
{"Shape", &shape}, {"Slice", &slice}, {"Concat", &concat}, {"Reshape", &reshape}};
int MAX = std::min(10, node_count - i - 1);
int pos_candidate = 0;
for (int j = 0; j < MAX; ++j) {
auto node_ptr = mutable_graph->mutable_node(j + i + 1);
if (node_ptr->op_type() == "Constant") {
continue;
}
if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) {
*(std::get<1>(candidates[pos_candidate])) = node_ptr;
pos_candidate++;
}
}
if (pos_candidate != candidates.size()) {
// not match the sequence
continue;
}
if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 ||
node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 ||
node_reference[reshape->output(0)] != 1) {
continue;
}
// check the connections
if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) {
continue;
}
if (slice->input(0) != shape->output(0)) {
continue;
}
if (concat->input(0) != slice->output(0)) {
continue;
}
if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) {
continue;
}
// add reshape attr
auto result = query_shape(mutable_graph, concat, weights, shape_context);
if (!std::get<0>(result)) {
continue;
}
set_node_attr_ai(*reshape, "shape", std::get<1>(result));
// reconstruct graph
{
// remove reference
node_reference[reshape->input(1)] -= 1;
node_reference[concat->input(0)] -= 1;
node_reference[slice->input(0)] -= 1;
node_reference[shape->input(0)] -= 1;
// remove tensor/blob on edge
blob_names.erase(slice->input(0));
blob_names.erase(slice->input(1));
blob_names.erase(slice->input(2));
blob_names.erase(slice->input(3));
weights.erase(slice->input(1));
weights.erase(slice->input(2));
weights.erase(slice->input(3));
blob_names.erase(concat->input(0));
blob_names.erase(concat->input(1));
weights.erase(concat->input(1));
blob_names.erase(reshape->input(0));
// update edge
shape->clear_input();
reshape->clear_input();
reshape->add_input(conv->output(0));
shape->set_op_type("noop_reducedncnn");
slice->set_op_type("noop_reducedncnn");
concat->set_op_type("noop_reducedncnn");
reduced_node_count += 3;
}
i += 3;
}
}
void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// Add/Sub/Mul/Div/Min/Max/Pow
if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" ||
node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" ||
node->op_type() == "Pow") {
if (weights.find(node->input(1)) == weights.end()) continue;
const onnx::TensorProto& scalar_b = weights[node->input(1)];
if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue;
float b = get_node_attr_from_input<float>(scalar_b);
node_reference[node->input(1)] -= 1;
std::string input = node->input(0);
node->clear_input();
node->add_input(input);
onnx::AttributeProto* attr_with_scalar = node->add_attribute();
attr_with_scalar->set_name("with_scalar");
attr_with_scalar->set_i(1);
onnx::AttributeProto* attr_b = node->add_attribute();
attr_b->set_name("b");
attr_b->set_f(b);
}
}
}
void fuse_hardswish(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
// HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
// HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
// HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
// out = x * F.relu6(x + 3, inplace=True) / 6
if (node->op_type() == "Add") {
if (node_reference[node->output(0)] != 1) continue;
if (i + 3 >= node_count) continue;
if (weights.find(node->input(1)) == weights.end()) continue;
const onnx::TensorProto& add_three = weights[node->input(1)];
if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue;
float constant_add_three = get_node_attr_from_input<float>(add_three);
if (constant_add_three != 3.f) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
if (node4->op_type() == "Constant") {
if (i + 4 >= node_count) continue;
node4 = mutable_graph->mutable_node(i + 4);
}
if (node2->op_type() != "Clip" || node3->op_type() != "Mul" ||
(node4->op_type() != "Div" && node4->op_type() != "Mul"))
continue;
if (node_reference[node2->output(0)] != 1) continue;
float relu6_min;
float relu6_max;
if (node2->input_size() == 1) {
relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
} else {
const onnx::TensorProto& min_tp = weights[node2->input(1)];
const onnx::TensorProto& max_tp = weights[node2->input(2)];
relu6_min = get_node_attr_from_input<float>(min_tp);
relu6_max = get_node_attr_from_input<float>(max_tp);
}
if (relu6_min != 0.f || relu6_max != 6.f) continue;
if (node_reference[node3->output(0)] != 1) continue;
if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0)) continue;
if (weights.find(node4->input(1)) == weights.end()) continue;
const onnx::TensorProto& div_six = weights[node4->input(1)];
if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue;
float constant_div_six = get_node_attr_from_input<float>(div_six);
if (node4->op_type() == "Div" && constant_div_six != 6.f) continue;
if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
node3->set_op_type("noop_reducedncnn");
node_reference[node->input(0)] -= 1;
node_reference[node->input(1)] -= 1;
node_reference[node->output(0)] -= 1;
if (node2->input_size() == 3) {
node_reference[node2->input(1)] -= 1;
node_reference[node2->input(2)] -= 1;
}
node_reference[node2->output(0)] -= 1;
node_reference[node3->output(0)] -= 1;
node_reference[node4->input(1)] -= 1;
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
blob_names.erase(node3->output(0));
node4->set_op_type("HardSwish");
node4->clear_input();
node4->add_input(node->input(0));
onnx::AttributeProto* attr_alpha = node4->add_attribute();
attr_alpha->set_name("alpha");
attr_alpha->set_f(1.f / 6.f);
onnx::AttributeProto* attr_beta = node4->add_attribute();
attr_beta->set_name("beta");
attr_beta->set_f(3.f / 6.f);
reduced_node_count += 3;
i += 3;
}
}
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// HardSwish <= HardSigmoid - Mul
// out = x * hsigmoid(x)
if (node->op_type() == "HardSigmoid") {
if (node_reference[node->output(0)] != 1) continue;
float alpha = get_node_attr_f(*node, "alpha", 0.2f);
float beta = get_node_attr_f(*node, "beta", 0.5f);
if (i + 1 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
if (node2->op_type() != "Mul") continue;
if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node_reference[node->input(0)] -= 1;
node_reference[node->output(0)] -= 1;
blob_names.erase(node->output(0));
node2->set_op_type("HardSwish");
node2->clear_input();
node2->add_input(node->input(0));
onnx::AttributeProto* attr_alpha = node2->add_attribute();
attr_alpha->set_name("alpha");
attr_alpha->set_f(alpha);
onnx::AttributeProto* attr_beta = node2->add_attribute();
attr_beta->set_name("beta");
attr_beta->set_f(beta);
reduced_node_count += 1;
i += 1;
}
}
}
void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
// HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
// HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
// HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
// out = F.relu6(x + 3, inplace=True) / 6
if (node->op_type() == "Add") {
if (node_reference[node->output(0)] != 1) continue;
if (i + 2 >= node_count) continue;
if (weights.find(node->input(1)) == weights.end()) continue;
const onnx::TensorProto& add_three = weights[node->input(1)];
if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue;
float constant_add_three = get_node_attr_from_input<float>(add_three);
if (constant_add_three != 3.f) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
if (node3->op_type() == "Constant") {
if (i + 3 >= node_count) continue;
node3 = mutable_graph->mutable_node(i + 3);
}
if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
continue;
if (node_reference[node2->output(0)] != 1) continue;
float relu6_min;
float relu6_max;
if (node2->input_size() == 1) {
relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
} else {
const onnx::TensorProto& min_tp = weights[node2->input(1)];
const onnx::TensorProto& max_tp = weights[node2->input(2)];
relu6_min = get_node_attr_from_input<float>(min_tp);
relu6_max = get_node_attr_from_input<float>(max_tp);
}
if (relu6_min != 0.f || relu6_max != 6.f) continue;
if (weights.find(node3->input(1)) == weights.end()) continue;
const onnx::TensorProto& div_six = weights[node3->input(1)];
if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue;
float constant_div_six = get_node_attr_from_input<float>(div_six);
if (node3->op_type() == "Div" && constant_div_six != 6.f) continue;
if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
node_reference[node->input(1)] -= 1;
node_reference[node->output(0)] -= 1;
if (node2->input_size() == 3) {
node_reference[node2->input(1)] -= 1;
node_reference[node2->input(2)] -= 1;
}
node_reference[node2->output(0)] -= 1;
node_reference[node3->input(1)] -= 1;
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
node3->set_op_type("HardSigmoid");
node3->clear_input();
node3->add_input(node->input(0));
onnx::AttributeProto* attr_alpha = node3->add_attribute();
attr_alpha->set_name("alpha");
attr_alpha->set_f(1.f / 6.f);
onnx::AttributeProto* attr_beta = node3->add_attribute();
attr_beta->set_name("beta");
attr_beta->set_f(3.f / 6.f);
reduced_node_count += 2;
i += 2;
}
}
}
void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// Swish <= Sigmoid - Mul
// x * torch.sigmoid(x)
if (node->op_type() == "Sigmoid") {
if (node_reference[node->output(0)] != 1) continue;
if (i + 1 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
if (node2->op_type() != "Mul") continue;
if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node_reference[node->input(0)] -= 1;
node_reference[node->output(0)] -= 1;
blob_names.erase(node->output(0));
node2->set_op_type("Swish");
node2->clear_input();
node2->add_input(node->input(0));
reduced_node_count += 1;
i += 1;
}
}
}
void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
if (node->op_type() == "Unsqueeze") {
if (node_reference[node->output(0)] != 1) continue;
if (i + 2 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze") continue;
if (node_reference[node2->output(0)] != 1) continue;
if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node3->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 1;
node_reference[node2->output(0)] -= 1;
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
node2->set_input(0, node->input(0));
node2->set_output(0, node3->output(0));
reduced_node_count += 2;
i += 2;
}
}
}
void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// PReLU <= Unsqueeze - PReLU
if (node->op_type() == "Unsqueeze") {
// check weight
if (weights.find(node->input(0)) == weights.end()) continue;
onnx::TensorProto& B = weights[node->input(0)];
if (B.dims_size() != 1) continue;
if (node_reference[node->output(0)] != 1) continue;
// axes = (1, 2)
std::vector<int> axes = get_node_attr_ai(*node, "axes");
if (axes.size() != 2) continue;
if (axes[0] != 1 || axes[1] != 2) continue;
if (i + 1 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
if (node2->op_type() != "PRelu") continue;
if (node2->input(1) != node->output(0)) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 1;
blob_names.erase(node->output(0));
node2->set_input(1, node->input(0));
reduced_node_count += 1;
i += 1;
}
}
}
void fuse_normalize(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// Normalize <= X - ReduceL2 - Clip - Expand - Div
// Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
if (node->op_type() == "ReduceL2") {
if (node_reference[node->output(0)] != 1) continue;
// axes = (1)
std::vector<int> axes = get_node_attr_ai(*node, "axes");
if (axes.size() != 1) continue;
if (axes[0] != 1) continue;
if (i + 3 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
bool has_shape_node = node3->op_type() == "Shape";
onnx::NodeProto* node_shape = 0;
if (has_shape_node) {
if (i + 4 >= node_count) continue;
node_shape = node3;
node3 = mutable_graph->mutable_node(i + 3);
node4 = mutable_graph->mutable_node(i + 4);
}
if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
continue;
if (node_reference[node2->output(0)] != 1) continue;
if (node_reference[node3->output(0)] != 1) continue;
if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) ||
node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
continue;
if (has_shape_node) {
if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
continue;
}
// +eps
float clip_min;
if (node2->input_size() == 1) {
clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
} else {
const onnx::TensorProto& min_tp = weights[node2->input(1)];
clip_min = get_node_attr_from_input<float>(min_tp);
}
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
if (has_shape_node) {
node_shape->set_op_type("noop_reducedncnn");
}
node3->set_op_type("noop_reducedncnn");
node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
node_reference[node->output(0)] -= 1;
node_reference[node2->output(0)] -= 1;
if (has_shape_node) {
node_reference[node_shape->output(0)] -= 1;
}
node_reference[node3->output(0)] -= 1;
if (node3->input_size() == 2) {
node_reference[node3->input(1)] -= 1;
}
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
if (has_shape_node) {
blob_names.erase(node_shape->output(0));
}
blob_names.erase(node3->output(0));
node4->set_op_type("Normalize");
node4->clear_input();
node4->add_input(node->input(0));
onnx::AttributeProto* attr_alpha = node4->add_attribute();
attr_alpha->set_name("eps");
attr_alpha->set_f(clip_min);
reduced_node_count += has_shape_node ? 4 : 3;
i += has_shape_node ? 4 : 3;
}
}
}
void fuse_groupnorm(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
if (node->op_type() == "Reshape") {
if (node_reference[node->output(0)] != 1) continue;
std::vector<int> shape;
if (node->input_size() == 1) {
shape = get_node_attr_ai(*node, "shape");
} else {
// skip weight reshape
if (weights.find(node->input(1)) == weights.end()) continue;
shape = get_node_attr_from_input_ai(weights[node->input(1)]);
}
// 0, group, -1
if (shape.size() != 3) continue;
if (shape[0] != 0 || shape[2] != -1) continue;
int groups = shape[1];
if (i + 4 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" ||
node4->op_type() != "Mul" || node5->op_type() != "Add")
continue;
if (node_reference[node2->output(0)] != 1) continue;
if (node_reference[node3->output(0)] != 1) continue;
if (node_reference[node4->output(0)] != 1) continue;
if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) ||
node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
continue;
// +eps
float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
// InstanceNormalization S=1 B=0
std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
if ((int)S.size() != groups || (int)B.size() != groups) continue;
bool instancenorm_affine = false;
for (int j = 0; j < groups; j++) {
if (S[j] != 1.f || B[j] != 0.f) {
instancenorm_affine = true;
break;
}
}
if (instancenorm_affine) continue;
std::vector<int> shape2;
if (node3->input_size() == 1) {
shape2 = get_node_attr_ai(*node3, "shape");
} else {
// skip weight reshape
if (weights.find(node3->input(1)) == weights.end()) continue;
shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
}
// 1, channels, w, h
if (shape2.size() != 4) continue;
if (shape2[0] != 1) continue;
int channels = shape2[1];
// affine
std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 &&
affine_B[0] == 0.f) {
// no affine
} else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels) {
// we only allow per-channel affine
continue;
}
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
node3->set_op_type("noop_reducedncnn");
node4->set_op_type("noop_reducedncnn");
if (node->input_size() == 2) {
node_reference[node->input(1)] -= 1;
}
node_reference[node->output(0)] -= 1;
node_reference[node2->input(1)] -= 1;
node_reference[node2->input(2)] -= 1;
node_reference[node2->output(0)] -= 1;
if (node3->input_size() == 2) {
node_reference[node3->input(1)] -= 1;
}
node_reference[node3->output(0)] -= 1;
node_reference[node4->output(0)] -= 1;
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
blob_names.erase(node3->output(0));
blob_names.erase(node4->output(0));
std::string affine_scale = node4->input(1);
std::string affine_bias = node5->input(1);
node5->set_op_type("GroupNorm");
node5->clear_input();
node5->add_input(node->input(0));
node5->add_input(affine_scale);
node5->add_input(affine_bias);
onnx::AttributeProto* attr_groups = node5->add_attribute();
attr_groups->set_name("groups");
attr_groups->set_i(groups);
onnx::AttributeProto* attr_channels = node5->add_attribute();
attr_channels->set_name("channels");
attr_channels->set_i(channels);
onnx::AttributeProto* attr_eps = node5->add_attribute();
attr_eps->set_name("epsilon");
attr_eps->set_f(eps);
onnx::AttributeProto* attr_affine = node5->add_attribute();
attr_affine->set_name("affine");
attr_affine->set_i(1);
reduced_node_count += 4;
i += 4;
}
}
}
void fuse_layernorm(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div
// LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div -
// Mul - Add
if (node->op_type() == "ReduceMean") {
if (node_reference[node->output(0)] != 1) continue;
std::vector<int> axes = get_node_attr_ai(*node, "axes");
// -1
// -2 -1
if (axes.size() != 1 && axes.size() != 2) continue;
int normed_axes = (int)axes.size();
if (normed_axes == 1 && axes[0] != -1) continue;
if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1)) continue;
if (i + 6 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
if (node2->op_type() != "Sub" || node3->op_type() != "Pow" ||
node4->op_type() != "ReduceMean" || node5->op_type() != "Add" ||
node6->op_type() != "Sqrt" || node7->op_type() != "Div")
continue;
if (node_reference[node2->output(0)] != 2) continue;
if (node_reference[node3->output(0)] != 1) continue;
if (node_reference[node4->output(0)] != 1) continue;
if (node_reference[node5->output(0)] != 1) continue;
if (node_reference[node6->output(0)] != 1) continue;
if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0) ||
node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) ||
node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) ||
node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0))
continue;
if (weights.find(node3->input(1)) == weights.end()) continue;
const onnx::TensorProto& pow_two = weights[node3->input(1)];
if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue;
float constant_pow_two = get_node_attr_from_input<float>(pow_two);
if (constant_pow_two != 2.f) continue;
std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
// -1
// -2 -1
if ((int)axes4.size() != normed_axes) continue;
if (normed_axes == 1 && axes4[0] != -1) continue;
if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1)) continue;
if (weights.find(node5->input(1)) == weights.end()) continue;
const onnx::TensorProto& add_eps = weights[node5->input(1)];
if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue;
float eps = get_node_attr_from_input<float>(add_eps);
int affine = 0;
while (i + 8 < node_count) {
onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
if (node8->op_type() != "Mul" || node9->op_type() != "Add") break;
if (node_reference[node7->output(0)] != 1) break;
if (node_reference[node8->output(0)] != 1) break;
if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0)) break;
// affine
std::vector<float> affine_S = get_node_attr_from_input_af(weights[node8->input(1)]);
std::vector<float> affine_B = get_node_attr_from_input_af(weights[node9->input(1)]);
if (affine_S.size() != affine_B.size()) break;
affine = 1;
break;
}
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
node3->set_op_type("noop_reducedncnn");
node4->set_op_type("noop_reducedncnn");
node5->set_op_type("noop_reducedncnn");
node6->set_op_type("noop_reducedncnn");
node_reference[node->input(0)] -= 1;
node_reference[node2->input(0)] -= 1;
node_reference[node2->input(1)] -= 1;
node_reference[node3->input(0)] -= 1;
node_reference[node3->input(1)] -= 1;
node_reference[node4->input(0)] -= 1;
node_reference[node5->input(0)] -= 1;
node_reference[node5->input(1)] -= 1;
node_reference[node6->input(0)] -= 1;
node_reference[node7->input(0)] -= 1;
node_reference[node7->input(1)] -= 1;
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
blob_names.erase(node3->output(0));
blob_names.erase(node4->output(0));
blob_names.erase(node5->output(0));
blob_names.erase(node6->output(0));
node_reference[node->input(0)] += 1;
if (affine == 0) {
node7->set_op_type("LayerNorm");
node7->clear_input();
node7->add_input(node->input(0));
onnx::AttributeProto* attr_eps = node7->add_attribute();
attr_eps->set_name("epsilon");
attr_eps->set_f(eps);
onnx::AttributeProto* attr_affine = node7->add_attribute();
attr_affine->set_name("affine");
attr_affine->set_i(affine);
reduced_node_count += 6;
i += 6;
} else // if (affine == 1)
{
onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
node7->set_op_type("noop_reducedncnn");
node8->set_op_type("noop_reducedncnn");
node_reference[node8->input(0)] -= 1;
node_reference[node9->input(0)] -= 1;
blob_names.erase(node7->output(0));
blob_names.erase(node8->output(0));
std::string affine_scale = node8->input(1);
std::string affine_bias = node9->input(1);
node9->set_op_type("LayerNorm");
node9->clear_input();
node9->add_input(node->input(0));
node9->add_input(affine_scale);
node9->add_input(affine_bias);
onnx::AttributeProto* attr_eps = node9->add_attribute();
attr_eps->set_name("epsilon");
attr_eps->set_f(eps);
onnx::AttributeProto* attr_affine = node9->add_attribute();
attr_affine->set_name("affine");
attr_affine->set_i(affine);
reduced_node_count += 8;
i += 8;
}
}
}
}
void fuse_flatten(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat
// - Reshape
if (node->op_type() == "Shape") {
if (node_reference[node->output(0)] != 1) continue;
if (i + 6 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
if (node2->op_type() != "Gather" || node3->op_type() != "Constant" ||
node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze" ||
node6->op_type() != "Concat" || node7->op_type() != "Reshape")
continue;
if (node_reference[node2->output(0)] != 1) continue;
// if (node_reference[node3->output(0)] != 1)
// continue;
if (node_reference[node4->output(0)] != 1) continue;
if (node_reference[node5->output(0)] != 1) continue;
if (node_reference[node6->output(0)] != 1) continue;
if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) ||
node5->input(0) != node3->output(0) || node6->input(0) != node4->output(0) ||
node6->input(1) != node5->output(0) || node7->input(0) != node->input(0) ||
node7->input(1) != node6->output(0))
continue;
// axis = 0
int gather_axis = get_node_attr_i(*node2, "axis");
if (gather_axis != 0) continue;
// indices = 0
if (weights.find(node2->input(1)) == weights.end()) continue;
std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
if (gather_indices.size() != 1 || gather_indices[0] != 0) continue;
// axes = (0)
std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
if (unsqueeze_axes.size() != 1) continue;
if (unsqueeze_axes[0] != 0) continue;
// axes = (0)
std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
if (unsqueeze2_axes.size() != 1) continue;
if (unsqueeze2_axes[0] != 0) continue;
// data = -1
if (weights.find(node5->input(0)) == weights.end()) continue;
std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1) continue;
// axis = 0
int concat_axis = get_node_attr_i(*node6, "axis");
if (concat_axis != 0) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
// node3->set_op_type("noop_reducedncnn");
node4->set_op_type("noop_reducedncnn");
node5->set_op_type("noop_reducedncnn");
node6->set_op_type("noop_reducedncnn");
node_reference[node->input(0)] -= 1;
node_reference[node->output(0)] -= 1;
node_reference[node2->input(1)] -= 1;
node_reference[node2->output(0)] -= 1;
// node_reference[node3->output(0)] -= 1;
node_reference[node4->output(0)] -= 1;
node_reference[node5->input(0)] -= 1;
node_reference[node5->output(0)] -= 1;
node_reference[node6->output(0)] -= 1;
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
// blob_names.erase(node3->output(0));
blob_names.erase(node4->output(0));
blob_names.erase(node5->output(0));
blob_names.erase(node6->output(0));
node7->set_op_type("Flatten");
node7->clear_input();
node7->add_input(node->input(0));
reduced_node_count += 5;
i += 5;
}
}
}
void fuse_pixelshuffle(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// PixelShuffle <= Reshape - Transpose - Reshape
// PixelShuffle <= Reshape - Transpose - Constant - Reshape
if (node->op_type() == "Reshape") {
if (node_reference[node->output(0)] != 1) continue;
std::vector<int> shape;
if (node->input_size() == 1) {
shape = get_node_attr_ai(*node, "shape");
} else {
// skip weight reshape
if (weights.find(node->input(1)) == weights.end()) continue;
shape = get_node_attr_from_input_ai(weights[node->input(1)]);
}
// -1, 3, upscale_factor, upscale_factor, height, width
if (shape.size() != 6) continue;
if (shape[0] != 1 && shape[0] != -1) continue;
if (shape[2] != shape[3]) continue;
if (i + 2 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
if (node3->op_type() == "Constant") {
if (i + 3 >= node_count) continue;
node3 = mutable_graph->mutable_node(i + 3);
}
if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue;
if (node_reference[node2->output(0)] != 1) continue;
// 0 1 4 2 5 3
std::vector<int> perm = get_node_attr_ai(*node2, "perm");
if (perm.size() != 6) continue;
if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 ||
perm[5] != 3)
continue;
std::vector<int> shape3;
if (node3->input_size() == 1) {
shape3 = get_node_attr_ai(*node3, "shape");
} else {
// skip weight reshape
if (weights.find(node3->input(1)) == weights.end()) continue;
shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
}
// -1, 3, height, width
if (shape3.size() != 4) continue;
if (shape3[0] != 1 && shape3[0] != -1) continue;
if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] ||
shape3[3] != shape[3] * shape[5])
continue;
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
if (node->input_size() == 2) {
node_reference[node->input(1)] -= 1;
}
node_reference[node->output(0)] -= 1;
node_reference[node2->output(0)] -= 1;
if (node3->input_size() == 2) {
node_reference[node3->input(1)] -= 1;
}
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
node3->set_op_type("PixelShuffle");
node3->set_input(0, node->input(0));
onnx::AttributeProto* attr_group = node3->add_attribute();
attr_group->set_name("scale_factor");
attr_group->set_i(shape[2]);
reduced_node_count += 2;
i += 2;
}
}
}
void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// PixelShuffle <= Reshape - Transpose - Reshape
// PixelShuffle <= Reshape - Transpose - Constant - Reshape
if (node->op_type() == "Reshape") {
if (node_reference[node->output(0)] != 1) continue;
std::vector<int> shape;
if (node->input_size() == 1) {
shape = get_node_attr_ai(*node, "shape");
} else {
// skip weight reshape
if (weights.find(node->input(1)) == weights.end()) continue;
shape = get_node_attr_from_input_ai(weights[node->input(1)]);
}
// -1, 3, out_height, block_size, out_width, block_size
if (shape.size() != 6) continue;
if (shape[0] != 1 && shape[0] != -1) continue;
if (shape[3] != shape[5]) continue;
if (i + 2 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
if (node3->op_type() == "Constant") {
if (i + 3 >= node_count) continue;
node3 = mutable_graph->mutable_node(i + 3);
}
if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue;
if (node_reference[node2->output(0)] != 1) continue;
// 0 1 3 5 2 4
std::vector<int> perm = get_node_attr_ai(*node2, "perm");
if (perm.size() != 6) continue;
if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 ||
perm[5] != 4)
continue;
std::vector<int> shape3;
if (node3->input_size() == 1) {
shape3 = get_node_attr_ai(*node3, "shape");
} else {
// skip weight reshape
if (weights.find(node3->input(1)) == weights.end()) continue;
shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
}
// -1, out_channels, out_height, out_width
if (shape3.size() != 4) continue;
if (shape3[0] != 1 && shape3[0] != -1) continue;
if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] ||
shape3[3] != shape[4])
continue;
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
if (node->input_size() == 2) {
node_reference[node->input(1)] -= 1;
}
node_reference[node->output(0)] -= 1;
node_reference[node2->output(0)] -= 1;
if (node3->input_size() == 2) {
node_reference[node3->input(1)] -= 1;
}
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
node3->set_op_type("Reorg");
node3->set_input(0, node->input(0));
onnx::AttributeProto* attr_group = node3->add_attribute();
attr_group->set_name("stride");
attr_group->set_i(shape[3]);
reduced_node_count += 2;
i += 2;
}
}
}
void fuse_expand_broadcast(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
if (node->op_type() == "Expand") {
if (node_reference[node->output(0)] != 1) continue;
if (i + 1 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" &&
node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
continue;
if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0)) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 1;
if (node->input_size() == 2) {
node_reference[node->input(1)] -= 1;
}
blob_names.erase(node->output(0));
if (node2->input(0) == node->output(0)) {
node2->set_input(0, node->input(0));
} else {
node2->set_input(1, node->input(0));
}
reduced_node_count += 1;
i += 1;
}
}
}
void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
// or LSTM(bi) <= LSTM(bi) - Transpose Constant - Reshape - Transpose
if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") {
if (node_reference[node->output(0)] != 1) continue;
if (i + 2 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
// skip if second ops is constant
if (node3->op_type() == "Constant") {
if (i + 3 >= node_count) continue;
node3 = mutable_graph->mutable_node(i + 3);
i += 1;
}
if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue;
if (node_reference[node2->output(0)] != 1) continue;
if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue;
std::string direction = get_node_attr_s(*node, "direction");
if (direction != "bidirectional") continue;
// 0 2 1 3
std::vector<int> perm = get_node_attr_ai(*node2, "perm");
if (perm.size() != 4) continue;
if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3) continue;
std::vector<int> shape;
if (node3->input_size() == 1) {
shape = get_node_attr_ai(*node3, "shape");
} else {
// skip weight reshape
if (weights.find(node3->input(1)) == weights.end()) continue;
shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
}
// 0 0 -1
if (shape.size() != 3) continue;
if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1) continue;
// reduce
node2->set_op_type("noop_reducedncnn");
node3->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 1;
node_reference[node2->output(0)] -= 1;
if (node3->input_size() == 2) {
node_reference[node3->input(1)] -= 1;
}
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
node->set_output(0, node3->output(0));
reduced_node_count += 2;
i += 2;
if (i + 1 < node_count) {
if (node_reference[node3->output(0)] != 1) continue;
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
if (node4->op_type() != "Transpose") continue;
if (node4->input(0) != node->output(0)) continue;
// 1 0 2
std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
if (perm4.size() != 3) continue;
if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue;
// reduce
node4->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 1;
blob_names.erase(node->output(0));
node->set_output(0, node4->output(0));
reduced_node_count += 1;
i += 1;
}
}
}
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") {
if (node_reference[node->output(0)] != 1) continue;
if (i + 1 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
if (node2->op_type() != "Squeeze") continue;
if (node2->input(0) != node->output(0)) continue;
std::string direction = get_node_attr_s(*node, "direction");
if (direction == "bidirectional") continue;
// 1
std::vector<int> axes = get_node_attr_ai(*node2, "axes");
if (axes.size() != 1) continue;
if (axes[0] != 1) continue;
// reduce
node2->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 1;
blob_names.erase(node->output(0));
node->set_output(0, node2->output(0));
reduced_node_count += 1;
i += 1;
if (i + 1 < node_count) {
if (node_reference[node2->output(0)] != 1) continue;
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
if (node3->op_type() != "Transpose") continue;
if (node3->input(0) != node->output(0)) continue;
// 1 0 2
std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
if (perm4.size() != 3) continue;
if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue;
// reduce
node3->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 1;
blob_names.erase(node->output(0));
node->set_output(0, node3->output(0));
reduced_node_count += 1;
i += 1;
}
}
}
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// LSTM <= Transpose - LSTM
if (node->op_type() == "Transpose") {
if (node_reference[node->output(0)] != 1) continue;
// 1 0 2
std::vector<int> perm = get_node_attr_ai(*node, "perm");
if (perm.size() != 3) continue;
if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2) continue;
if (i + 1 >= node_count) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
continue;
if (node2->input(0) != node->output(0)) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node_reference[node->output(0)] -= 1;
blob_names.erase(node->output(0));
node2->set_input(0, node->input(0));
reduced_node_count += 1;
i += 1;
}
}
}
void fuse_multiheadattention(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// MultiHeadAttention <= MatMul(q) - Add
// - MatMul(k) - Add
// - MatMul(v) - Add
// - Mul
// - Reshape - Transpose
// - Reshape - Reshape - Transpose - Transpose
// - Gemm - Softmax - Gemm - Transpose - Reshape -
// MatMul - Add
if (node->op_type() == "MatMul") {
if (i + 19 >= node_count) continue;
if (node_reference[node->output(0)] != 1) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17);
onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18);
onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19);
if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" ||
node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" ||
node8->op_type() != "Reshape" || node9->op_type() != "Transpose" ||
node10->op_type() != "Reshape" || node11->op_type() != "Reshape" ||
node12->op_type() != "Transpose" || node13->op_type() != "Transpose" ||
node14->op_type() != "MatMul" || node15->op_type() != "Softmax" ||
node16->op_type() != "MatMul" || node17->op_type() != "Transpose" ||
node18->op_type() != "Reshape" || node19->op_type() != "MatMul" ||
node20->op_type() != "Add")
continue;
if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 ||
node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 ||
node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 ||
node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 ||
node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 ||
node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 ||
node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 ||
node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 ||
node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1)
continue;
if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) ||
node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) ||
node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) ||
node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) ||
node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) ||
node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) ||
node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) ||
node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) ||
node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) ||
node20->input(0) != node19->output(0))
continue;
std::vector<float> q_B = get_node_attr_from_input_af(weights[node2->input(1)]);
std::vector<float> k_B = get_node_attr_from_input_af(weights[node4->input(1)]);
std::vector<float> v_B = get_node_attr_from_input_af(weights[node6->input(1)]);
std::vector<float> o_B = get_node_attr_from_input_af(weights[node20->input(1)]);
if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size())
continue;
int embed_dim = q_B.size();
// 1 0 2
std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
std::vector<int> perm12 = get_node_attr_ai(*node12, "perm");
if (perm9.size() != 3 || perm12.size() != 3) continue;
if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 ||
perm12[2] != 2)
continue;
// 1 2 0
std::vector<int> perm13 = get_node_attr_ai(*node13, "perm");
if (perm13.size() != 3) continue;
if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0) continue;
// 1 0 2
std::vector<int> perm17 = get_node_attr_ai(*node17, "perm");
if (perm17.size() != 3) continue;
if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2) continue;
int softmax_axis = get_node_attr_i(*node15, "axis");
if (softmax_axis != 2) continue;
// 1/-1, seqlen * num_heads, embed_dim / num_heads
std::vector<int> shape8;
std::vector<int> shape10;
std::vector<int> shape11;
if (node8->input_size() == 1) {
shape8 = get_node_attr_ai(*node8, "shape");
} else {
// skip weight reshape
if (weights.find(node8->input(1)) == weights.end()) continue;
shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
}
if (node10->input_size() == 1) {
shape10 = get_node_attr_ai(*node10, "shape");
} else {
// skip weight reshape
if (weights.find(node10->input(1)) == weights.end()) continue;
shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]);
}
if (node11->input_size() == 1) {
shape11 = get_node_attr_ai(*node11, "shape");
} else {
// skip weight reshape
if (weights.find(node11->input(1)) == weights.end()) continue;
shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]);
}
if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3) continue;
if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] ||
shape8[2] != shape11[2])
continue;
int num_heads = embed_dim / shape8[2];
// 1, seqlen, embed_dim
std::vector<int> shape18;
if (node18->input_size() == 1) {
shape18 = get_node_attr_ai(*node18, "shape");
} else {
// skip weight reshape
if (weights.find(node18->input(1)) == weights.end()) continue;
shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]);
}
if (shape18.size() != 3) continue;
if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1]) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
node3->set_op_type("noop_reducedncnn");
node4->set_op_type("noop_reducedncnn");
node5->set_op_type("noop_reducedncnn");
node6->set_op_type("noop_reducedncnn");
node7->set_op_type("noop_reducedncnn");
node8->set_op_type("noop_reducedncnn");
node9->set_op_type("noop_reducedncnn");
node10->set_op_type("noop_reducedncnn");
node11->set_op_type("noop_reducedncnn");
node12->set_op_type("noop_reducedncnn");
node13->set_op_type("noop_reducedncnn");
node14->set_op_type("noop_reducedncnn");
node15->set_op_type("noop_reducedncnn");
node16->set_op_type("noop_reducedncnn");
node17->set_op_type("noop_reducedncnn");
node18->set_op_type("noop_reducedncnn");
node19->set_op_type("noop_reducedncnn");
node_reference[node2->input(0)] -= 1;
node_reference[node4->input(0)] -= 1;
node_reference[node6->input(0)] -= 1;
node_reference[node7->input(0)] -= 1;
node_reference[node7->input(1)] -= 1;
node_reference[node8->input(0)] -= 1;
if (node8->input_size() == 2) {
node_reference[node8->input(1)] -= 1;
}
node_reference[node9->input(0)] -= 1;
node_reference[node10->input(0)] -= 1;
if (node10->input_size() == 2) {
node_reference[node10->input(1)] -= 1;
}
node_reference[node11->input(0)] -= 1;
if (node11->input_size() == 2) {
node_reference[node11->input(1)] -= 1;
}
node_reference[node12->input(0)] -= 1;
node_reference[node13->input(0)] -= 1;
node_reference[node14->input(0)] -= 1;
node_reference[node14->input(1)] -= 1;
node_reference[node15->input(0)] -= 1;
node_reference[node16->input(0)] -= 1;
node_reference[node16->input(1)] -= 1;
node_reference[node17->input(0)] -= 1;
node_reference[node18->input(0)] -= 1;
if (node18->input_size() == 2) {
node_reference[node18->input(1)] -= 1;
}
node_reference[node19->input(0)] -= 1;
node_reference[node20->input(0)] -= 1;
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
blob_names.erase(node3->output(0));
blob_names.erase(node4->output(0));
blob_names.erase(node5->output(0));
blob_names.erase(node6->output(0));
blob_names.erase(node7->output(0));
blob_names.erase(node8->output(0));
blob_names.erase(node9->output(0));
blob_names.erase(node10->output(0));
blob_names.erase(node11->output(0));
blob_names.erase(node12->output(0));
blob_names.erase(node13->output(0));
blob_names.erase(node14->output(0));
blob_names.erase(node15->output(0));
blob_names.erase(node16->output(0));
blob_names.erase(node17->output(0));
blob_names.erase(node18->output(0));
blob_names.erase(node19->output(0));
std::string qw = node->input(1);
std::string qb = node2->input(1);
std::string kw = node3->input(1);
std::string kb = node4->input(1);
std::string vw = node5->input(1);
std::string vb = node6->input(1);
std::string ow = node19->input(1);
std::string ob = node20->input(1);
node20->set_op_type("MultiHeadAttention");
node20->clear_input();
node20->add_input(node->input(0));
node20->add_input(node3->input(0));
node20->add_input(node5->input(0));
// q
node20->add_input(qw);
node20->add_input(qb);
// k
node20->add_input(kw);
node20->add_input(kb);
// v
node20->add_input(vw);
node20->add_input(vb);
// out linear
node20->add_input(ow);
node20->add_input(ob);
onnx::AttributeProto* attr_embed_dim = node20->add_attribute();
attr_embed_dim->set_name("embed_dim");
attr_embed_dim->set_i(embed_dim);
onnx::AttributeProto* attr_num_heads = node20->add_attribute();
attr_num_heads->set_name("num_heads");
attr_num_heads->set_i(num_heads);
reduced_node_count += 19;
i += 19;
}
}
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
// MultiHeadAttention <= MatMul(qkv) - Add - Split
// - Mul
// - Reshape - Transpose
// - Reshape - Reshape - Transpose - Transpose
// - Gemm - Softmax - Gemm - Transpose - Reshape -
// MatMul - Add
if (node->op_type() == "MatMul") {
if (i + 16 >= node_count) continue;
if (node_reference[node->output(0)] != 1) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" ||
node5->op_type() != "Reshape" || node6->op_type() != "Transpose" ||
node7->op_type() != "Reshape" || node8->op_type() != "Reshape" ||
node9->op_type() != "Transpose" || node10->op_type() != "Transpose" ||
node11->op_type() != "MatMul" || node12->op_type() != "Softmax" ||
node13->op_type() != "MatMul" || node14->op_type() != "Transpose" ||
node15->op_type() != "Reshape" || node16->op_type() != "MatMul" ||
node17->op_type() != "Add")
continue;
if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 ||
node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 ||
node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 ||
node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 ||
node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 ||
node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 ||
node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 ||
node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 ||
node_reference[node16->output(0)] != 1)
continue;
if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) ||
node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) ||
node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) ||
node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) ||
node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) ||
node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) ||
node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) ||
node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) ||
node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0))
continue;
std::vector<float> qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]);
std::vector<float> o_B = get_node_attr_from_input_af(weights[node17->input(1)]);
if (qkv_B.size() != o_B.size() * 3) continue;
int embed_dim = o_B.size();
// 1 0 2
std::vector<int> perm6 = get_node_attr_ai(*node6, "perm");
std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
if (perm6.size() != 3 || perm9.size() != 3) continue;
if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 ||
perm9[2] != 2)
continue;
// 1 2 0
std::vector<int> perm10 = get_node_attr_ai(*node10, "perm");
if (perm10.size() != 3) continue;
if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0) continue;
// 1 0 2
std::vector<int> perm14 = get_node_attr_ai(*node14, "perm");
if (perm14.size() != 3) continue;
if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2) continue;
int softmax_axis = get_node_attr_i(*node12, "axis");
if (softmax_axis != 2) continue;
// 1/-1, seqlen * num_heads, embed_dim / num_heads
std::vector<int> shape5;
std::vector<int> shape7;
std::vector<int> shape8;
if (node5->input_size() == 1) {
shape5 = get_node_attr_ai(*node5, "shape");
} else {
// skip weight reshape
if (weights.find(node5->input(1)) == weights.end()) continue;
shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]);
}
if (node7->input_size() == 1) {
shape7 = get_node_attr_ai(*node7, "shape");
} else {
// skip weight reshape
if (weights.find(node7->input(1)) == weights.end()) continue;
shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]);
}
if (node8->input_size() == 1) {
shape8 = get_node_attr_ai(*node8, "shape");
} else {
// skip weight reshape
if (weights.find(node8->input(1)) == weights.end()) continue;
shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
}
if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3) continue;
if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] ||
shape5[2] != shape8[2])
continue;
int num_heads = embed_dim / shape5[2];
// 1, seqlen, embed_dim
std::vector<int> shape15;
if (node15->input_size() == 1) {
shape15 = get_node_attr_ai(*node15, "shape");
} else {
// skip weight reshape
if (weights.find(node15->input(1)) == weights.end()) continue;
shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]);
}
if (shape15.size() != 3) continue;
if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1]) continue;
// reduce
node->set_op_type("noop_reducedncnn");
node2->set_op_type("noop_reducedncnn");
node3->set_op_type("noop_reducedncnn");
node4->set_op_type("noop_reducedncnn");
node5->set_op_type("noop_reducedncnn");
node6->set_op_type("noop_reducedncnn");
node7->set_op_type("noop_reducedncnn");
node8->set_op_type("noop_reducedncnn");
node9->set_op_type("noop_reducedncnn");
node10->set_op_type("noop_reducedncnn");
node11->set_op_type("noop_reducedncnn");
node12->set_op_type("noop_reducedncnn");
node13->set_op_type("noop_reducedncnn");
node14->set_op_type("noop_reducedncnn");
node15->set_op_type("noop_reducedncnn");
node16->set_op_type("noop_reducedncnn");
node_reference[node2->input(0)] -= 1;
node_reference[node3->input(0)] -= 1;
node_reference[node4->input(0)] -= 1;
node_reference[node4->input(1)] -= 1;
node_reference[node5->input(0)] -= 1;
if (node5->input_size() == 2) {
node_reference[node5->input(1)] -= 1;
}
node_reference[node6->input(0)] -= 1;
node_reference[node7->input(0)] -= 1;
if (node7->input_size() == 2) {
node_reference[node7->input(1)] -= 1;
}
node_reference[node8->input(0)] -= 1;
if (node8->input_size() == 2) {
node_reference[node8->input(1)] -= 1;
}
node_reference[node9->input(0)] -= 1;
node_reference[node10->input(0)] -= 1;
node_reference[node11->input(0)] -= 1;
node_reference[node11->input(1)] -= 1;
node_reference[node12->input(0)] -= 1;
node_reference[node13->input(0)] -= 1;
node_reference[node13->input(1)] -= 1;
node_reference[node14->input(0)] -= 1;
node_reference[node15->input(0)] -= 1;
if (node15->input_size() == 2) {
node_reference[node15->input(1)] -= 1;
}
node_reference[node16->input(0)] -= 1;
node_reference[node17->input(0)] -= 1;
blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
blob_names.erase(node3->output(0));
blob_names.erase(node3->output(1));
blob_names.erase(node3->output(2));
blob_names.erase(node4->output(0));
blob_names.erase(node5->output(0));
blob_names.erase(node6->output(0));
blob_names.erase(node7->output(0));
blob_names.erase(node8->output(0));
blob_names.erase(node9->output(0));
blob_names.erase(node10->output(0));
blob_names.erase(node11->output(0));
blob_names.erase(node12->output(0));
blob_names.erase(node13->output(0));
blob_names.erase(node14->output(0));
blob_names.erase(node15->output(0));
blob_names.erase(node16->output(0));
std::string qkvw = node->input(1);
std::string qkvb = node2->input(1);
std::string ow = node16->input(1);
std::string ob = node17->input(1);
node17->set_op_type("MultiHeadAttention");
node17->clear_input();
node17->add_input(node->input(0));
// qkv
node17->add_input(qkvw);
node17->add_input(qkvb);
// out linear
node17->add_input(ow);
node17->add_input(ob);
onnx::AttributeProto* attr_embed_dim = node17->add_attribute();
attr_embed_dim->set_name("embed_dim");
attr_embed_dim->set_i(embed_dim);
onnx::AttributeProto* attr_num_heads = node17->add_attribute();
attr_num_heads->set_name("num_heads");
attr_num_heads->set_i(num_heads);
reduced_node_count += 16;
i += 16;
}
}
}
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "shape_inference.h"
#include "utils.h"
void fuse_identity(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_shufflechannel(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
/**
* @brief fuse subgraph
*
* conv - - - - - - - - - - - -> reshape
* \ /
* shape - slice - concat
*
* to
*
* conv --> reshape
*
* @param mutable_graph
* @param weights
* @param node_reference
* @param blob_names
* @param reduced_node_count
*/
void fuse_conv_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_hardswish(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_normalize(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_groupnorm(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_layernorm(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_flatten(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_pixelshuffle(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_expand_broadcast(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_multiheadattention(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_weight_transpose(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
//
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
//
// Copyright (c) ONNX Project Contributors.
// Licensed under the MIT license.
syntax = "proto2";
package onnx;
// Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
//
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
// of key-value pairs, where order does not matter and duplicates
// are not allowed.
// Versioning
//
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
//
// To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number.
enum Version {
// proto3 requires the first enum value to be zero.
// We add this just to appease the compiler.
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control.
// For the IR, we are using simple numbers starting with with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;
// IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x0000000000000002;
// IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION_2017_11_3 = 0x0000000000000003;
// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;
// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION = 0x0000000000000005;
}
// Attributes
//
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {
// Note: this enum is structurally identical to the OpSchema::AttrType
// enum defined in schema.h. If you rev one, you likely need to rev the other.
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
INT = 2;
STRING = 3;
TENSOR = 4;
GRAPH = 5;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
}
// The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
optional string ref_attr_name = 21;
// A human-readable documentation for this attribute. Markdown is allowed.
optional string doc_string = 13;
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field hueristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accomodate proto3 implementations.
optional AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
optional float f = 2; // float
optional int64 i = 3; // int
optional bytes s = 4; // UTF-8 string
optional TensorProto t = 5; // tensor value
optional GraphProto g = 6; // graph
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
}
// Defines information on value, including the name, the type, and
// the shape of the value.
message ValueInfoProto {
// This field MUST be present in this version of the IR.
optional string name = 1; // namespace Value
// This field MUST be present in this version of the IR.
optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3;
}
// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
optional string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
optional string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
optional string domain = 7; // namespace Domain
// Additional named attributes.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
optional string doc_string = 6;
}
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
optional int64 ir_version = 1;
// The OperatorSets this model relies on.
// All ModelProtos MUST have at least one entry that
// specifies which version of the ONNX OperatorSet is
// being imported.
//
// All nodes in the ModelProto's graph will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets.
repeated OperatorSetIdProto opset_import = 8;
// The name of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
optional string producer_name = 2;
// The version of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
optional string producer_version = 3;
// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
optional string domain = 4;
// The version of the graph encoded. See Version enum below.
optional int64 model_version = 5;
// A human-readable documentation for this model. Markdown is allowed.
optional string doc_string = 6;
// The parameterized graph that is evaluated to execute the model.
optional GraphProto graph = 7;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
optional string key = 1;
optional string value= 2;
};
message TensorAnnotation {
optional string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}
// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
// The name of the graph.
optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that
// MAY also appear in the input list.
repeated TensorProto initializer = 5;
// A human-readable documentation for this graph. Markdown is allowed.
optional string doc_string = 10;
// The inputs and outputs of the graph.
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;
// Information for the values in the graph. The ValueInfoProto.name's
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
// This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
// DO NOT USE the following fields, they were deprecated from earlier versions.
// repeated string input = 3;
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
}
// Tensors
//
// A serialized tensor value.
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here.
}
// The shape of the tensor.
repeated int64 dims = 1;
// The data type of the tensor.
// This field MUST have a valid TensorProto.DataType value
optional int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
// the current TensorProto.
message Segment {
optional int64 begin = 1;
optional int64 end = 2;
}
optional Segment segment = 3;
// Tensor content must be organized in row-major order.
//
// Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor.
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component apparing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16 values
// float16 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16
repeated int32 int32_data = 5 [packed = true];
// For strings.
// Each element of string_data is a UTF-8 encoded Unicode
// string. No trailing null, no leading BOM. The protobuf "string"
// scalar type is not used to match ML community conventions.
// When this field is present, the data_type field MUST be STRING
repeated bytes string_data = 6;
// For int64.
// When this field is present, the data_type field MUST be INT64
repeated int64 int64_data = 7 [packed = true];
// Optionally, a name for the tensor.
optional string name = 8; // namespace Value
// A human-readable documentation for this tensor. Markdown is allowed.
optional string doc_string = 12;
// Serializations can either use one of the fields above, or use this
// raw bytes field. The only exception is the string case, where one is
// required to store the content in the repeated bytes string_data field.
//
// When this raw_data field is used to store tensor value, elements MUST
// be stored in as fixed-width, little-endian order.
// Floating-point data types MUST be stored in IEEE 754 format.
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
// variable length storage, and may lead to smaller binary footprint.
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
optional bytes raw_data = 9;
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;
// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
optional DataLocation data_location = 14;
// For double
// Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component apparing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
repeated double double_data = 10 [packed = true];
// For uint64 and uint32 values
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];
}
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
message TensorShapeProto {
message Dimension {
oneof value {
int64 dim_value = 1;
string dim_param = 2; // namespace Shape
};
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
optional string denotation = 3;
};
repeated Dimension dim = 1;
}
// Types
//
// The standard ONNX data types.
message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
oneof value {
// The type of a tensor.
Tensor tensor_type = 1;
}
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
optional string denotation = 6;
}
// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
// The domain of the operator set being identified.
// The empty string ("") or absence of this field implies the operator
// set that is defined as part of the ONNX specification.
// This field MUST be present in this version of the IR when referring to any other operator set.
optional string domain = 1;
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
optional int64 version = 2;
}
// Tencent is pleased to support the open source community by making ncnn
// available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this
// file except in compliance with the License. You may obtain a copy of the
// License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
#include <algorithm>
#include <fstream>
#include <iostream>
#include <limits>
#include <set>
#include <tuple>
#include "fuse_pass.h"
#include "shape_inference.h"
#include "utils.h"
int main(int argc, char** argv) {
if (!(argc == 2 || argc == 4)) {
fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]);
return -1;
}
const char* onnxpb = argv[1];
const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param";
const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin";
onnx::ModelProto model;
// load
bool s1 = read_proto_from_binary(onnxpb, &model);
if (!s1) {
fprintf(stderr, "read_proto_from_binary failed\n");
return -1;
}
FILE* pp = fopen(ncnn_prototxt, "wb");
FILE* bp = fopen(ncnn_modelbin, "wb");
// magic
fprintf(pp, "7767517\n");
onnx::GraphProto* mutable_graph = model.mutable_graph();
int node_count = mutable_graph->node_size();
// node reference
std::map<std::string, int> node_reference;
// weight node and weight reshape node
std::map<std::string, onnx::TensorProto> weights;
for (int j = 0; j < mutable_graph->initializer_size(); j++) {
const onnx::TensorProto& initializer = mutable_graph->initializer(j);
// fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(),
// initializer.data_type());
weights[initializer.name()] = initializer;
}
// topological sort
{
// name -> producer node index
std::set<std::string> producers;
for (int j = 0; j < mutable_graph->input_size(); j++) {
const std::string& input_name = mutable_graph->input(j).name();
producers.insert(input_name);
}
for (int i = 0; i < node_count;) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
bool swapnode = false;
std::string missing_input_name;
for (int j = 0; j < (int)node->input_size(); j++) {
const std::string& input_name = node->input(j);
if (input_name.empty()) continue;
if (producers.find(input_name) == producers.end() &&
weights.find(input_name) == weights.end()) {
swapnode = true;
missing_input_name = input_name;
break;
}
}
if (!swapnode) {
for (int j = 0; j < (int)node->output_size(); j++) {
const std::string& output_name = node->output(j);
if (output_name.empty()) continue;
producers.insert(output_name);
}
i++;
continue;
}
// find node that produce missing_input_name
int q = i + 1;
for (; q < node_count; q++) {
onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
bool found = false;
for (int j = 0; j < (int)nodeq->output_size(); j++) {
const std::string& output_name = nodeq->output(j);
if (output_name == missing_input_name) {
found = true;
break;
}
}
if (found) break;
}
if (q == node_count) {
fprintf(stderr, "cannot find node produces %s but node %d requires it\n",
missing_input_name.c_str(), i);
return -1;
}
// fprintf(stderr, "swap %d %d\n", i, q);
// swap this node with q
onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
onnx::NodeProto tmp = *node;
*node = *nodeq;
*nodeq = tmp;
}
}
// global definition line
// [layer count] [blob count]
std::set<std::string> blob_names;
for (int i = 0; i < node_count; i++) {
const onnx::NodeProto& node = mutable_graph->node(i);
const std::string& op = node.op_type();
std::string name = node.name();
if (name.empty()) {
name = node.output(0);
}
if (op == "Constant") {
onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
weights[node.output(0)] = tensor;
}
for (int j = 0; j < (int)node.input_size(); j++) {
const std::string& input_name = node.input(j);
blob_names.insert(input_name);
if (node_reference.find(input_name) == node_reference.end()) {
node_reference[input_name] = 1;
} else {
node_reference[input_name] = node_reference[input_name] + 1;
}
}
if (op == "Dropout") {
const std::string& output_name = node.output(0);
blob_names.insert(output_name);
node_reference[output_name] = 0;
continue;
}
for (int j = 0; j < (int)node.output_size(); j++) {
const std::string& output_name = node.output(j);
blob_names.insert(output_name);
node_reference[output_name] = 0;
}
}
// include Input node
int input_node_count = 0;
for (int j = 0; j < mutable_graph->input_size(); j++) {
const std::string& input_name = mutable_graph->input(j).name();
// check weight
if (weights.find(input_name) != weights.end()) continue;
blob_names.insert(input_name);
input_node_count++;
}
// for (auto a: node_reference)
// {
// fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
// }
// op chain fusion
int reduced_node_count = 0;
{
fuse_identity(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names,
reduced_node_count);
fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names,
reduced_node_count);
fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names,
reduced_node_count);
fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
}
// reduce common const weight node_reference
for (int i = 0; i < node_count; i++) {
const onnx::NodeProto& node = mutable_graph->node(i);
const std::string& op = node.op_type();
if (op == "BatchNormalization") {
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
node_reference[node.input(3)] -= 1;
node_reference[node.input(4)] -= 1;
} else if (op == "BiasGelu") {
node_reference[node.input(1)] -= 1;
} else if (op == "Clip") {
if (node.input_size() == 3) {
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
}
} else if (op == "Conv") {
node_reference[node.input(1)] -= 1;
if (node.input_size() == 3) {
node_reference[node.input(2)] -= 1;
}
} else if (op == "ConvTranspose") {
node_reference[node.input(1)] -= 1;
if (node.input_size() == 3) {
node_reference[node.input(2)] -= 1;
}
} else if (op == "EmbedLayerNormalization") {
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
node_reference[node.input(3)] -= 1;
node_reference[node.input(4)] -= 1;
node_reference[node.input(5)] -= 1;
node_reference[node.input(6)] -= 1;
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);
int transA = get_node_attr_i(node, "transA", 0);
int transB = get_node_attr_i(node, "transB", 0);
if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) {
// InnerProduct-like A * B + C, C is optional.
node_reference[node.input(1)] -= 1;
if (node.input_size() == 3) {
node_reference[node.input(2)] -= 1;
}
}
} else if (op == "GroupNorm") {
int affine = get_node_attr_i(node, "affine", 1);
if (affine) {
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
}
} else if (op == "GRU") {
for (int j = 1; j < node.input_size(); j++) {
node_reference[node.input(j)] -= 1;
}
} else if (op == "InstanceNormalization") {
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
} else if (op == "LayerNorm") {
int affine = get_node_attr_i(node, "affine", 1);
if (affine) {
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
}
} else if (op == "LSTM") {
for (int j = 1; j < node.input_size(); j++) {
node_reference[node.input(j)] -= 1;
}
} else if (op == "MatMul") {
if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) {
// InnerProduct
node_reference[node.input(1)] -= 1;
}
} else if (op == "MultiHeadAttention") {
if (node.input_size() == 5) {
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
node_reference[node.input(3)] -= 1;
node_reference[node.input(4)] -= 1;
} else {
node_reference[node.input(3)] -= 1;
node_reference[node.input(4)] -= 1;
node_reference[node.input(5)] -= 1;
node_reference[node.input(6)] -= 1;
node_reference[node.input(7)] -= 1;
node_reference[node.input(8)] -= 1;
node_reference[node.input(9)] -= 1;
node_reference[node.input(10)] -= 1;
}
} else if (op == "NonMaxSuppression") {
if (node.input_size() >= 3) {
node_reference[node.input(2)] -= 1;
}
if (node.input_size() >= 4) {
node_reference[node.input(3)] -= 1;
}
if (node.input_size() >= 5) {
node_reference[node.input(4)] -= 1;
}
} else if (op == "Pad") {
if (node.input_size() >= 2) {
node_reference[node.input(1)] -= 1;
}
} else if (op == "PRelu") {
node_reference[node.input(1)] -= 1;
} else if (op == "Reshape") {
if (node.input_size() == 2) {
if (weights[node.input(1)].data_type() != 0) {
node_reference[node.input(1)] -= 1;
}
}
} else if (op == "Resize") {
if (node.input_size() == 2) {
// opset 10
node_reference[node.input(1)] -= 1;
} else {
// opset 11+
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
if (node.input_size() >= 4) {
node_reference[node.input(3)] -= 1;
}
}
} else if (op == "RNN") {
for (int j = 1; j < node.input_size(); j++) {
node_reference[node.input(j)] -= 1;
}
} else if (op == "SkipLayerNormalization") {
node_reference[node.input(2)] -= 1;
node_reference[node.input(3)] -= 1;
node_reference[node.input(4)] -= 1;
} else if (op == "Slice") {
if (node.input_size() >= 2) {
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
if (node.input_size() >= 4) node_reference[node.input(3)] -= 1;
if (node.input_size() >= 5) node_reference[node.input(4)] -= 1;
}
} else if (op == "Upsample") {
if (node.input_size() >= 2) {
node_reference[node.input(1)] -= 1;
}
} else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" ||
op == "adaptive_max_pool2d") {
if (node.input_size() >= 2) {
node_reference[node.input(1)] -= 1;
}
}
}
// for (auto a: node_reference)
// {
// fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
// }
// count all weight node with zero reference
int zero_reference_weight_node_count = 0;
for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end();
it++) {
const std::string& input_name = it->first;
int refcount = node_reference[input_name];
if (refcount == 0) zero_reference_weight_node_count++;
}
// we always treat constant node as weight or binaryop_weights
// do not count it twice for layer_count
int constant_node_count_moved_to_weight = 0;
for (int i = 0; i < node_count; i++) {
const onnx::NodeProto& node = mutable_graph->node(i);
const std::string& op = node.op_type();
if (op == "Constant") {
constant_node_count_moved_to_weight++;
}
}
// some op may have anonymous input
// LSTM sequence_lens
blob_names.erase("");
node_reference.erase("");
// remove node_reference entry with reference equals to one
int split_layer_count = 0;
int splitncnn_blob_count = 0;
// split node reference
std::map<std::string, int> split_node_reference;
for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end();
it++) {
if (it->second > 1) {
split_layer_count++;
splitncnn_blob_count += it->second;
split_node_reference[it->first] = it->second;
}
}
fprintf(pp, "%zu %zu\n",
node_count - constant_node_count_moved_to_weight + weights.size() -
zero_reference_weight_node_count - reduced_node_count + input_node_count +
split_layer_count,
blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
int internal_split = 0;
// place Input at the beginning
for (int j = 0; j < mutable_graph->input_size(); j++) {
const std::string& input_name = mutable_graph->input(j).name();
// check weight
if (weights.find(input_name) != weights.end()) continue;
fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
int refcount = node_reference[input_name];
if (refcount <= 1) {
continue;
}
char splitname[256];
sprintf(splitname, "splitncnn_input%d", j);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
fprintf(pp, " %s", input_name.c_str());
for (int k = 0; k < refcount; k++) {
fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
}
fprintf(pp, "\n");
}
// place MemoryData next
for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin();
weight_it != weights.end(); weight_it++) {
const std::string& input_name = weight_it->first;
int refcount = node_reference[input_name];
if (refcount == 0) {
continue;
}
fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
const onnx::TensorProto& M = weights[input_name];
if (M.dims_size() == 0) {
fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
} else if (M.dims_size() == 1) {
fprintf(pp, " 0=%d", (int)M.dims(0));
} else if (M.dims_size() == 2) {
fprintf(pp, " 0=%d", (int)M.dims(1));
if (M.dims(0) != 1) {
fprintf(pp, " 1=%d", (int)M.dims(0));
}
} else if (M.dims_size() == 3) {
fprintf(pp, " 0=%d", (int)M.dims(2));
fprintf(pp, " 1=%d", (int)M.dims(1));
if (M.dims(0) != 1) {
fprintf(pp, " 2=%d", (int)M.dims(0));
}
} else if (M.dims_size() == 4) {
fprintf(pp, " 0=%d", (int)M.dims(3));
fprintf(pp, " 1=%d", (int)M.dims(2));
fprintf(pp, " 2=%d", (int)M.dims(1));
}
fprintf(pp, "\n");
if (M.data_type() == 1) {
fwrite_tensor_proto_data(M, bp);
} else if (M.data_type() == 7 || M.data_type() == 6 || M.data_type() == 9 ||
M.data_type() == 11) {
fwrite_tensor_proto_data_to_float(M, bp);
} else {
fwrite_tensor_proto_data(M, bp);
}
if (refcount <= 1) {
continue;
}
char splitname[256];
sprintf(splitname, "splitncnn_%d", internal_split);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
fprintf(pp, " %s", input_name.c_str());
for (int k = 0; k < refcount; k++) {
fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
}
fprintf(pp, "\n");
internal_split++;
}
for (int i = 0; i < node_count; i++) {
const onnx::NodeProto& node = mutable_graph->node(i);
const std::string& op = node.op_type();
// fprintf(stderr, "op = %s\n", op.c_str());
if (op == "noop_reducedncnn") {
continue;
}
std::string name = node.name();
if (name.empty()) {
name = node.output(0);
}
int input_size = node.input_size();
int output_size = node.output_size();
for (int j = 0; j < (int)node.input_size(); j++) {
const std::string& input_name = node.input(j);
// check weight
if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) {
input_size--;
}
if (input_name.empty()) {
input_size--;
}
// fprintf(stderr, " input = %s\n", input_name.c_str());
}
/*
for (int j=0; j<(int)node.output_size(); j++)
{
const std::string& output_name = node.output(j);
fprintf(stderr, " output = %s\n", output_name.c_str());
}
*/
if (op == "Abs") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Acos") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Add") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "ArgMax") {
fprintf(pp, "%-16s", "TopK");
} else if (op == "Asin") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Atan") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "AveragePool" || op == "MaxPool") {
std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
if (kernel_shape.size() == 1) {
fprintf(pp, "%-16s", "Pooling1D");
} else {
fprintf(pp, "%-16s", "Pooling");
}
} else if (op == "BatchNormalization") {
fprintf(pp, "%-16s", "BatchNorm");
} else if (op == "BiasGelu") {
fprintf(pp, "%-16s", "BiasGelu");
} else if (op == "Cast") {
fprintf(pp, "%-16s", "Noop");
} else if (op == "Ceil") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Clip") {
fprintf(pp, "%-16s", "Clip");
} else if (op == "Concat") {
fprintf(pp, "%-16s", "Concat");
} else if (op == "Constant") {
continue;
} else if (op == "ConstantOfShape") {
fprintf(pp, "%-16s", "ConstantOfShape");
} else if (op == "Conv") {
std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
if (kernel_shape.size() == 1) {
fprintf(pp, "%-16s", "Convolution1D");
} else {
int group = get_node_attr_i(node, "group", 1);
if (group > 1) {
fprintf(pp, "%-16s", "ConvolutionDepthWise");
} else {
fprintf(pp, "%-16s", "Convolution");
}
}
} else if (op == "ConvTranspose") {
int group = get_node_attr_i(node, "group", 1);
if (group > 1) {
fprintf(pp, "%-16s", "DeconvolutionDepthWise");
} else {
fprintf(pp, "%-16s", "Deconvolution");
}
} else if (op == "Cos") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Crop") {
fprintf(pp, "%-16s", "Crop");
} else if (op == "DepthToSpace") {
fprintf(pp, "%-16s", "PixelShuffle");
} else if (op == "DetectionOutput") {
fprintf(pp, "%-16s", "DetectionOutput");
} else if (op == "Div") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "Dropout") {
fprintf(pp, "%-16s", "Dropout");
output_size = 1;
} else if (op == "Elu") {
fprintf(pp, "%-16s", "ELU");
} else if (op == "EmbedLayerNormalization") {
fprintf(pp, "%-16s", "EmbedLayerNormalization");
} else if (op == "Equal") {
fprintf(pp, "%-16s", "Compare");
} else if (op == "Exp") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Expand") {
fprintf(pp, "%-16s", "Expand");
} else if (op == "Flatten") {
fprintf(pp, "%-16s", "Flatten");
} else if (op == "Floor") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Gather") {
fprintf(pp, "%-16s", "Gather");
} else if (op == "Gelu") {
fprintf(pp, "%-16s", "GELU");
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);
int transA = get_node_attr_i(node, "transA", 0);
int transB = get_node_attr_i(node, "transB", 0);
if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) {
// InnerProduct-like A * B + C
fprintf(pp, "%-16s", "InnerProduct");
} else {
fprintf(pp, "%-16s", "Gemm");
}
} else if (op == "GlobalAveragePool") {
fprintf(pp, "%-16s", "Pooling");
} else if (op == "GlobalMaxPool") {
fprintf(pp, "%-16s", "Pooling");
} else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" ||
op == "adaptive_max_pool2d") {
fprintf(pp, "%-16s", "Pooling");
} else if (op == "GroupNorm") {
fprintf(pp, "%-16s", "GroupNorm");
} else if (op == "GRU") {
fprintf(pp, "%-16s", "GRU");
} else if (op == "HardSigmoid") {
fprintf(pp, "%-16s", "HardSigmoid");
} else if (op == "HardSwish") {
fprintf(pp, "%-16s", "HardSwish");
} else if (op == "ImageScaler") {
fprintf(pp, "%-16s", "Scale");
} else if (op == "InstanceNormalization") {
fprintf(pp, "%-16s", "InstanceNorm");
} else if (op == "LayerNorm") {
fprintf(pp, "%-16s", "LayerNorm");
} else if (op == "LeakyRelu") {
fprintf(pp, "%-16s", "ReLU");
} else if (op == "Threshold") {
fprintf(pp, "%-16s", "Threshold");
} else if (op == "Log") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "LRN") {
fprintf(pp, "%-16s", "LRN");
} else if (op == "LSTM") {
fprintf(pp, "%-16s", "LSTM");
} else if (op == "MatMul") {
if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) {
fprintf(pp, "%-16s", "InnerProduct");
} else {
fprintf(pp, "%-16s", "Gemm");
}
} else if (op == "Max") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "Min") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "Mul") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "MultiHeadAttention") {
fprintf(pp, "%-16s", "MultiHeadAttention");
} else if (op == "Neg") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "NonMaxSuppression") {
fprintf(pp, "%-16s", "NonMaxSuppression");
} else if (op == "Normalize") {
fprintf(pp, "%-16s", "Normalize");
} else if (op == "Pad") {
fprintf(pp, "%-16s", "Padding");
} else if (op == "PixelShuffle") {
fprintf(pp, "%-16s", "PixelShuffle");
} else if (op == "Pow") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "PriorBox") {
fprintf(pp, "%-16s", "PriorBox");
} else if (op == "PRelu") {
fprintf(pp, "%-16s", "PReLU");
} else if (op == "Range") {
fprintf(pp, "%-16s", "Range");
} else if (op == "Reciprocal") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" ||
op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" ||
op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") {
fprintf(pp, "%-16s", "Reduction");
} else if (op == "Relu") {
fprintf(pp, "%-16s", "ReLU");
} else if (op == "Reorg") {
fprintf(pp, "%-16s", "Reorg");
} else if (op == "Reshape") {
fprintf(pp, "%-16s", "Reshape");
} else if (op == "RNN") {
fprintf(pp, "%-16s", "RNN");
} else if (op == "RDiv") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "RSub") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "RoiAlign") {
fprintf(pp, "%-16s", "ROIAlign");
} else if (op == "ScatterND") {
fprintf(pp, "%-16s", "ScatterND");
} else if (op == "Shape") {
fprintf(pp, "%-16s", "Shape");
} else if (op == "ShuffleChannel") {
fprintf(pp, "%-16s", "ShuffleChannel");
} else if (op == "Sigmoid") {
fprintf(pp, "%-16s", "Sigmoid");
} else if (op == "Sin") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "SkipLayerNormalization") {
fprintf(pp, "%-16s", "SkipLayerNormalization");
} else if (op == "Slice") {
std::vector<int> ends;
std::vector<int> steps;
bool use_crop = true;
if (node.input_size() == 1) {
ends = get_node_attr_ai(node, "ends");
steps = get_node_attr_ai(node, "steps"); // TODO
} else {
ends = get_node_attr_from_input_ai(weights[node.input(2)]);
if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]);
}
// assert step == 1
for (int i = 0; i < (int)steps.size(); i++) {
if (steps[i] != 1 && steps[i] < ends[i]) {
use_crop = false;
break;
}
}
if (use_crop) {
fprintf(pp, "%-16s", "Crop");
} else {
fprintf(pp, "%-16s", "TensorSlice");
}
} else if (op == "Softmax") {
fprintf(pp, "%-16s", "Softmax");
} else if (op == "Softplus") {
fprintf(pp, "%-16s", "Softplus");
} else if (op == "Split") {
fprintf(pp, "%-16s", "Slice");
} else if (op == "Sqrt") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Squeeze") {
std::vector<int> axes = get_node_attr_ai(node, "axes");
// fprintf(stderr, "axes[0]: %d\n",axes[0]);
if (axes[0] == 0) {
fprintf(pp, "%-16s", "Noop");
} else {
fprintf(pp, "%-16s", "Squeeze");
}
} else if (op == "Sub") {
fprintf(pp, "%-16s", "BinaryOp");
} else if (op == "Sum") {
fprintf(pp, "%-16s", "Eltwise");
} else if (op == "Swish") {
fprintf(pp, "%-16s", "Swish");
} else if (op == "Tan") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Tanh") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Tile") {
fprintf(pp, "%-16s", "TileOnnx");
} else if (op == "TopK") {
fprintf(pp, "%-16s", "TopK");
} else if (op == "Transpose") {
fprintf(pp, "%-16s", "Permute");
} else if (op == "Upsample" || op == "Resize") {
fprintf(pp, "%-16s", "Interp");
} else if (op == "Unsqueeze") {
std::vector<int> axes = get_node_attr_ai(node, "axes");
// fprintf(stderr, "axes[0]: %d\n",axes[0]);
if (axes[0] == 0) {
fprintf(pp, "%-16s", "Noop");
} else {
fprintf(pp, "%-16s", "ExpandDims");
}
} else if (op == "Where") {
fprintf(pp, "%-16s", "Where");
} else if (op == "Yolov3DetectionOutput") {
fprintf(pp, "%-16s", "Yolov3DetectionOutput");
} else {
// TODO
fprintf(stderr, "%s not supported yet!\n", op.c_str());
fprintf(pp, "%-16s", op.c_str());
}
fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
for (int j = 0; j < (int)node.input_size(); j++) {
std::string input_name = node.input(j);
// check weight
if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) {
continue;
}
if (input_name.empty()) {
continue;
}
if (split_node_reference.find(input_name) != split_node_reference.end()) {
int refidx = split_node_reference[input_name] - 1;
split_node_reference[input_name] = refidx;
char splitsuffix[256];
sprintf(splitsuffix, "_splitncnn_%d", refidx);
input_name = input_name + splitsuffix;
}
fprintf(pp, " %s", input_name.c_str());
}
for (int j = 0; j < output_size; j++) {
const std::string& output_name = node.output(j);
fprintf(pp, " %s", output_name.c_str());
}
if (op == "Abs") {
int op_type = 0;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Acos") {
int op_type = 13;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Add") {
int op_type = 0;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "ArgMax") {
int axis = get_node_attr_i(node, "axis");
int keepdims = get_node_attr_i(node, "keepdims");
fprintf(pp, " 0=%d", axis - 1);
fprintf(pp, " 3=%d", keepdims);
} else if (op == "Asin") {
int op_type = 12;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Atan") {
int op_type = 14;
fprintf(pp, " 0=%d", op_type);
} else if (op == "AveragePool" || op == "MaxPool") {
std::string auto_pad = get_node_attr_s(node, "auto_pad");
int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
std::vector<int> strides = get_node_attr_ai(node, "strides");
std::vector<int> pads = get_node_attr_ai(node, "pads");
int pool = op == "AveragePool" ? 1 : 0;
int pad_mode = 1;
if (auto_pad == "SAME_UPPER") {
pad_mode = 2;
} else if (auto_pad == "SAME_LOWER") {
pad_mode = 3;
}
if (ceil_mode == 1) {
pad_mode = 0;
}
fprintf(pp, " 0=%d", pool);
if (kernel_shape.size() == 1) {
fprintf(pp, " 1=%d", kernel_shape[0]);
} else if (kernel_shape.size() == 2) {
fprintf(pp, " 1=%d", kernel_shape[1]);
fprintf(pp, " 11=%d", kernel_shape[0]);
}
if (strides.size() == 1) {
fprintf(pp, " 2=%d", strides[0]);
} else if (strides.size() == 2) {
fprintf(pp, " 2=%d", strides[1]);
fprintf(pp, " 12=%d", strides[0]);
}
if (pads.size() == 1) {
fprintf(pp, " 3=%d", pads[0]);
} else if (pads.size() == 2) {
fprintf(pp, " 3=%d", pads[1]);
fprintf(pp, " 13=%d", pads[0]);
} else if (pads.size() == 4) {
fprintf(pp, " 3=%d", pads[1]);
fprintf(pp, " 13=%d", pads[0]);
fprintf(pp, " 14=%d", pads[3]);
fprintf(pp, " 15=%d", pads[2]);
}
fprintf(pp, " 5=%d", pad_mode);
if (op == "AveragePool") {
int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
fprintf(pp, " 6=%d", avgpool_count_include_pad);
}
} else if (op == "BatchNormalization") {
float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
const onnx::TensorProto& scale = weights[node.input(1)];
const onnx::TensorProto& B = weights[node.input(2)];
const onnx::TensorProto& mean = weights[node.input(3)];
const onnx::TensorProto& var = weights[node.input(4)];
int channels = get_tensor_proto_data_size(scale);
fprintf(pp, " 0=%d", channels);
fwrite_tensor_proto_data(scale, bp);
fwrite_tensor_proto_data(mean, bp);
// apply epsilon to var
{
const float* v =
var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
for (int j = 0; j < channels; j++) {
float ve = v[j] + epsilon;
fwrite(&ve, sizeof(float), 1, bp);
}
}
fwrite_tensor_proto_data(B, bp);
} else if (op == "BiasGelu") {
const onnx::TensorProto& B = weights[node.input(1)];
fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(B, bp);
} else if (op == "Ceil") {
int op_type = 3;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Clip") {
float min;
float max;
if (node.input_size() == 1) {
min = get_node_attr_f(node, "min", -FLT_MAX);
max = get_node_attr_f(node, "max", FLT_MAX);
} else {
min = weights.find(node.input(1)) != weights.end()
? get_node_attr_from_input<float>(weights[node.input(1)])
: -FLT_MAX;
max = weights.find(node.input(2)) != weights.end()
? get_node_attr_from_input<float>(weights[node.input(2)])
: FLT_MAX;
}
fprintf(pp, " 0=%e", min);
fprintf(pp, " 1=%e", max);
} else if (op == "Concat") {
int axis = get_node_attr_i(node, "axis", 1);
fprintf(pp, " 0=%d", axis - 1);
} else if (op == "Constant") {
// never reach here
} else if (op == "ConstantOfShape") {
float value = 0.f;
value = get_node_attr_f(node, "value", 0.f);
fprintf(pp, " 0=%f", value);
} else if (op == "Conv") {
const onnx::TensorProto& W = weights[node.input(1)];
int num_filter = W.dims(0);
int has_bias = node.input_size() == 3 ? 1 : 0;
std::string auto_pad = get_node_attr_s(node, "auto_pad");
std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
std::vector<int> dilations = get_node_attr_ai(node, "dilations");
std::vector<int> strides = get_node_attr_ai(node, "strides");
std::vector<int> pads = get_node_attr_ai(node, "pads");
int group = get_node_attr_i(node, "group", 1);
fprintf(pp, " 0=%d", num_filter);
if (kernel_shape.size() == 1) {
fprintf(pp, " 1=%d", kernel_shape[0]);
} else if (kernel_shape.size() == 2) {
fprintf(pp, " 1=%d", kernel_shape[1]);
fprintf(pp, " 11=%d", kernel_shape[0]);
}
if (dilations.size() == 1) {
fprintf(pp, " 2=%d", dilations[0]);
} else if (dilations.size() == 2) {
fprintf(pp, " 2=%d", dilations[1]);
fprintf(pp, " 12=%d", dilations[0]);
}
if (strides.size() == 1) {
fprintf(pp, " 3=%d", strides[0]);
} else if (strides.size() == 2) {
fprintf(pp, " 3=%d", strides[1]);
fprintf(pp, " 13=%d", strides[0]);
}
if (auto_pad == "SAME_UPPER") {
fprintf(pp, " 4=-233");
} else if (auto_pad == "SAME_LOWER") {
fprintf(pp, " 4=-234");
} else {
if (pads.size() == 1) {
fprintf(pp, " 4=%d", pads[0]);
} else if (pads.size() == 2) {
fprintf(pp, " 4=%d", pads[1]);
fprintf(pp, " 14=%d", pads[0]);
} else if (pads.size() == 4) {
fprintf(pp, " 4=%d", pads[1]);
fprintf(pp, " 14=%d", pads[0]);
fprintf(pp, " 15=%d", pads[3]);
fprintf(pp, " 16=%d", pads[2]);
}
}
fprintf(pp, " 5=%d", has_bias);
fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
if (group > 1) {
fprintf(pp, " 7=%d", group);
}
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(W, bp);
if (has_bias) {
const onnx::TensorProto& B = weights[node.input(2)];
fwrite_tensor_proto_data(B, bp);
}
} else if (op == "ConvTranspose") {
const onnx::TensorProto& W = weights[node.input(1)];
int has_bias = node.input_size() == 3 ? 1 : 0;
std::string auto_pad = get_node_attr_s(node, "auto_pad");
std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
std::vector<int> dilations = get_node_attr_ai(node, "dilations");
std::vector<int> strides = get_node_attr_ai(node, "strides");
std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
std::vector<int> pads = get_node_attr_ai(node, "pads");
int group = get_node_attr_i(node, "group", 1);
int num_filter = W.dims(1) * group;
fprintf(pp, " 0=%d", num_filter);
if (kernel_shape.size() == 1) {
fprintf(pp, " 1=%d", kernel_shape[0]);
} else if (kernel_shape.size() == 2) {
fprintf(pp, " 1=%d", kernel_shape[1]);
fprintf(pp, " 11=%d", kernel_shape[0]);
}
if (dilations.size() == 1) {
fprintf(pp, " 2=%d", dilations[0]);
} else if (dilations.size() == 2) {
fprintf(pp, " 2=%d", dilations[1]);
fprintf(pp, " 12=%d", dilations[0]);
}
if (strides.size() == 1) {
fprintf(pp, " 3=%d", strides[0]);
} else if (strides.size() == 2) {
fprintf(pp, " 3=%d", strides[1]);
fprintf(pp, " 13=%d", strides[0]);
}
if (auto_pad == "SAME_UPPER") {
fprintf(pp, " 4=-233");
} else if (auto_pad == "SAME_LOWER") {
fprintf(pp, " 4=-234");
} else {
if (pads.size() == 1) {
fprintf(pp, " 4=%d", pads[0]);
} else if (pads.size() == 2) {
fprintf(pp, " 4=%d", pads[1]);
fprintf(pp, " 14=%d", pads[0]);
} else if (pads.size() == 4) {
fprintf(pp, " 4=%d", pads[1]);
fprintf(pp, " 14=%d", pads[0]);
fprintf(pp, " 15=%d", pads[3]);
fprintf(pp, " 16=%d", pads[2]);
}
}
if (output_padding.size() == 1) {
fprintf(pp, " 18=%d", output_padding[0]);
} else if (output_padding.size() == 2) {
fprintf(pp, " 18=%d", output_padding[1]);
fprintf(pp, " 19=%d", output_padding[0]);
}
if (output_shape.size() == 1) {
fprintf(pp, " 20=%d", output_shape[0]);
} else if (output_shape.size() == 2) {
fprintf(pp, " 20=%d", output_shape[1]);
fprintf(pp, " 21=%d", output_shape[0]);
}
fprintf(pp, " 5=%d", has_bias);
fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
if (group > 1) {
fprintf(pp, " 7=%d", group);
}
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
int maxk = 0;
if (kernel_shape.size() == 2) {
maxk = kernel_shape[1] * kernel_shape[0];
} else {
maxk = kernel_shape[0] * kernel_shape[0];
}
int weight_data_size = get_tensor_proto_data_size(W);
const float* weight_data = 0;
if (W.has_raw_data()) {
weight_data = (const float*)W.raw_data().data();
} else if (W.data_type() == 1) {
weight_data = W.float_data().data();
}
for (int g = 0; g < group; g++) {
// reorder weight from inch-outch to outch-inch
int num_filter_g = num_filter / group;
int num_input = weight_data_size / maxk / num_filter_g / group;
const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
for (int k = 0; k < num_filter_g; k++) {
for (int j = 0; j < num_input; j++) {
fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
}
}
}
if (has_bias) {
const onnx::TensorProto& B = weights[node.input(2)];
fwrite_tensor_proto_data(B, bp);
}
} else if (op == "Cos") {
int op_type = 10;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Crop") {
auto starts = get_node_attr_ai(node, "starts");
fprintf(pp, " -23309=%zu", starts.size());
for (size_t j = 0; j < starts.size(); ++j) {
fprintf(pp, ",%i", starts[j]);
}
auto ends = get_node_attr_ai(node, "ends");
fprintf(pp, " -23310=%zu", ends.size());
for (size_t j = 0; j < ends.size(); ++j) {
fprintf(pp, ",%i", ends[j]);
}
auto axis = get_node_attr_ai(node, "axis");
fprintf(pp, " -23311=%zu", axis.size());
for (size_t j = 0; j < axis.size(); ++j) {
fprintf(pp, ",%i", axis[j]);
}
} else if (op == "DepthToSpace") {
// pixelshuffle
int scale_factor = get_node_attr_i(node, "blocksize", 1);
std::string mode = get_node_attr_s(node, "mode");
fprintf(pp, " 0=%d", scale_factor);
if (mode == "CRD") {
fprintf(pp, " 1=0");
} else if (mode == "DCR") {
fprintf(pp, " 1=1");
}
} else if (op == "DetectionOutput") {
float score_threshold = get_node_attr_f(node, "score_threshold");
float nms_threshold = get_node_attr_f(node, "nms_threshold");
int nms_top_k = get_node_attr_i(node, "nms_top_k");
int keep_top_k = get_node_attr_i(node, "keep_top_k");
int num_class = get_node_attr_i(node, "num_class");
std::vector<float> vars = get_node_attr_af(node, "vars");
fprintf(pp, " 0=%d", num_class);
fprintf(pp, " 1=%f", nms_threshold);
fprintf(pp, " 2=%d", nms_top_k);
fprintf(pp, " 3=%d", keep_top_k);
fprintf(pp, " 4=%f", score_threshold);
fprintf(pp, " 5=%f", vars[0]);
fprintf(pp, " 6=%f", vars[1]);
fprintf(pp, " 7=%f", vars[2]);
fprintf(pp, " 8=%f", vars[3]);
} else if (op == "Div") {
int op_type = 3;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "Dropout") {
// no-op
} else if (op == "Elu") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
fprintf(pp, " 0=%e", alpha);
} else if (op == "EmbedLayerNormalization") {
const onnx::TensorProto& words = weights[node.input(2)];
const onnx::TensorProto& positions = weights[node.input(3)];
const onnx::TensorProto& W = weights[node.input(5)];
const onnx::TensorProto& B = weights[node.input(6)];
fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(words, bp);
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(positions, bp);
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(W, bp);
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(B, bp);
} else if (op == "Equal") {
int op_type = 0;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Exp") {
int op_type = 7;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Flatten") {
int axis = get_node_attr_i(node, "axis", 1);
if (axis != 1) {
fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
}
} else if (op == "Floor") {
int op_type = 2;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Gather") {
if (weights[node.input(1)].dims_size() > 1) {
fprintf(stderr, "Unsupported indice dims > 1");
}
int axis = get_node_attr_i(node, "axis", 1) - 1;
if (axis < 0) {
fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1);
}
fprintf(pp, " 0=%d", axis);
} else if (op == "Gelu") {
fprintf(pp, " 0=1");
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);
int transA = get_node_attr_i(node, "transA", 0);
int transB = get_node_attr_i(node, "transB", 0);
if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) {
// InnerProduct-like A * B + C
const onnx::TensorProto& B = weights[node.input(1)];
// B has transposed.
int num_output = B.dims(0);
fprintf(pp, " 0=%d", num_output);
if (node.input_size() == 3) {
fprintf(pp, " 1=1");
} else {
fprintf(pp, " 1=0");
}
fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(B, bp);
if (node.input_size() == 3) {
const onnx::TensorProto& C = weights[node.input(2)];
fwrite_tensor_proto_data(C, bp);
}
} else {
// gemm
fprintf(pp, " 0=%e", alpha);
fprintf(pp, " 1=%e", beta);
fprintf(pp, " 2=%d", transA);
fprintf(pp, " 3=%d", transB);
}
} else if (op == "GlobalAveragePool") {
int pool = 1;
int global_pool = 1;
fprintf(pp, " 0=%d", pool);
fprintf(pp, " 4=%d", global_pool);
} else if (op == "GlobalMaxPool") {
int pool = 0;
int global_pool = 1;
fprintf(pp, " 0=%d", pool);
fprintf(pp, " 4=%d", global_pool);
} else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" ||
op == "adaptive_max_pool2d") {
int pool = 0;
if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d") {
pool = 1;
}
int adaptive_pooling = 1;
const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
fprintf(pp, " 0=%d", pool);
fprintf(pp, " 7=%d", adaptive_pooling);
if (out_shape.size() == 1) {
fprintf(pp, " 8=%d", out_shape[0]);
} else if (out_shape.size() == 2) {
// out_w
fprintf(pp, " 8=%d", out_shape[1]);
// out_h
fprintf(pp, " 18=%d", out_shape[0]);
}
} else if (op == "GroupNorm") {
int groups = get_node_attr_i(node, "groups", 1);
int channels = get_node_attr_i(node, "channels", 1);
float eps = get_node_attr_f(node, "epsilon", 1e-5f);
int affine = get_node_attr_i(node, "affine", 1);
if (affine) {
// discard affine-less S=1 B=0
std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 &&
affine_B[0] == 0.f) {
affine = 0;
} else {
affine = 0;
{
for (int j = 0; j < channels; j++) {
if (affine_S[j] != 1.f || affine_B[j] != 0.f) {
affine = 1;
break;
}
}
}
}
}
fprintf(pp, " 0=%d", groups);
fprintf(pp, " 1=%d", channels);
fprintf(pp, " 2=%e", eps);
fprintf(pp, " 3=%d", affine);
if (affine) {
const onnx::TensorProto& scale = weights[node.input(1)];
const onnx::TensorProto& B = weights[node.input(2)];
fwrite_tensor_proto_data(scale, bp);
fwrite_tensor_proto_data(B, bp);
}
} else if (op == "GRU") {
const onnx::TensorProto& W = weights[node.input(1)];
const onnx::TensorProto& R = weights[node.input(2)];
const onnx::TensorProto& B = weights[node.input(3)];
int hidden_size = get_node_attr_i(node, "hidden_size", 0);
std::string direction = get_node_attr_s(node, "direction");
int direction_type = 0;
if (direction == "forward") {
direction_type = 0;
} else if (direction == "reverse") {
direction_type = 1;
} else if (direction == "bidirectional") {
direction_type = 2;
}
int weight_data_size = get_tensor_proto_data_size(W);
fprintf(pp, " 0=%d", hidden_size);
fprintf(pp, " 1=%d", weight_data_size);
fprintf(pp, " 2=%d", direction_type);
int num_directions = direction_type == 2 ? 2 : 1;
int quantize_tag = 0;
// reorder num_directions-URN-hidden-size to
// num_directions-RUN-hidden-size
{
fwrite(&quantize_tag, sizeof(int), 1, bp);
int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
const float* wptr =
W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
const float* uptr = wptr;
const float* rptr = wptr + weight_data_size_g;
const float* nptr = wptr + weight_data_size_g * 2;
fwrite(rptr, sizeof(float), weight_data_size_g, bp);
fwrite(uptr, sizeof(float), weight_data_size_g, bp);
fwrite(nptr, sizeof(float), weight_data_size_g, bp);
if (direction_type == 2) {
uptr += weight_data_size_g * 3;
rptr += weight_data_size_g * 3;
nptr += weight_data_size_g * 3;
fwrite(rptr, sizeof(float), weight_data_size_g, bp);
fwrite(uptr, sizeof(float), weight_data_size_g, bp);
fwrite(nptr, sizeof(float), weight_data_size_g, bp);
}
}
// reduce U and R bias except N
// reorder num_directions-URN-hidden to num_directions-RUN-hidden
{
fwrite(&quantize_tag, sizeof(int), 1, bp);
int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
const float* bptr =
B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
const float* wuptr = bptr;
const float* wrptr = bptr + bias_data_size_g;
const float* wnptr = bptr + bias_data_size_g * 2;
const float* buptr = bptr + bias_data_size_g * 3;
const float* brptr = bptr + bias_data_size_g * 4;
const float* bnptr = bptr + bias_data_size_g * 5;
for (int j = 0; j < bias_data_size_g; j++) {
float vb = wrptr[j] + brptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
for (int j = 0; j < bias_data_size_g; j++) {
float vb = wuptr[j] + buptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
if (direction_type == 2) {
wuptr += bias_data_size_g * 6;
wrptr += bias_data_size_g * 6;
wnptr += bias_data_size_g * 6;
buptr += bias_data_size_g * 6;
brptr += bias_data_size_g * 6;
bnptr += bias_data_size_g * 6;
for (int j = 0; j < bias_data_size_g; j++) {
float vb = wrptr[j] + brptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
for (int j = 0; j < bias_data_size_g; j++) {
float vb = wuptr[j] + buptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
}
}
// reorder num_directions-URN-hidden-hidden to
// num_directions-RUN-hidden-hidden
{
fwrite(&quantize_tag, sizeof(int), 1, bp);
int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
const float* Rptr =
R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
const float* uptr = Rptr;
const float* rptr = Rptr + weight_data_size_g;
const float* nptr = Rptr + weight_data_size_g * 2;
fwrite(rptr, sizeof(float), weight_data_size_g, bp);
fwrite(uptr, sizeof(float), weight_data_size_g, bp);
fwrite(nptr, sizeof(float), weight_data_size_g, bp);
if (direction_type == 2) {
uptr += weight_data_size_g * 3;
rptr += weight_data_size_g * 3;
nptr += weight_data_size_g * 3;
fwrite(rptr, sizeof(float), weight_data_size_g, bp);
fwrite(uptr, sizeof(float), weight_data_size_g, bp);
fwrite(nptr, sizeof(float), weight_data_size_g, bp);
}
}
} else if (op == "HardSigmoid") {
float alpha = get_node_attr_f(node, "alpha", 0.2f);
float beta = get_node_attr_f(node, "beta", 0.5f);
fprintf(pp, " 0=%e", alpha);
fprintf(pp, " 1=%e", beta);
} else if (op == "HardSwish") {
float alpha = get_node_attr_f(node, "alpha", 0.2f);
float beta = get_node_attr_f(node, "beta", 0.5f);
fprintf(pp, " 0=%e", alpha);
fprintf(pp, " 1=%e", beta);
} else if (op == "ImageScaler") {
std::vector<float> bias = get_node_attr_af(node, "bias");
float scale = get_node_attr_f(node, "scale", 1.f);
int channels = (int)bias.size();
fprintf(pp, " 0=%d", channels);
fprintf(pp, " 1=1");
for (int j = 0; j < channels; j++) {
fwrite(&scale, sizeof(float), 1, bp);
}
fwrite(&bias[0], sizeof(float), channels, bp);
} else if (op == "InstanceNormalization") {
float eps = get_node_attr_f(node, "epsilon", 1e-5f);
// discard affine-less S=1 B=0
std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
int channels = (int)affine_S.size();
int affine = 0;
{
for (int j = 0; j < channels; j++) {
if (affine_S[j] != 1.f || affine_B[j] != 0.f) {
affine = 1;
break;
}
}
}
fprintf(pp, " 0=%d", channels);
fprintf(pp, " 1=%e", eps);
fprintf(pp, " 2=%d", affine);
if (affine) {
const onnx::TensorProto& scale = weights[node.input(1)];
const onnx::TensorProto& B = weights[node.input(2)];
fwrite_tensor_proto_data(scale, bp);
fwrite_tensor_proto_data(B, bp);
}
} else if (op == "LayerNorm") {
float eps = get_node_attr_f(node, "epsilon", 1e-5f);
int affine = get_node_attr_i(node, "affine", 1);
if (affine) {
// discard affine-less S=1 B=0
std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
int affine_size = (int)affine_S.size();
affine = 0;
{
for (int j = 0; j < affine_size; j++) {
if (affine_S[j] != 1.f || affine_B[j] != 0.f) {
affine = 1;
break;
}
}
}
if (affine) {
fprintf(pp, " 0=%d", affine_size);
}
}
fprintf(pp, " 1=%e", eps);
fprintf(pp, " 2=%d", affine);
if (affine) {
const onnx::TensorProto& scale = weights[node.input(1)];
const onnx::TensorProto& B = weights[node.input(2)];
fwrite_tensor_proto_data(scale, bp);
fwrite_tensor_proto_data(B, bp);
}
} else if (op == "LeakyRelu") {
float alpha = get_node_attr_f(node, "alpha", 0.01f);
fprintf(pp, " 0=%e", alpha);
} else if (op == "Threshold") {
float threshold = get_node_attr_f(node, "threshold", 0.f);
fprintf(pp, " 0=%e", threshold);
} else if (op == "Log") {
int op_type = 8;
fprintf(pp, " 0=%d", op_type);
} else if (op == "LRN") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 0.5f);
float bias = get_node_attr_f(node, "bias", 1.f);
int size = get_node_attr_i(node, "size", 1);
int norm_region = 0;
fprintf(pp, " 0=%d", norm_region);
fprintf(pp, " 1=%d", size);
fprintf(pp, " 2=%e", alpha);
fprintf(pp, " 3=%e", beta);
fprintf(pp, " 4=%e", bias);
} else if (op == "LSTM") {
const onnx::TensorProto& W = weights[node.input(1)];
const onnx::TensorProto& R = weights[node.input(2)];
const onnx::TensorProto& B = weights[node.input(3)];
int hidden_size = get_node_attr_i(node, "hidden_size", 0);
std::string direction = get_node_attr_s(node, "direction");
int direction_type = 0;
if (direction == "forward") {
direction_type = 0;
} else if (direction == "reverse") {
direction_type = 1;
} else if (direction == "bidirectional") {
direction_type = 2;
}
int weight_data_size = get_tensor_proto_data_size(W);
fprintf(pp, " 0=%d", hidden_size);
fprintf(pp, " 1=%d", weight_data_size);
fprintf(pp, " 2=%d", direction_type);
int num_directions = direction_type == 2 ? 2 : 1;
int quantize_tag = 0;
// reorder num_directions-IOFG-hidden-size to
// num_directions-IFOG-hidden-size
{
fwrite(&quantize_tag, sizeof(int), 1, bp);
int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
const float* wptr =
W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
const float* iptr = wptr;
const float* optr = wptr + weight_data_size_g;
const float* fptr = wptr + weight_data_size_g * 2;
const float* gptr = wptr + weight_data_size_g * 3;
fwrite(iptr, sizeof(float), weight_data_size_g, bp);
fwrite(fptr, sizeof(float), weight_data_size_g, bp);
fwrite(optr, sizeof(float), weight_data_size_g, bp);
fwrite(gptr, sizeof(float), weight_data_size_g, bp);
if (direction_type == 2) {
iptr += weight_data_size_g * 4;
optr += weight_data_size_g * 4;
fptr += weight_data_size_g * 4;
gptr += weight_data_size_g * 4;
fwrite(iptr, sizeof(float), weight_data_size_g, bp);
fwrite(fptr, sizeof(float), weight_data_size_g, bp);
fwrite(optr, sizeof(float), weight_data_size_g, bp);
fwrite(gptr, sizeof(float), weight_data_size_g, bp);
}
}
// reduce xc and hc bias
// reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
{
fwrite(&quantize_tag, sizeof(int), 1, bp);
int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
const float* xcbptr =
B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
const float* xiptr = xcbptr;
const float* xoptr = xcbptr + bias_data_size_g;
const float* xfptr = xcbptr + bias_data_size_g * 2;
const float* xgptr = xcbptr + bias_data_size_g * 3;
const float* hiptr = xcbptr + bias_data_size_g * 4;
const float* hoptr = xcbptr + bias_data_size_g * 5;
const float* hfptr = xcbptr + bias_data_size_g * 6;
const float* hgptr = xcbptr + bias_data_size_g * 7;
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xiptr[j] + hiptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xfptr[j] + hfptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xoptr[j] + hoptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xgptr[j] + hgptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
if (direction_type == 2) {
xiptr += bias_data_size_g * 8;
xoptr += bias_data_size_g * 8;
xfptr += bias_data_size_g * 8;
xgptr += bias_data_size_g * 8;
hiptr += bias_data_size_g * 8;
hoptr += bias_data_size_g * 8;
hfptr += bias_data_size_g * 8;
hgptr += bias_data_size_g * 8;
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xiptr[j] + hiptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xfptr[j] + hfptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xoptr[j] + hoptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xgptr[j] + hgptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
}
}
// reorder num_directions-IOFG-hidden-hidden to
// num_directions-IFOG-hidden-hidden
{
fwrite(&quantize_tag, sizeof(int), 1, bp);
int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
const float* rptr =
R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
const float* iptr = rptr;
const float* optr = rptr + weight_data_size_g;
const float* fptr = rptr + weight_data_size_g * 2;
const float* gptr = rptr + weight_data_size_g * 3;
fwrite(iptr, sizeof(float), weight_data_size_g, bp);
fwrite(fptr, sizeof(float), weight_data_size_g, bp);
fwrite(optr, sizeof(float), weight_data_size_g, bp);
fwrite(gptr, sizeof(float), weight_data_size_g, bp);
if (direction_type == 2) {
iptr += weight_data_size_g * 4;
optr += weight_data_size_g * 4;
fptr += weight_data_size_g * 4;
gptr += weight_data_size_g * 4;
fwrite(iptr, sizeof(float), weight_data_size_g, bp);
fwrite(fptr, sizeof(float), weight_data_size_g, bp);
fwrite(optr, sizeof(float), weight_data_size_g, bp);
fwrite(gptr, sizeof(float), weight_data_size_g, bp);
}
}
} else if (op == "MatMul") {
if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) {
// InnerProduct
const onnx::TensorProto& B = weights[node.input(1)];
int weight_data_size = get_tensor_proto_data_size(B);
int num_output = B.dims(B.dims_size() - 1);
int num_input = weight_data_size / num_output;
fprintf(pp, " 0=%d", num_output);
fprintf(pp, " 1=0");
fprintf(pp, " 2=%d", weight_data_size);
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
// reorder num_input-num_output to num_output-num_input
{
const float* bptr =
B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
for (int j = 0; j < num_output; j++) {
for (int k = 0; k < num_input; k++) {
float vb = bptr[k * num_output + j];
fwrite(&vb, sizeof(float), 1, bp);
}
}
}
// fwrite_tensor_proto_data(B, bp)
} else {
// default matrix multiplication
}
} else if (op == "Max") {
int op_type = 4;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "Min") {
int op_type = 5;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "Mul") {
int op_type = 2;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "MultiHeadAttention") {
int embed_dim = get_node_attr_i(node, "embed_dim", 0);
int num_heads = get_node_attr_i(node, "num_heads", 0);
fprintf(pp, " 0=%d", embed_dim);
fprintf(pp, " 1=%d", num_heads);
if (node.input_size() == 5) {
const onnx::TensorProto& qkvw = weights[node.input(1)];
const onnx::TensorProto& qkvb = weights[node.input(2)];
const onnx::TensorProto& ow = weights[node.input(3)];
const onnx::TensorProto& ob = weights[node.input(4)];
int weight_data_size = get_tensor_proto_data_size(ow);
fprintf(pp, " 2=%d", weight_data_size);
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
// transpose qw
{
const float* wptr =
qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
const float* bptr =
qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[j * embed_dim * 3 + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
fwrite(bptr, sizeof(float), embed_dim, bp);
}
fwrite(&quantize_tag, sizeof(int), 1, bp);
// transpose kw
{
const float* wptr =
qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
const float* bptr =
qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
bptr += embed_dim;
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[j * embed_dim * 3 + k + embed_dim];
fwrite(&vb, sizeof(float), 1, bp);
}
}
fwrite(bptr, sizeof(float), embed_dim, bp);
}
fwrite(&quantize_tag, sizeof(int), 1, bp);
// transpose vw
{
const float* wptr =
qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
const float* bptr =
qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
bptr += embed_dim * 2;
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[j * embed_dim * 3 + k + embed_dim * 2];
fwrite(&vb, sizeof(float), 1, bp);
}
}
fwrite(bptr, sizeof(float), embed_dim, bp);
}
fwrite(&quantize_tag, sizeof(int), 1, bp);
// transpose ow
{
const float* wptr =
ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
}
fwrite_tensor_proto_data(ob, bp);
} else {
const onnx::TensorProto& qw = weights[node.input(3)];
const onnx::TensorProto& qb = weights[node.input(4)];
const onnx::TensorProto& kw = weights[node.input(5)];
const onnx::TensorProto& kb = weights[node.input(6)];
const onnx::TensorProto& vw = weights[node.input(7)];
const onnx::TensorProto& vb = weights[node.input(8)];
const onnx::TensorProto& ow = weights[node.input(9)];
const onnx::TensorProto& ob = weights[node.input(10)];
int weight_data_size = get_tensor_proto_data_size(qw);
fprintf(pp, " 2=%d", weight_data_size);
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
// transpose qw
{
const float* wptr =
qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data();
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
}
fwrite_tensor_proto_data(qb, bp);
fwrite(&quantize_tag, sizeof(int), 1, bp);
// transpose kw
{
const float* wptr =
kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data();
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
}
fwrite_tensor_proto_data(kb, bp);
fwrite(&quantize_tag, sizeof(int), 1, bp);
// transpose vw
{
const float* wptr =
vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data();
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
}
fwrite_tensor_proto_data(vb, bp);
fwrite(&quantize_tag, sizeof(int), 1, bp);
// transpose ow
{
const float* wptr =
ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
}
fwrite_tensor_proto_data(ob, bp);
}
} else if (op == "Neg") {
int op_type = 1;
fprintf(pp, " 0=%d", op_type);
} else if (op == "NonMaxSuppression") {
int max_dets = 0;
float iou_thre = 0.f;
float score_thre = 0.f;
// fprintf(stderr, "%s\n", node.name().c_str());
// fprintf(stderr, "node.input_size(): %d\n", node.input_size());
if (node.input_size() >= 3) {
// fprintf(stderr, "ok12!\n");
max_dets = (int)(get_node_attr_from_input<float>(weights[node.input(2)]) + 0.5);
}
if (node.input_size() >= 4) {
// fprintf(stderr, "iou_thre: %f\n",
// get_node_attr_from_input<float>(weights[node.input(3)]));
iou_thre = get_node_attr_from_input<float>(weights[node.input(3)]);
}
if (node.input_size() >= 5) {
// fprintf(stderr, "score_thre: %f\n",
// get_node_attr_from_input<float>(weights[node.input(4)]));
score_thre = get_node_attr_from_input<float>(weights[node.input(4)]);
}
fprintf(pp, " 0=%d", max_dets);
fprintf(pp, " 1=%f", iou_thre);
fprintf(pp, " 2=%f", score_thre);
} else if (op == "Normalize") {
float eps = get_node_attr_f(node, "eps", 0.f);
int scale_data_size = 1;
fprintf(pp, " 1=1"); // channel_shared
fprintf(pp, " 2=%e", eps);
fprintf(pp, " 3=%d", scale_data_size);
fprintf(pp, " 9=1"); // TODO hardcode pytorch style
const float scale_data[1] = {1.f};
fwrite(scale_data, sizeof(float), 1, bp);
} else if (op == "Pad") {
std::string mode = get_node_attr_s(node, "mode");
float value = get_node_attr_f(node, "value", 0.f);
std::vector<int> pads;
if (node.input_size() == 1) {
pads = get_node_attr_ai(node, "pads");
} else {
pads = get_node_attr_from_input_ai(weights[node.input(1)]);
}
int type = 0;
if (mode == "constant") {
type = 0;
} else if (mode == "edge") {
type = 1;
} else if (mode == "reflect") {
type = 2;
}
int pad_size = (int)pads.size();
int top = 0;
int bottom = 0;
int left = 0;
int right = 0;
int front = 0;
int behind = 0;
if (pad_size == 8) {
// NCHW
top = pads[2];
bottom = pads[6];
left = pads[3];
right = pads[7];
front = pads[1];
behind = pads[5];
} else if (pad_size == 6) {
// NHW
top = pads[1];
bottom = pads[4];
left = pads[2];
right = pads[5];
} else {
// NW
left = pads[1];
right = pads[3];
}
fprintf(pp, " 0=%d", top);
fprintf(pp, " 1=%d", bottom);
fprintf(pp, " 2=%d", left);
fprintf(pp, " 3=%d", right);
fprintf(pp, " 4=%d", type);
fprintf(pp, " 5=%e", value);
fprintf(pp, " 7=%d", front);
fprintf(pp, " 8=%d", behind);
} else if (op == "Pow") {
int op_type = 6;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "PriorBox") {
std::vector<float> min_sizes = get_node_attr_af(node, "min_sizes");
std::vector<float> max_sizes = get_node_attr_af(node, "max_sizes");
std::vector<float> aspect_ratios = get_node_attr_af(node, "aspect_ratios");
fprintf(pp, " -23300=%zu", min_sizes.size());
for (size_t j = 0; j < min_sizes.size(); ++j) {
fprintf(pp, ",%f", min_sizes[j]);
}
fprintf(pp, " -23301=%zu", max_sizes.size());
for (size_t j = 0; j < max_sizes.size(); ++j) {
fprintf(pp, ",%f", max_sizes[j]);
}
fprintf(pp, " -23302=%zu", aspect_ratios.size());
for (size_t j = 0; j < aspect_ratios.size(); ++j) {
fprintf(pp, ",%f", aspect_ratios[j]);
}
int image_width = get_node_attr_i(node, "image_width");
int image_height = get_node_attr_i(node, "image_height");
float step_width = get_node_attr_f(node, "step_width");
float step_height = get_node_attr_f(node, "step_height");
float offset = get_node_attr_f(node, "offset");
int step_mmdetection = get_node_attr_i(node, "step_mmdetection");
fprintf(pp, " 9=%d", image_width);
fprintf(pp, " 10=%d", image_height);
fprintf(pp, " 11=%f", step_width);
fprintf(pp, " 12=%f", step_height);
fprintf(pp, " 13=%f", offset);
fprintf(pp, " 14=%d", step_mmdetection);
} else if (op == "PixelShuffle") {
int scale_factor = get_node_attr_i(node, "scale_factor", 1);
fprintf(pp, " 0=%d", scale_factor);
} else if (op == "PRelu") {
const onnx::TensorProto& slope = weights[node.input(1)];
int num_slope = get_tensor_proto_data_size(slope);
fprintf(pp, " 0=%d", num_slope);
fwrite_tensor_proto_data(slope, bp);
} else if (op == "Reciprocal") {
int op_type = 15;
fprintf(pp, " 0=%d", op_type);
} else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" ||
op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" ||
op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") {
int op_type = -233;
if (op == "ReduceSum")
op_type = 0;
else if (op == "ReduceSumSquare")
op_type = 2;
else if (op == "ReduceMean")
op_type = 3;
else if (op == "ReduceMax")
op_type = 4;
else if (op == "ReduceMin")
op_type = 5;
else if (op == "ReduceProd")
op_type = 6;
else if (op == "ReduceL1")
op_type = 7;
else if (op == "ReduceL2")
op_type = 8;
else if (op == "ReduceLogSum")
op_type = 9;
else if (op == "ReduceLogSumExp")
op_type = 10;
fprintf(pp, " 0=%d", op_type);
std::vector<int> axes = get_node_attr_ai(node, "axes");
int keepdims = get_node_attr_i(node, "keepdims", 1);
if (axes.size() > 0) {
// if axes set, reduce according to axes
fprintf(pp, " 1=%d", 0);
fprintf(pp, " -23303=%zu", axes.size());
for (size_t j = 0; j < axes.size(); j++) {
if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3)
fprintf(stderr, "Unsupported reduction axes !\n");
fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]);
}
} else {
// if axes not set, reduce all axes by default
fprintf(pp, " 1=%d", 1);
}
fprintf(pp, " 4=%d", keepdims);
fprintf(pp, " 5=1");
} else if (op == "Reorg") {
int stride = get_node_attr_i(node, "stride", 1);
fprintf(pp, " 0=%d", stride);
} else if (op == "Reshape") {
std::vector<int> shape;
if (node.input_size() == 1) {
shape = get_node_attr_ai(node, "shape");
} else if (weights.find(node.input(1)) != weights.end()) {
shape = get_node_attr_from_input_ai(weights[node.input(1)]);
} else {
fprintf(stderr, "Unsupported reshape weight ! \n");
}
if (shape.size() == 1) {
fprintf(pp, " 0=%d", shape[0]); // should never reach here
} else if (shape.size() == 2) {
fprintf(pp, " 0=%d", shape[1]);
} else if (shape.size() == 3) {
fprintf(pp, " 0=%d", shape[2]);
fprintf(pp, " 1=%d", shape[1]);
} else if (shape.size() == 4) {
fprintf(pp, " 0=%d", shape[3]);
fprintf(pp, " 1=%d", shape[2]);
fprintf(pp, " 2=%d", shape[1]);
} else if (shape.size() == 5) {
fprintf(pp, " 0=%d", shape[4] * shape[3]);
fprintf(pp, " 1=%d", shape[2]);
fprintf(pp, " 2=%d", shape[1]);
}
} else if (op == "Resize") {
std::string mode = get_node_attr_s(node, "mode");
std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
std::vector<float> scales;
std::vector<int> sizes;
if (node.input_size() == 2) {
// opset 10
scales = get_node_attr_from_input_af(weights[node.input(1)]);
} else {
// opset 11+
scales = get_node_attr_from_input_af(weights[node.input(2)]);
if (node.input_size() >= 4) {
sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
}
}
int resize_type = 1;
if (mode == "nearest") {
resize_type = 1;
} else if (mode == "linear") {
resize_type = 2;
} else if (mode == "cubic") {
resize_type = 3;
}
if (scales.empty() && sizes.empty()) {
fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
}
float h_scale = 1.f;
float w_scale = 1.f;
if (scales.size() == 2) {
w_scale = scales[1];
} else if (scales.size() == 3) {
h_scale = scales[1];
w_scale = scales[2];
} else if (scales.size() == 4) {
h_scale = scales[2];
w_scale = scales[3];
if (scales[1] != 1.f) fprintf(stderr, "Unsupported Resize scales !\n");
}
int output_height = 0;
int output_width = 0;
if (sizes.size() == 2) {
output_width = sizes[1];
} else if (sizes.size() == 3) {
output_height = sizes[1];
output_width = sizes[2];
} else if (sizes.size() == 4) {
output_height = sizes[2];
output_width = sizes[3];
}
int align_corner = 0;
if (align == "align_corners") {
align_corner = 1;
}
fprintf(pp, " 0=%d", resize_type);
fprintf(pp, " 1=%e", h_scale);
fprintf(pp, " 2=%e", w_scale);
fprintf(pp, " 3=%d", output_height);
fprintf(pp, " 4=%d", output_width);
fprintf(pp, " 6=%d", align_corner);
} else if (op == "RNN") {
const onnx::TensorProto& W = weights[node.input(1)];
const onnx::TensorProto& R = weights[node.input(2)];
const onnx::TensorProto& B = weights[node.input(3)];
int hidden_size = get_node_attr_i(node, "hidden_size", 0);
std::string direction = get_node_attr_s(node, "direction");
int direction_type = 0;
if (direction == "forward") {
direction_type = 0;
} else if (direction == "reverse") {
direction_type = 1;
} else if (direction == "bidirectional") {
direction_type = 2;
}
int weight_data_size = get_tensor_proto_data_size(W);
fprintf(pp, " 0=%d", hidden_size);
fprintf(pp, " 1=%d", weight_data_size);
fprintf(pp, " 2=%d", direction_type);
int num_directions = direction_type == 2 ? 2 : 1;
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(W, bp);
// reduce xc and hc bias
{
fwrite(&quantize_tag, sizeof(int), 1, bp);
int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
const float* bptr =
B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
const float* xiptr = bptr;
const float* hiptr = bptr + bias_data_size_g;
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xiptr[j] + hiptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
if (direction_type == 2) {
xiptr += bias_data_size_g * 2;
hiptr += bias_data_size_g * 2;
for (int j = 0; j < bias_data_size_g; j++) {
float vb = xiptr[j] + hiptr[j];
fwrite(&vb, sizeof(float), 1, bp);
}
}
}
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(R, bp);
} else if (op == "RDiv") {
int op_type = 8;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "RSub") {
int op_type = 7;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "RoiAlign") {
int pooled_width = get_node_attr_i(node, "output_width", 1);
int pooled_height = get_node_attr_i(node, "output_height", 1);
float spatial_scale = get_node_attr_f(node, "spatial_scale", 1.f);
int sampling_ratio = get_node_attr_i(node, "sampling_ratio", 0);
fprintf(pp, " 0=%d", pooled_width);
fprintf(pp, " 1=%d", pooled_height);
fprintf(pp, " 2=%f", spatial_scale);
fprintf(pp, " 3=%d", sampling_ratio);
} else if (op == "ShuffleChannel") {
int group = get_node_attr_i(node, "group", 1);
int reverse = get_node_attr_i(node, "reverse", 0);
fprintf(pp, " 0=%d", group);
fprintf(pp, " 1=%d", reverse);
} else if (op == "Sigmoid") {
// no param
} else if (op == "Sin") {
int op_type = 9;
fprintf(pp, " 0=%d", op_type);
} else if (op == "SkipLayerNormalization") {
const onnx::TensorProto& W = weights[node.input(2)];
const onnx::TensorProto& B = weights[node.input(3)];
const onnx::TensorProto& B2 = weights[node.input(4)];
fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
int quantize_tag = 0;
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(W, bp);
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(B, bp);
fwrite(&quantize_tag, sizeof(int), 1, bp);
fwrite_tensor_proto_data(B2, bp);
} else if (op == "Slice") {
bool use_crop = true;
std::vector<int> starts;
std::vector<int> ends;
std::vector<int> axes;
std::vector<int> steps;
if (node.input_size() == 1) {
starts = get_node_attr_ai(node, "starts");
ends = get_node_attr_ai(node, "ends");
axes = get_node_attr_ai(node, "axes");
steps = get_node_attr_ai(node, "steps"); // TODO
} else {
starts = get_node_attr_from_input_ai(weights[node.input(1)]);
ends = get_node_attr_from_input_ai(weights[node.input(2)]);
if (node.input_size() >= 4) axes = get_node_attr_from_input_ai(weights[node.input(3)]);
if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]);
}
// assert step == 1 or step >= ends
for (int i = 0; i < (int)steps.size(); i++) {
if (steps[i] != 1 && steps[i] < ends[i]) {
use_crop = false;
fprintf(stderr, "Unsupported slice step ! Use custom TensorSlice\n");
}
}
if (use_crop) {
// filter out N-dim axis
if (!axes.empty()) {
for (int i = 0; i < (int)axes.size(); i++) {
int axis = axes[i];
if (axis == 0) {
starts.erase(starts.begin() + i);
ends.erase(ends.begin() + i);
axes.erase(axes.begin() + i);
break;
}
}
}
fprintf(pp, " -23309=%d", (int)starts.size());
for (int i = 0; i < (int)starts.size(); i++) {
fprintf(pp, ",%d", starts[i]);
}
fprintf(pp, " -23310=%d", (int)ends.size());
for (int i = 0; i < (int)ends.size(); i++) {
fprintf(pp, ",%d", ends[i]);
}
if (!axes.empty()) {
fprintf(pp, " -23311=%d", (int)axes.size());
for (int i = 0; i < (int)axes.size(); i++) {
int axis = axes[i];
if (axis == 0 || axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n");
if (axis > 0) axis = axis - 1; // -1 for skip N-dim
fprintf(pp, ",%d", axis);
}
}
} else {
fprintf(pp, " -23300=%d", (int)starts.size());
for (int i = 0; i < (int)starts.size(); i++) {
fprintf(pp, ",%d", starts[i]);
}
fprintf(pp, " -23301=%d", (int)ends.size());
for (int i = 0; i < (int)ends.size(); i++) {
fprintf(pp, ",%d", ends[i]);
}
if (!axes.empty()) {
fprintf(pp, " -23302=%d", (int)axes.size());
for (int i = 0; i < (int)axes.size(); i++) {
int axis = axes[i];
if (axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n");
fprintf(pp, ",%d", axis);
}
}
if (!steps.empty()) {
fprintf(pp, " -23303=%d", (int)steps.size());
for (int i = 0; i < (int)steps.size(); i++) {
int step = steps[i];
if (step == 0) fprintf(stderr, "Unsupported slice step ! Unsupported slice step\n");
fprintf(pp, ",%d", step);
}
}
}
} else if (op == "Softmax") {
int axis = get_node_attr_i(node, "axis", 1);
fprintf(pp, " 0=%d", axis - 1);
fprintf(pp, " 1=1");
} else if (op == "Split") {
int axis = get_node_attr_i(node, "axis", 0);
std::vector<int> split = get_node_attr_ai(node, "split");
if (axis < 1) fprintf(stderr, "Unsupported split axis !\n");
fprintf(pp, " -23300=%d", output_size);
if (split.empty()) {
for (int i = 0; i < output_size; i++) {
fprintf(pp, ",-233");
}
} else {
for (size_t i = 0; i < split.size() - 1; i++) {
fprintf(pp, ",%d", split[i]);
}
fprintf(pp, ",-233");
}
fprintf(pp, " 1=%d", axis - 1);
} else if (op == "Sqrt") {
int op_type = 5;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Squeeze") {
std::vector<int> axes = get_node_attr_ai(node, "axes");
if (axes.empty()) {
fprintf(pp, " 0=1");
fprintf(pp, " 1=1");
fprintf(pp, " 2=1");
} else {
bool flag = true;
for (int i = 0; i < (int)axes.size(); i++) {
if (axes[i] == 0) {
flag = false;
break;
}
}
if (flag == true) {
fprintf(pp, " -23303=%zu", axes.size());
for (int i = 0; i < (int)axes.size(); i++) {
if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
fprintf(stderr, "Unsupported squeeze axes !: %d, %s\n", axes[i], node.name().c_str());
fprintf(pp, ",%d", axes[i] - 1);
}
}
}
} else if (op == "Sub") {
int op_type = 1;
fprintf(pp, " 0=%d", op_type);
int with_scalar = get_node_attr_i(node, "with_scalar", 0);
float b = get_node_attr_f(node, "b", 0.f);
if (with_scalar) {
fprintf(pp, " 1=%d", with_scalar);
fprintf(pp, " 2=%e", b);
}
} else if (op == "Sum") {
int op_type = 1;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Swish") {
// no param
} else if (op == "Tan") {
int op_type = 11;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Tanh") {
int op_type = 16;
fprintf(pp, " 0=%d", op_type);
} else if (op == "TopK") {
int axis = get_node_attr_i(node, "axis", -1);
axis = axis > 0 ? axis - 1 : axis;
int largest = get_node_attr_i(node, "largest", 1);
int sorted = get_node_attr_i(node, "sorted", 1);
fprintf(pp, " 0=%d", axis);
fprintf(pp, " 1=%d", largest);
fprintf(pp, " 2=%d", sorted);
} else if (op == "Transpose") {
std::vector<int> perm = get_node_attr_ai(node, "perm");
if (perm.size() == 3) {
if (perm[1] == 1 && perm[2] == 2)
fprintf(pp, " 0=0"); // w h
else if (perm[1] == 2 && perm[2] == 1)
fprintf(pp, " 0=1"); // h w
else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
fprintf(pp, " 0=0"); // w h
else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
fprintf(pp, " 0=1"); // h w
} else if (perm.size() == 4) {
if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
fprintf(pp, " 0=0"); // w h c
else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
fprintf(pp, " 0=1"); // h w c
else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
fprintf(pp, " 0=2"); // w c h
else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
fprintf(pp, " 0=3"); // c w h
else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
fprintf(pp, " 0=4"); // h c w
else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
fprintf(pp, " 0=5"); // c h w
} else if (perm.size() == 5) {
if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
fprintf(pp, " 0=0"); // wx h c
else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
fprintf(pp, " 0=1"); // h wx c
else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
fprintf(pp, " 0=2"); // wx c h
else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
fprintf(pp, " 0=3"); // c wx h
else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
fprintf(pp, " 0=4"); // h c wx
else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
fprintf(pp, " 0=5"); // c h wx
else
fprintf(stderr, "Unsupported transpose type !\n");
}
} else if (op == "Upsample") {
std::string mode = get_node_attr_s(node, "mode");
std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
std::vector<float> scales;
if (node.input_size() == 1) {
scales = get_node_attr_af(node, "scales");
} else {
scales = get_node_attr_from_input_af(weights[node.input(1)]);
}
int resize_type = 1;
if (mode == "nearest") {
resize_type = 1;
} else if (mode == "bilinear" || mode == "linear") {
resize_type = 2;
} else if (mode == "trilinear") {
fprintf(stderr, "Unsupported Upsample mode !\n");
}
float h_scale = 1.f;
float w_scale = 1.f;
if (scales.size() == 2) {
w_scale = scales[1];
} else if (scales.size() == 3) {
h_scale = scales[1];
w_scale = scales[2];
} else if (scales.size() == 4) {
h_scale = scales[2];
w_scale = scales[3];
if (scales[1] != 1.f) fprintf(stderr, "Unsupported Upsample scales !\n");
} else {
fprintf(stderr, "Unsupported Upsample scales !\n");
}
int align_corner = 0;
if (align == "align_corners") {
align_corner = 1;
}
fprintf(pp, " 0=%d", resize_type);
fprintf(pp, " 1=%e", h_scale);
fprintf(pp, " 2=%e", w_scale);
fprintf(pp, " 6=%d", align_corner);
} else if (op == "Unsqueeze") {
std::vector<int> axes = get_node_attr_ai(node, "axes");
bool flag = true;
for (int i = 0; i < (int)axes.size(); i++) {
if (axes[i] == 0) {
flag = false;
break;
}
}
if (flag) {
fprintf(pp, " -23303=%zu", axes.size());
for (int i = 0; i < (int)axes.size(); i++) {
if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
fprintf(stderr, "Unsupported unsqueeze axes !: %d, %s\n", axes[i], node.name().c_str());
fprintf(pp, ",%d", axes[i] - 1);
}
}
} else if (op == "Yolov3DetectionOutput") {
int num_class = get_node_attr_i(node, "num_class");
int num_box = get_node_attr_i(node, "num_box");
float confidence_threshold = get_node_attr_f(node, "confidence_threshold");
float nms_threshold = get_node_attr_f(node, "nms_threshold");
fprintf(pp, " 0=%d", num_class);
fprintf(pp, " 1=%d", num_box);
fprintf(pp, " 2=%e", confidence_threshold);
fprintf(pp, " 3=%e", nms_threshold);
std::vector<float> biases = get_node_attr_af(node, "biases");
if (biases.size() > 0) {
fprintf(pp, " -23304=%zu", biases.size());
for (int i = 0; i < (int)biases.size(); i++) {
fprintf(pp, ",%e", biases[i]);
}
}
std::vector<float> mask = get_node_attr_af(node, "mask");
if (mask.size() > 0) {
fprintf(pp, " -23305=%zu", mask.size());
for (int i = 0; i < (int)mask.size(); i++) {
fprintf(pp, ",%e", mask[i]);
}
}
std::vector<float> anchors_scale = get_node_attr_af(node, "anchors_scale");
if (anchors_scale.size() > 0) {
fprintf(pp, " -23306=%zu", anchors_scale.size());
for (int i = 0; i < (int)anchors_scale.size(); i++) {
fprintf(pp, ",%e", anchors_scale[i]);
}
}
} else {
// TODO op specific param
}
fprintf(pp, "\n");
for (int j = 0; j < output_size; j++) {
const std::string& output_name = node.output(j);
if (node_reference.find(output_name) != node_reference.end()) {
int refcount = node_reference[output_name];
if (refcount > 1) {
char splitname[256];
sprintf(splitname, "splitncnn_%d", internal_split);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
fprintf(pp, " %s", output_name.c_str());
for (int k = 0; k < refcount; k++) {
fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
}
fprintf(pp, "\n");
internal_split++;
}
}
}
}
fclose(pp);
fclose(bp);
fprintf(stderr, "onnx2ncnn finish\n");
return 0;
}
// Copyright (c) OpenMMLab. All rights reserved.
#include "shape_inference.h"
#include <algorithm>
/**
* @brief query output shape of target node
*
* @param mutable_graph
* @param target
* @param weights
* @param context <tensor name, shape>
* @return std::tuple<bool, std::vector<int>>
*/
std::tuple<bool, std::vector<int>> query_shape(
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
const std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, std::vector<int>>& context) {
// emplace all input nodes
const int input_count = mutable_graph->input_size();
for (int i = 0; i < input_count; i++) {
auto inp = mutable_graph->input(i);
onnx::TypeProto inp_type = inp.type();
onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape();
auto dim_size = shape_proto.dim_size();
std::vector<int> shape(dim_size);
for (int index = 0; index < dim_size; ++index) {
shape[index] = shape_proto.dim(index).dim_value();
}
context.emplace(inp.name(), shape);
}
// BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes
std::vector<onnx::NodeProto*> serial = {target};
{
std::set<std::string> mark_as_appended = {};
while (true) {
int start = 0, end = serial.size();
for (int i = start; i < end; ++i) {
auto node_ptr = serial[i];
auto len = node_ptr->input_size();
for (int j = 0; j < len; ++j) {
std::string name = node_ptr->input(j);
if (context.find(name) != context.end()) {
// if input founded, skip
continue;
}
if (weights.find(name) != weights.end()) {
// if founded in weights, extract shape to context
auto weight = weights.at(name);
std::vector<int> shape;
for (auto index = 0; index < weight.dims_size(); ++index) {
shape.emplace_back(weight.dims(index));
}
context.emplace(name, shape);
continue;
}
if (mark_as_appended.find(name) != mark_as_appended.end()) {
// if mark as appended, skip
continue;
}
// else append it to serialization list
auto depend_ptr = find_node_by_output_name(mutable_graph, name);
if (depend_ptr == nullptr) {
fprintf(stderr, "cannot find %s from graph !\n", name.c_str());
return std::make_tuple(false, std::vector<int>{});
}
mark_as_appended.insert(name);
serial.emplace_back(depend_ptr);
}
}
if (serial.size() <= end) {
// if not new node added, quit
break;
}
// update start and end position, continue BFS the tree
start = end;
end = serial.size();
}
}
// for each node in serialization list, calculate the output shape
{
std::reverse(serial.begin(), serial.end());
for (auto node : serial) {
if (node->op_type() == "Conv") {
auto inp = context[node->input(0)];
auto weight = context[node->input(1)];
assert(inp.size() == 4 and weight.size() == 4);
int group = get_node_attr_i(*node, "group", 1);
assert(group == 1);
// treat multiple spatial attr as single one
#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \
int ATTR = DEFAULT; \
{ \
std::vector<int> _vec = get_node_attr_ai(*node, NAME); \
if (not _vec.empty()) { \
ATTR = _vec[0]; \
} \
}
EXTRACT_REPEATED_PARAM("dilations", dilation, 1);
EXTRACT_REPEATED_PARAM("pads", pad, 0);
EXTRACT_REPEATED_PARAM("strides", stride, 1);
#undef EXTRACT_REPEATED_PARAM
int on = inp[0];
int oc = weight[0];
int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1;
int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1;
context.emplace(node->output(0), std::vector<int>{on, oc, oh, ow});
} else if (node->op_type() == "Shape") {
auto inp = context[node->input(0)];
context.emplace(node->output(0), std::vector<int>{1, inp[1], inp[2], inp[3]});
} else if (node->op_type() == "Slice") {
assert(node->input_size() >= 4);
auto inp = context[node->input(0)];
int start = get_node_attr_from_input<int>(weights.at(node->input(1)));
int end = get_node_attr_from_input<int>(weights.at(node->input(2)));
int axes = get_node_attr_from_input<int>(weights.at(node->input(3)));
if (axes != 0) {
fprintf(stderr, "Not support axes=%d !\n", axes);
return std::make_tuple(false, std::vector<int>{});
}
assert(inp.size() >= end - start);
context.emplace(node->output(0), std::vector<int>{inp.begin() + start, inp.begin() + end});
} else if (node->op_type() == "Concat") {
assert(node->input_size() >= 2);
auto axis = get_node_attr_i(*node, "axis", 0);
if (axis != 0) {
fprintf(stderr, "Not support axes=%d !\n", axis);
return std::make_tuple(false, std::vector<int>{});
}
std::vector<int> inp = context[node->input(0)];
std::vector<int> w_data = get_node_attr_from_input_ai(weights.at(node->input(1)));
// concat data on axis 0
inp.insert(inp.end(), w_data.begin(), w_data.end());
context.emplace(node->output(0), inp);
} else {
fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str());
return std::make_tuple(false, std::vector<int>{});
}
}
}
assert(context.find(target->output(0)) != context.end());
auto target_shape = context[target->output(0)];
return std::make_tuple(true, target_shape);
}
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "utils.h"
/**
* @brief query output shape of target node
*
* @param mutable_graph
* @param target
* @param weights
* @param context <tensor name, shape>
* @return std::tuple<bool, std::vector<int>>
*/
std::tuple<bool, std::vector<int>> query_shape(
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
const std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, std::vector<int>>& context);
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <float.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <limits.h>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include "onnx.pb.h"
/**
* @brief find graph node by output name
*
* @param graph
* @param name
* @return onnx::NodeProto*
*/
static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph,
const std::string& name) {
const int input_count = mutable_graph->node_size();
for (int i = 0; i < input_count; ++i) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
for (int j = 0; j < node->output_size(); ++j) {
auto output = node->output(j);
if (output == name) {
return node;
}
}
}
return nullptr;
}
static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) {
std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) {
fprintf(stderr, "open failed %s\n", filepath);
return false;
}
google::protobuf::io::IstreamInputStream input(&fs);
google::protobuf::io::CodedInputStream codedstr(&input);
#if GOOGLE_PROTOBUF_VERSION >= 3011000
codedstr.SetTotalBytesLimit(INT_MAX);
#else
codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
#endif
bool success = message->ParseFromCodedStream(&codedstr);
fs.close();
return success;
}
static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key) {
std::vector<int> v;
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
v.resize(attr.ints_size());
for (int j = 0; j < attr.ints_size(); j++) {
v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX),
(::google::protobuf::int64)INT_MIN);
}
break;
}
}
return v;
}
static void set_node_attr_ai(onnx::NodeProto& node, const char* key,
const std::vector<int>& value) {
onnx::AttributeProto* attr_group = node.add_attribute();
attr_group->set_name(key);
for (auto v : value) {
attr_group->add_ints(v);
}
return;
}
static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key) {
std::vector<float> v;
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
v.resize(attr.floats_size());
for (int j = 0; j < attr.floats_size(); j++) {
v[j] = attr.floats(j);
}
break;
}
}
return v;
}
static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0) {
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX),
(::google::protobuf::int64)INT_MIN);
}
}
return def;
}
static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f) {
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
return attr.f();
}
}
return def;
}
static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key,
const std::string& def = std::string()) {
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
return attr.s();
}
}
return def;
}
static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key) {
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
return attr.t();
}
}
return onnx::TensorProto();
}
template <typename T>
static T get_node_attr_from_input(const onnx::TensorProto& tp) {
T v = 0.f;
// float
if (tp.data_type() == 1) {
const float* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const float*)tp.raw_data().data();
} else {
shape_data = tp.float_data().data();
}
v = shape_data[0];
}
// double
else if (tp.data_type() == 11) {
const double* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const double*)tp.raw_data().data();
} else {
shape_data = tp.double_data().data();
}
v = shape_data[0];
}
// int64
else if (tp.data_type() == 7) {
const int64_t* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const int64_t*)tp.raw_data().data();
} else {
shape_data = tp.int64_data().data();
}
v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX),
(::google::protobuf::int64)INT_MIN);
}
// int32
else if (tp.data_type() == 6) {
const int32_t* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const int32_t*)tp.raw_data().data();
} else {
shape_data = tp.int32_data().data();
}
v = shape_data[0];
} else {
// fprintf(stderr, "tp.name: %s\n", tp.name().c_str());
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
fprintf(stderr, "get_node_attr_from_input\n");
abort();
}
return v;
}
static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp) {
int size = 0;
std::vector<int> v;
// int64
if (tp.data_type() == 7) {
const int64_t* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const int64_t*)tp.raw_data().data();
size = (int)(tp.raw_data().size() / 8);
} else {
shape_data = tp.int64_data().data();
size = tp.int64_data_size();
}
for (int j = 0; j < size; j++) {
int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX),
(::google::protobuf::int64)INT_MIN);
v.push_back(vi);
}
}
// int32
else if (tp.data_type() == 6) {
const int32_t* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const int32_t*)tp.raw_data().data();
size = (int)(tp.raw_data().size() / 4);
} else {
shape_data = tp.int32_data().data();
size = tp.int32_data_size();
}
for (int j = 0; j < size; j++) {
v.push_back(shape_data[j]);
}
} else {
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
}
return v;
}
static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp) {
int size = 0;
std::vector<float> v;
// float
if (tp.data_type() == 1) {
const float* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const float*)tp.raw_data().data();
size = (int)(tp.raw_data().size() / 4);
} else {
shape_data = tp.float_data().data();
size = tp.float_data_size();
}
for (int j = 0; j < size; j++) {
v.push_back(shape_data[j]);
}
}
// double
else if (tp.data_type() == 11) {
const double* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const double*)tp.raw_data().data();
size = (int)(tp.raw_data().size() / 8);
} else {
shape_data = tp.double_data().data();
size = tp.double_data_size();
}
for (int j = 0; j < size; j++) {
v.push_back((float)shape_data[j]);
}
} else {
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
}
return v;
}
static int get_tensor_proto_data_size(const onnx::TensorProto& tp) {
if (tp.has_raw_data()) {
if (tp.data_type() == 1 || tp.data_type() == 6) {
const std::string& raw_data = tp.raw_data();
int size = (int)raw_data.size() / 4;
return size;
} else if (tp.data_type() == 7 || tp.data_type() == 11) {
const std::string& raw_data = tp.raw_data();
int size = (int)raw_data.size() / 8;
return size;
} else if (tp.data_type() == 9) {
const std::string& raw_data = tp.raw_data();
return 0;
}
} else if (tp.data_type() == 1) {
return tp.float_data_size();
} else if (tp.data_type() == 7) {
return tp.int64_data_size();
} else if (tp.data_type() == 6) {
return tp.int32_data_size();
} else if (tp.data_type() == 11) {
return tp.double_data_size();
}
return 0;
}
static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) {
int size = get_tensor_proto_data_size(tp);
if (tp.has_raw_data()) {
const std::string& raw_data = tp.raw_data();
fwrite(raw_data.data(), sizeof(float), size, bp);
} else if (tp.data_type() == 1) {
fwrite(tp.float_data().data(), sizeof(float), size, bp);
}
}
static void fwrite_tensor_proto_data_to_float(const onnx::TensorProto& tp, FILE* bp) {
int size = get_tensor_proto_data_size(tp);
size_t written_size;
if (tp.has_raw_data()) {
const std::string& raw_data = tp.raw_data();
if (tp.data_type() == 6) {
int* intdataptr = (int*)raw_data.data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 7) {
int64_t* intdataptr = (int64_t*)raw_data.data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 9) {
bool* intdataptr = (bool*)raw_data.data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 11) {
double* doubledataptr = (double*)raw_data.data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)doubledataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
}
} else if (tp.data_type() == 6) {
int* intdataptr = (int*)tp.int32_data().data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 7) {
int64_t* intdataptr = (int64_t*)tp.int64_data().data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 9) {
int* intdataptr = (int*)tp.int64_data().data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 11) {
double* doubledataptr = (double*)tp.double_data().data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)doubledataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
}
}
# Copyright (c) OpenMMLab. All rights reserved.
project(mmdeploy_ncnn_ops)
# add plugin source
file(GLOB_RECURSE NCNN_OPS_SRCS *.cpp)
add_library(${PROJECT_NAME}_obj OBJECT "${NCNN_OPS_SRCS}")
target_compile_definitions(${PROJECT_NAME}_obj PRIVATE -DMMDEPLOY_API_EXPORTS=1)
set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE 1)
target_link_libraries(${PROJECT_NAME}_obj PRIVATE ncnn)
set(_COMMON_INCLUDE_DIRS
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/csrc>)
target_include_directories(${PROJECT_NAME}_obj
PUBLIC ${_COMMON_INCLUDE_DIRS})
mmdeploy_export(${PROJECT_NAME}_obj)
mmdeploy_add_library(${PROJECT_NAME} SHARED EXCLUDE "")
target_link_libraries(${PROJECT_NAME} PRIVATE ${PROJECT_NAME}_obj)
target_include_directories(${PROJECT_NAME}
PUBLIC ${_COMMON_INCLUDE_DIRS})
add_library(mmdeploy::ncnn_ops ALIAS ${PROJECT_NAME})
set(_NCNN_OPS_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/lib)
install(TARGETS ${PROJECT_NAME} DESTINATION ${_NCNN_OPS_DIR})
// Copyright (c) OpenMMLab. All rights reserved.
#include "constantofshape.h"
#include "../ncnn_ops_definer.h"
namespace mmdeploy {
using namespace ncnn;
DEFINE_LAYER_CREATOR(ConstantOfShape)
DEFINE_NCNN_OPS(ConstantOfShape, ConstantOfShape)
ConstantOfShape::ConstantOfShape() {
one_blob_only = true;
support_inplace = false;
}
int ConstantOfShape::load_param(const ParamDict& pd) {
val = pd.get(0, 0.f);
return 0;
}
int ConstantOfShape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const {
int dims = bottom_blob.w - 1;
const float* bottom_ptr = bottom_blob;
const float* shape_ptr = bottom_ptr + 1;
if (dims == 1) {
int w = (int)(shape_ptr[0] + 0.5);
size_t elemsize = sizeof(val);
top_blob.create(w, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
top_blob.fill(val);
return 0;
} else if (dims == 2) {
int h = (int)(shape_ptr[0] + 0.5);
int w = (int)(shape_ptr[1] + 0.5);
size_t elemsize = sizeof(val);
top_blob.create(w, h, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
top_blob.fill(val);
return 0;
} else if (dims == 3) {
int channels = (int)(shape_ptr[0] + 0.5);
int h = (int)(shape_ptr[1] + 0.5);
int w = (int)(shape_ptr[2] + 0.5);
size_t elemsize = sizeof(val);
top_blob.create(w, h, channels, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
top_blob.fill(val);
return 0;
}
return -1;
}
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef LAYER_CONSTANTOFSHAPE_H
#define LAYER_CONSTANTOFSHAPE_H
#include "layer.h"
namespace mmdeploy {
class ConstantOfShape : public ncnn::Layer {
public:
ConstantOfShape();
virtual int load_param(const ncnn::ParamDict& pd);
virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob,
const ncnn::Option& opt) const;
public:
float val;
};
} // namespace mmdeploy
#endif // LAYER_CONSTANTOFSHAPE_H
// Copyright (c) OpenMMLab. All rights reserved.
// right alignment broadcast (c, h, w).
// the same as onnx
#include "expand.h"
#include "../ncnn_ops_definer.h"
namespace mmdeploy {
using namespace ncnn;
DEFINE_LAYER_CREATOR(Expand)
DEFINE_NCNN_OPS(Expand, Expand)
Expand::Expand() {
one_blob_only = false;
support_inplace = false;
}
int Expand::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs,
const Option& opt) const {
const Mat& bottom_blob = bottom_blobs[0];
size_t elemsize = bottom_blob.elemsize;
const Mat& old_shape_blob = bottom_blobs[1];
const int shape_width = old_shape_blob.w - 1;
Mat shape_blob(shape_width, elemsize, opt.workspace_allocator);
memcpy(shape_blob.row(0), old_shape_blob.row(0) + 1, shape_width * elemsize);
Mat& top_blob = top_blobs[0];
if (bottom_blob.dims == 1 && shape_blob.w == 1) {
int shape_0 = (int)(shape_blob[0] + 0.5);
if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d) vs (%d)\n", bottom_blob.w, shape_0);
} else if (bottom_blob.w == shape_0 || shape_0 == 1) {
top_blob.create(bottom_blob.w, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int i = 0; i < bottom_blob.w; i++) {
top_blob[i] = bottom_blob[i];
}
} else if (bottom_blob.w == 1) {
top_blob.create(shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int i = 0; i < shape_0; i++) {
top_blob[i] = bottom_blob[0];
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
} else if (bottom_blob.dims == 1 && shape_blob.w == 2) {
int shape_0 = (int)(shape_blob[0] + 0.5);
int shape_1 = (int)(shape_blob[1] + 0.5);
if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (1, %d) vs (%d, %d)\n", bottom_blob.w, shape_0,
shape_1);
} else if (bottom_blob.w == shape_1 || shape_1 == 1) {
top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int j = 0; j < shape_0; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.row(j)[i] = bottom_blob[i];
}
}
} else if (bottom_blob.w == 1) {
top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int j = 0; j < shape_0; j++) {
for (int i = 0; i < shape_1; i++) {
top_blob.row(j)[i] = bottom_blob[0];
}
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
} else if (bottom_blob.dims == 1 && shape_blob.w == 3) {
int shape_0 = (int)(shape_blob[0] + 0.5);
int shape_1 = (int)(shape_blob[1] + 0.5);
int shape_2 = (int)(shape_blob[2] + 0.5);
if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (1, 1, %d) vs (%d, %d, %d)\n", bottom_blob.w,
shape_0, shape_1, shape_2);
} else if (bottom_blob.w == shape_2 || shape_2 == 1) {
top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < shape_1; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob[i];
}
}
}
} else if (bottom_blob.w == 1) {
top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < shape_1; j++) {
for (int i = 0; i < shape_2; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob[0];
}
}
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
} else if (bottom_blob.dims == 2 && shape_blob.w == 2) {
int shape_0 = (int)(shape_blob[0] + 0.5);
int shape_1 = (int)(shape_blob[1] + 0.5);
if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h,
bottom_blob.w, shape_0, shape_1);
} else if (bottom_blob.h != shape_0 && bottom_blob.h != 1 && shape_0 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h,
bottom_blob.w, shape_0, shape_1);
} else if ((bottom_blob.w == shape_1 || shape_1 == 1) &&
(bottom_blob.h == shape_0 || shape_0 == 1)) {
top_blob.create(bottom_blob.w, bottom_blob.h, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.row(j)[i] = bottom_blob.row(j)[i];
}
}
} else if ((bottom_blob.w == shape_1 || shape_1 == 1) && (bottom_blob.h == 1)) {
top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int j = 0; j < shape_0; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.row(j)[i] = bottom_blob.row(0)[i];
}
}
} else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_0 || shape_0 == 1)) {
top_blob.create(shape_1, bottom_blob.h, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < shape_1; i++) {
top_blob.row(j)[i] = bottom_blob.row(j)[0];
}
}
} else if (bottom_blob.h == 1 && bottom_blob.w == 1) {
top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int j = 0; j < shape_0; j++) {
for (int i = 0; i < shape_1; i++) {
top_blob.row(j)[i] = bottom_blob.row(0)[0];
}
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
} else if (bottom_blob.dims == 2 && shape_blob.w == 3) {
int shape_0 = (int)(shape_blob[0] + 0.5);
int shape_1 = (int)(shape_blob[1] + 0.5);
int shape_2 = (int)(shape_blob[2] + 0.5);
if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h,
bottom_blob.w, shape_0, shape_1, shape_2);
} else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h,
bottom_blob.w, shape_0, shape_1, shape_2);
} else if ((bottom_blob.w == shape_2 || shape_2 == 1) &&
(bottom_blob.h == shape_1 || shape_1 == 1)) {
top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[i];
}
}
}
} else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1)) {
top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < shape_1; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[i];
}
}
}
} else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_1 || shape_1 == 1)) {
top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < shape_2; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[0];
}
}
}
} else if (bottom_blob.h == 1 && bottom_blob.w == 1) {
top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < shape_1; j++) {
for (int i = 0; i < shape_2; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[0];
}
}
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
} else if (bottom_blob.dims == 3 && shape_blob.w == 3) {
int shape_0 = (int)(shape_blob[0] + 0.5);
int shape_1 = (int)(shape_blob[1] + 0.5);
int shape_2 = (int)(shape_blob[2] + 0.5);
if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c,
bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2);
} else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c,
bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2);
} else if (bottom_blob.c != shape_0 && bottom_blob.c != 1 && shape_0 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c,
bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2);
} else if ((bottom_blob.w == shape_2 || shape_2 == 1) &&
(bottom_blob.h == shape_1 || shape_1 == 1) &&
(bottom_blob.c == shape_0 || shape_0 == 1)) {
top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i];
}
}
}
} else if ((bottom_blob.w == shape_2 || shape_2 == 1) &&
(bottom_blob.h == shape_1 || shape_1 == 1) && (bottom_blob.c == 1)) {
top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[i];
}
}
}
} else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) &&
(bottom_blob.c == shape_0 || shape_0 == 1)) {
top_blob.create(bottom_blob.w, shape_1, bottom_blob.c, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < shape_1; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[i];
}
}
}
} else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) &&
(bottom_blob.c == 1)) {
top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < shape_1; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[i];
}
}
}
} else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) &&
(bottom_blob.c == shape_0 || shape_0 == 1)) {
top_blob.create(shape_2, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < shape_2; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0];
}
}
}
} else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) &&
(bottom_blob.c == 1)) {
top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < shape_2; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[0];
}
}
}
} else if (bottom_blob.w == 1 && bottom_blob.h == 1 &&
(bottom_blob.c == shape_0 || shape_0 == 1)) {
top_blob.create(shape_2, shape_1, bottom_blob.c, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < shape_1; j++) {
for (int i = 0; i < shape_2; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[0];
}
}
}
} else if (bottom_blob.w == 1 && bottom_blob.h == 1 && bottom_blob.c == 1) {
top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < shape_0; k++) {
for (int j = 0; j < shape_1; j++) {
for (int i = 0; i < shape_2; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[0];
}
}
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
}
fprintf(stderr, "Layer: Expand, bottom_blob.dims: %d, shape_blob.w: %d\n", bottom_blob.dims,
shape_blob.w);
return -1;
}
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef LAYER_EXPAND_H
#define LAYER_EXPAND_H
#include "layer.h"
namespace mmdeploy {
class Expand : public ncnn::Layer {
public:
Expand();
virtual int forward(const std::vector<ncnn::Mat>& bottom_blobs, std::vector<ncnn::Mat>& top_blobs,
const ncnn::Option& opt) const;
};
} // namespace mmdeploy
#endif // LAYER_EXPAND_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment