Unverified Commit 0859fe90 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Changes to support both OneDNN and ZenDNN builds (#929)



* Add preallocate method

* Add preallocate_param pass

* Preallocate buffers on the cpu

* Formatting

* Preallocate on the gpu

* Add missing cpp file

* Formatting

* Add lifetime function

* Formatting

* Improve handling of exceptions in test driver

* Formatting

* Auto print exception

* Formatting

* Fork each test case

* Formatting

* Exclude gcc 5 debug build

* Fix tidy issues

* Add color

* Formatting

* Create driver class

* Formatting

* Customize test_case names

* Formatting

* Report status from forked processes

* Formatting

* Update the verify driver

* Formatting

* Print out failed tests

* Formatting

* Fix tidy issues

* Formatting

* Expect passing

* Improve failure reporting on non-linux systems

* Fix ifdef

* Always allocate

* Fix tidy warning

* Flush code code cov

* Formatting

* Fix tidy

* Add const

* Check if weak symbols is linked

* Formatting

* initial progress

* formatting

* Add continue flag

* Formatting

* Set exe name

* Use stringstream and use quotes

* rename vars

* formatting

* more testing

* formatting

* Fix bug when using --continue in the tests

* Formatting

* revert gemm

* revert dot file

* rename var

* update cmakelists and deconv compute
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent bd85a76c
......@@ -9,6 +9,8 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_dfor.hpp>
#include <cmath>
#include <utility>
......@@ -70,6 +72,81 @@ struct deconvolution
return inputs[0].with_lens(output_lens);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto kdims = this->kdims();
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), type{0});
auto in_lens = input.get_shape().lens();
auto in_n = in_lens[0];
auto in_c = in_lens[1];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto out_lens = output_shape.lens();
std::vector<std::size_t> win_size{in_c};
std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1;
auto wei_dims_start = idx_win.begin() + kdims + 1;
std::vector<std::ptrdiff_t> win_start;
for(std::size_t n = 0; n < kdims; ++n)
{
win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) -
std::ptrdiff_t(padding[n]));
}
const int group_id = w / (wei_n / group);
const int in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx_out{o, in_ch};
for(size_t n = 0; n < kdims; n++)
{
idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]);
}
std::vector<std::ptrdiff_t> idx_wei{w, k};
std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei));
std::vector<std::ptrdiff_t> idx_in{o, w};
std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in));
if(std::all_of(
idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx_out.begin() + 2,
idx_out.end(),
out_lens.begin() + 2,
out_lens.end(),
std::less<std::ptrdiff_t>{}))
{
output(idx_out.begin(), idx_out.end()) +=
input(idx_in.begin(), idx_in.end()) *
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
size_t kdims() const
{
check_attribute_size();
......
......@@ -31,12 +31,29 @@ add_library(migraphx_cpu
set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu)
rocm_set_soversion(migraphx_cpu ${MIGRAPHX_SO_VERSION})
set(MIGRAPHX_ENABLE_ZENDNN Off CACHE BOOL "")
find_package(Threads)
find_package(dnnl REQUIRED)
if(MIGRAPHX_ENABLE_ZENDNN)
find_path(ZENDNN_INC_PATH zendnn.hpp)
find_library(ZENDNN_LIB amdZenDNN)
find_library(BLIS_LIB blis)
else()
find_package(dnnl REQUIRED)
endif()
rocm_clang_tidy_check(migraphx_cpu)
if(MIGRAPHX_ENABLE_ZENDNN)
target_compile_definitions(migraphx_cpu PRIVATE -DMIGRAPHX_ENABLE_ZENDNN)
target_include_directories(migraphx_cpu PRIVATE ${ZENDNN_INC_PATH})
message(STATUS "ZENDNN_LIB: ${ZENDNN_LIB}")
target_link_libraries(migraphx_cpu PRIVATE ${BLIS_LIB})
target_link_libraries(migraphx_cpu PRIVATE ${ZENDNN_LIB})
else()
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
endif()
target_link_libraries(migraphx_cpu PRIVATE migraphx Threads::Threads)
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
find_package(OpenMP)
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
......
......@@ -37,7 +37,10 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
dnnl::binary::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {to_dnnl_algo(algo), m.at(DNNL_ARG_SRC_0), m.at(DNNL_ARG_SRC_1), m.at(DNNL_ARG_DST)};
return {to_dnnl_algo(algo),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_1)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
}
};
......
......@@ -11,7 +11,7 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
std::vector<int> arg_map(int size) const
{
std::vector<int> result(size);
std::iota(result.begin(), result.end(), DNNL_ARG_MULTIPLE_SRC);
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC));
return result;
}
// Custom desc class since its missing in dnnl
......@@ -28,9 +28,9 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
for(auto i = 0; i < m.size() - 1; i++)
{
srcs.push_back(m.at(DNNL_ARG_MULTIPLE_SRC + i));
srcs.push_back(m.at(MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC) + i));
}
return {m.at(DNNL_ARG_DST), std::size_t(op.axis), srcs};
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), std::size_t(op.axis), srcs};
}
auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
......
......@@ -15,7 +15,10 @@ namespace cpu {
struct dnnl_convolution
: dnnl_extend_op<dnnl_convolution, dnnl::convolution_forward, op::convolution>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
{
......@@ -45,9 +48,9 @@ struct dnnl_convolution
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_auto,
m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_WEIGHTS),
m.at(DNNL_ARG_DST),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(dilation),
to_dnnl_dims(padding_l),
......
......@@ -9,7 +9,10 @@ namespace cpu {
struct dnnl_deconvolution
: dnnl_extend_op<dnnl_deconvolution, dnnl::deconvolution_forward, op::deconvolution>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
{
......@@ -35,9 +38,9 @@ struct dnnl_deconvolution
dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; });
return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::deconvolution_direct,
m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_WEIGHTS),
m.at(DNNL_ARG_DST),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(dilation),
to_dnnl_dims(op.padding),
......
......@@ -2,6 +2,9 @@
#if defined(__GNUC__) && __GNUC__ <= 5
namespace std {
#ifdef MIGRAPHX_ENABLE_ZENDNN
namespace dnnl = zendnn;
#endif
template <>
struct hash<dnnl::algorithm>
{
......
......@@ -39,7 +39,7 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
{
return {dnnl::prop_kind::forward_inference,
to_dnnl_algo(algo),
m.at(DNNL_ARG_SRC_0),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)),
alpha,
beta};
}
......
......@@ -13,13 +13,20 @@ namespace cpu {
struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC),
MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS),
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
}
void required(const check_shapes& cs) const { cs.not_broadcasted(); }
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_WEIGHTS), m.at(DNNL_ARG_DST)};
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
}
};
......
......@@ -7,14 +7,26 @@
#include <migraphx/register_op.hpp>
#include <migraphx/check_shapes.hpp>
#include <unordered_map>
#include <dnnl.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/assert.hpp>
#ifdef MIGRAPHX_ENABLE_ZENDNN
#include <zendnn.hpp>
#else
#include <dnnl.hpp>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
#ifdef MIGRAPHX_ENABLE_ZENDNN
namespace dnnl = zendnn;
#define MIGRAPHX_CONCAT_PREFIX(b) ZENDNN_##b // NOLINT
#else
#define MIGRAPHX_CONCAT_PREFIX(b) DNNL_##b // NOLINT
#endif
#define MIGRAPHX_DNNL_PREFIX(b) MIGRAPHX_CONCAT_PREFIX(b) // NOLINT
struct dnnl_context
{
dnnl::engine engine;
......@@ -102,7 +114,8 @@ struct dnnl_op : auto_register_op<Derived>
static std::size_t get_binary_post_op_arg(std::size_t pos)
{
return DNNL_ARG_ATTR_MULTIPLE_POST_OP(pos) | DNNL_ARG_SRC_1; // NOLINT
return MIGRAPHX_DNNL_PREFIX(ARG_ATTR_MULTIPLE_POST_OP)(pos) | // NOLINT
MIGRAPHX_DNNL_PREFIX(ARG_SRC_1); // NOLINT
}
static std::vector<shape> to_shapes(const std::vector<argument>& args)
......@@ -117,14 +130,18 @@ struct dnnl_op : auto_register_op<Derived>
{
auto desc = prim.get_primitive_desc();
const char* str = nullptr;
#ifdef MIGRAPHX_ENABLE_ZENDNN
zendnn_primitive_desc_query(desc, zendnn_query_impl_info_str, 0, &str);
#else
dnnl_primitive_desc_query(desc, dnnl_query_impl_info_str, 0, &str);
#endif
return str == nullptr ? "" : str;
}
// Map arg index to arg in dnnl
std::vector<int> arg_map(int size) const
{
std::vector<int> result(size);
std::iota(result.begin(), result.end(), DNNL_ARG_SRC_0);
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_SRC_0));
return result;
}
shape base_adjust_shape(const shape& s) const
......@@ -183,8 +200,9 @@ struct dnnl_op : auto_register_op<Derived>
{
const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result;
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = create_arg_map(inputs.size());
result[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++)
{
......@@ -201,7 +219,7 @@ struct dnnl_op : auto_register_op<Derived>
if(contains(op.algo, "binary_add"))
{
auto desc = m.at(arg);
if(desc == m.at(DNNL_ARG_DST))
if(desc == m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)))
po.append_sum(1.0f);
else
po.append_binary(to_dnnl_algo(op.algo), m.at(arg));
......@@ -328,7 +346,8 @@ struct dnnl_op : auto_register_op<Derived>
}
#endif
std::unordered_map<int, dnnl::memory> m;
m[DNNL_ARG_DST] = to_dnnl_memory(md.at(DNNL_ARG_DST), args.back());
m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back());
for(int i = 0; i < args.size() - 1; i++)
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
......
......@@ -31,7 +31,7 @@ struct dnnl_layernorm : dnnl_op<dnnl_layernorm, dnnl::layer_normalization_forwar
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {dnnl::prop_kind::forward_inference,
m.at(DNNL_ARG_SRC),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
1e-12f,
dnnl::normalization_flags::none};
}
......
......@@ -12,7 +12,7 @@ struct dnnl_logsoftmax : dnnl_extend_op<dnnl_logsoftmax, dnnl::logsoftmax_forwar
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
int axis = this->op.axis;
return {dnnl::prop_kind::forward_inference, m.at(DNNL_ARG_SRC_0), axis};
return {dnnl::prop_kind::forward_inference, m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), axis};
}
};
......
......@@ -392,8 +392,10 @@ struct cpu_apply
extend_op("concat", "dnnl::concat");
extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution");
#ifndef MIGRAPHX_ENABLE_ZENDNN
extend_op("deconvolution", "dnnl::deconvolution");
extend_op("dot", "dnnl::dot");
#endif
extend_op("erf", "cpu::erf");
extend_op("gather", "cpu::gather");
extend_op("logsoftmax", "dnnl::logsoftmax");
......
......@@ -12,7 +12,7 @@ struct dnnl_lrn : dnnl_extend_op<dnnl_lrn, dnnl::lrn_forward, op::lrn>
{
return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::lrn_across_channels,
m.at(DNNL_ARG_SRC_0),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)),
this->op.size,
this->op.alpha,
this->op.beta,
......
......@@ -125,7 +125,7 @@ template struct cpu_pooling<max_pool>;
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC}; }
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
......@@ -135,8 +135,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference,
algo,
m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_DST),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths),
to_dnnl_dims(padding_l),
......
File mode changed from 100755 to 100644
......@@ -37,7 +37,11 @@ struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction>
dnnl::reduction::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {to_dnnl_algo(algo), m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST), 0, 0};
return {to_dnnl_algo(algo),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
0,
0};
}
};
......
......@@ -27,7 +27,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
};
desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST)};
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
}
auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
......
......@@ -11,7 +11,7 @@ struct dnnl_softmax : dnnl_extend_op<dnnl_softmax, dnnl::softmax_forward, op::so
dnnl::softmax_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
int axis = this->op.axis;
return {dnnl::prop_kind::forward_inference, m.at(DNNL_ARG_SRC_0), axis};
return {dnnl::prop_kind::forward_inference, m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), axis};
}
};
......
......@@ -269,99 +269,6 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
}
};
template <class Op>
struct ref_deconvolution : auto_register_op<ref_deconvolution<Op>>
{
ref_deconvolution() = default;
ref_deconvolution(Op pop) : op(std::move(pop)) {}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), type{0});
auto in_lens = input.get_shape().lens();
auto in_n = in_lens[0];
auto in_c = in_lens[1];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto out_lens = output_shape.lens();
auto kdims = op.kdims();
std::vector<std::size_t> win_size{in_c};
std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1;
auto wei_dims_start = idx_win.begin() + kdims + 1;
std::vector<std::ptrdiff_t> win_start;
for(std::size_t n = 0; n < kdims; ++n)
{
win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * op.stride[n]) -
std::ptrdiff_t(op.padding[n]));
}
const int group_id = w / (wei_n / op.group);
const int in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx_out{o, in_ch};
for(size_t n = 0; n < kdims; n++)
{
idx_out.push_back(win_start[n] + *(wei_dims_start + n) * op.dilation[n]);
}
std::vector<std::ptrdiff_t> idx_wei{w, k};
std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei));
std::vector<std::ptrdiff_t> idx_in{o, w};
std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in));
if(std::all_of(
idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx_out.begin() + 2,
idx_out.end(),
out_lens.begin() + 2,
out_lens.end(),
std::less<std::ptrdiff_t>{}))
{
output(idx_out.begin(), idx_out.end()) +=
input(idx_in.begin(), idx_in.end()) *
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
};
struct ref_im2col
{
op::im2col op;
......@@ -917,10 +824,8 @@ struct ref_apply
apply_map["batch_norm_inference"] =
extend_op<ref_batch_norm_inference, op::batch_norm_inference>();
apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>();
apply_map["deconvolution"] =
extend_op<ref_deconvolution<op::deconvolution>, op::deconvolution>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["elu"] = extend_op<ref_unary<elu_op>, op::elu>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment