Commit 5d057776 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents d6b4ae77 9b19b73f
...@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss) ...@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
auto each = [&](auto x) { auto each = [&](auto x) {
using type = remove_vec<typename decltype(x)::type>; using type = remove_vec<typename decltype(x)::type>;
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>; constexpr auto size = s.element_space();
if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or
not is_same<T, type>{}) not is_same<T, type>{})
return f(x, offset, false_type{}); return f(x, offset, false_type{});
......
...@@ -19,7 +19,7 @@ struct max_pool ...@@ -19,7 +19,7 @@ struct max_pool
} }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int)
{ {
return (x); return (x);
} }
...@@ -36,21 +36,19 @@ struct avg_pool ...@@ -36,21 +36,19 @@ struct avg_pool
} }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t y) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{ {
return (y == 0) ? 0.0 : (x / y); return (y == 0) ? 0.0 : (x / y);
} }
}; };
template <class T, class Op> template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data, MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
const array<std::size_t, 2>& dims, const Iterator data, const array<index_int, 2>& dims, array<float, 2> xy, Op pooling)
array<float, 2> xy,
Op pooling)
{ {
array<int, 2> low{}; array<int, 2> low{};
array<int, 2> high{}; array<int, 2> high{};
for(std::size_t ii = 0; ii < xy.size(); ++ii) for(index_int ii = 0; ii < xy.size(); ++ii)
{ {
if(xy[ii] < -1.0f or xy[ii] > dims[ii]) if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{ {
...@@ -65,36 +63,36 @@ MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data, ...@@ -65,36 +63,36 @@ MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data,
xy[ii] = high[ii] = low[ii] = dims[ii] - 1; xy[ii] = high[ii] = low[ii] = dims[ii] - 1;
} }
} }
array<std::size_t, 4> locs = {low[0] * dims[1] + low[1], array<index_int, 4> locs = {low[0] * dims[1] + low[1],
low[0] * dims[1] + high[1], low[0] * dims[1] + high[1],
high[0] * dims[1] + low[1], high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]}; high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0]; float ly = xy[0] - low[0];
float lx = xy[1] - low[1]; float lx = xy[1] - low[1];
float hy = 1.0f - ly; float hy = 1.0f - ly;
float hx = 1.0f - lx; float hx = 1.0f - lx;
array<T, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx}; array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23); return pooling(v01, v23);
} }
template <class T, class Op> template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data, MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
const array<float, 2>& roi_starts, const array<float, 2>& roi_starts,
const array<float, 2>& bin_size, const array<float, 2>& bin_size,
const array<int, 2>& idx, const array<int, 2>& idx,
const array<std::size_t, 2>& bin_grid_size, const array<index_int, 2>& bin_grid_size,
const array<std::size_t, 2>& dims, const array<index_int, 2>& dims,
float roi_offset, float roi_offset,
Op op) Op op)
{ {
T output_val = op.init(); typename Iterator::value_type output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1]; const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<std::size_t, 2> id = {iy, ix}; array<index_int, 2> id = {iy, ix};
array<float, 2> locs = array<float, 2> locs =
roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset; roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset;
...@@ -122,19 +120,19 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs) ...@@ -122,19 +120,19 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
template <class T, class U, class V, class W, class Settings> template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s) __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s)
{ {
auto index = make_index(); auto index = make_index();
const auto* x = x_t.data(); const auto x = x_t.begin();
const auto* rois = rois_t.data(); const auto rois = rois_t.begin();
const auto* ind = ind_t.data(); const auto ind = ind_t.begin();
auto* out_ptr = y_t.data(); auto out_ptr = y_t.begin();
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
// input dims of height and width, in all 2-dim arrays, the first dim // input dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width // is for height and second dim is for width
array<std::size_t, 2> in_dims = {x_lens[2], x_lens[3]}; array<index_int, 2> in_dims = {x_lens[2], x_lens[3]};
const auto stride = index.nglobal(); const auto stride = index.nglobal();
auto out_s = y_t.get_shape(); auto out_s = y_t.get_shape();
...@@ -142,8 +140,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -142,8 +140,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
// output dims of height and width, in all 2-dim arrays, the first dim // output dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width // is for height and second dim is for width
const auto& out_lens = out_s.lens; const auto& out_lens = out_s.lens;
array<std::size_t, 2> out_dims = {out_lens[2], out_lens[3]}; array<index_int, 2> out_dims = {out_lens[2], out_lens[3]};
for(index_int i = index.global; i < out_s.elements(); i += stride) for(index_int i = index.global; i < out_s.elements(); i += stride)
{ {
...@@ -153,8 +151,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -153,8 +151,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
int ph = idx[2]; int ph = idx[2];
int pw = idx[3]; int pw = idx[3];
const auto* offset_rois = rois + (n * roi_column_num); const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale, array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * s.spatial_scale}; offset_rois[0] * s.spatial_scale};
...@@ -163,9 +161,9 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -163,9 +161,9 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
array<std::size_t, 2> bin_grid_size{}; array<index_int, 2> bin_grid_size{};
for(std::size_t ii = 0; ii < roi_size.size(); ++ii) for(index_int ii = 0; ii < roi_size.size(); ++ii)
{ {
roi_size[ii] = roi_ends[ii] - roi_starts[ii]; roi_size[ii] = roi_ends[ii] - roi_starts[ii];
roi_size[ii] = max(roi_size[ii], 1.0f); roi_size[ii] = max(roi_size[ii], 1.0f);
...@@ -175,7 +173,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -175,7 +173,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
(s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]); (s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]);
} }
const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(s.is_avg_pooling) if constexpr(s.is_avg_pooling)
{ {
out_ptr[i] = calc_pooling(offset_x, out_ptr[i] = calc_pooling(offset_x,
......
...@@ -17,35 +17,38 @@ struct shape ...@@ -17,35 +17,38 @@ struct shape
constexpr shape(Lens l, Strides s) : lens(l), strides(s) {} constexpr shape(Lens l, Strides s) : lens(l), strides(s) {}
constexpr index_int elements() const { return lens.product(); } constexpr auto elements() const { return _c<Lens{}.product()>; }
constexpr index_int element_space() const { return strides.dot(lens - 1) + 1; } constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr bool packed() const { return elements() == element_space(); } constexpr auto packed() const { return elements() == element_space(); }
constexpr bool broadcasted() const { return strides.product() == 0; } constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr bool transposed() const constexpr auto transposed() const
{ {
if(broadcasted()) return return_c([] {
{ auto lstrides = Strides{};
index_array s; if(shape{}.broadcasted())
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{ {
if(strides[i] != 0) index_array s{};
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{ {
s[j] = strides[i]; if(lstrides[i] != 0)
j++; {
s[j] = lstrides[i];
j++;
}
} }
return not is_sorted(s.begin(), s.begin() + j, greater{});
} }
return not is_sorted(s.begin(), s.begin() + j, greater{}); else
} {
else return not is_sorted(lstrides.begin(), lstrides.end(), greater{});
{ }
return not is_sorted(strides.begin(), strides.end(), greater{}); });
}
} }
constexpr bool standard() const { return packed() and not transposed(); } constexpr auto standard() const { return packed() and not transposed(); }
constexpr index_int index(index_array x) const { return x.dot(strides); } constexpr index_int index(index_array x) const { return x.dot(strides); }
...@@ -63,10 +66,10 @@ struct shape ...@@ -63,10 +66,10 @@ struct shape
return i; return i;
else else
{ {
const index_int rank = this->lens.size(); const auto rank = this->lens.size();
index_int s = 1; index_int s = 1;
index_int result = 0; index_int result = 0;
for(index_int j = 0; j < this->lens.size(); j++) for(index_int j = 0; j < rank; j++)
{ {
const index_int k = rank - j - 1; const index_int k = rank - j - 1;
const index_int stride = this->strides[k]; const index_int stride = this->strides[k];
......
...@@ -3,17 +3,30 @@ ...@@ -3,17 +3,30 @@
#include <migraphx/kernels/shape.hpp> #include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
namespace migraphx { namespace migraphx {
template <class T>
struct tensor_view_iterator_read
{
T* view;
constexpr auto& operator()(std::size_t n) const
{
MIGRAPHX_ASSERT(view != nullptr);
return (*view)[n];
}
};
template <class T, class Shape> template <class T, class Shape>
struct tensor_view struct tensor_view
{ {
using type = T; using type = T;
using shape_type = Shape; using shape_type = Shape;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>;
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); } constexpr auto size() const { return get_shape().elements(); }
template <class U> template <class U>
constexpr T& operator[](U i) const constexpr T& operator[](U i) const
...@@ -24,8 +37,8 @@ struct tensor_view ...@@ -24,8 +37,8 @@ struct tensor_view
constexpr T* data() const { return x; } constexpr T* data() const { return x; }
constexpr T* begin() const { return data(); } constexpr auto begin() const { return iterator{0, {this}}; }
constexpr T* end() const { return data() + size(); } constexpr auto end() const { return iterator{this->size(), {this}}; }
template <class U> template <class U>
constexpr tensor_view<U, Shape> with(U* y) const constexpr tensor_view<U, Shape> with(U* y) const
......
...@@ -6,6 +6,12 @@ ...@@ -6,6 +6,12 @@
namespace migraphx { namespace migraphx {
template <class T>
struct type_identity
{
using type = T;
};
template <bool B, class T = void> template <bool B, class T = void>
struct enable_if struct enable_if
{ {
...@@ -35,6 +41,33 @@ struct is_same<T, T> : true_type ...@@ -35,6 +41,33 @@ struct is_same<T, T> : true_type
{ {
}; };
template <class T>
struct remove_reference
{
using type = T;
};
template <class T>
struct remove_reference<T&>
{
using type = T;
};
template <class T>
struct remove_reference<T&&>
{
using type = T;
};
template <class T>
using remove_reference_t = typename remove_reference<T>::type;
template <class T>
struct add_pointer : type_identity<typename remove_reference<T>::type*>
{
};
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> #define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx } // namespace migraphx
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
namespace migraphx { namespace migraphx {
using index_int = std::uint32_t; using index_int = std::uint32_t;
using diff_int = std::int32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT #define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
......
#include <migraphx/ref/gemm.hpp> #include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <blaze/math/CustomMatrix.h> #include <blaze/math/CustomMatrix.h>
namespace migraphx { namespace migraphx {
...@@ -74,8 +74,10 @@ void migemm_impl( ...@@ -74,8 +74,10 @@ void migemm_impl(
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx; auto a_idx = c_idx;
auto b_idx = c_idx; auto b_idx = c_idx;
double s = 0.0; double s = 0.0;
......
...@@ -819,9 +819,9 @@ struct ref_apply ...@@ -819,9 +819,9 @@ struct ref_apply
void apply_pooling(instruction_ref ins) const void apply_pooling(instruction_ref ins) const
{ {
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "max") if(op.mode == op::pooling_mode::max)
mod->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs()); mod->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average") else if(op.mode == op::pooling_mode::average)
mod->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs()); mod->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs());
} }
}; };
......
...@@ -19,7 +19,12 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -19,7 +19,12 @@ struct parse_pooling : op_parser<parse_pooling>
tf_parser::node_info info, tf_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
op::pooling op{starts_with(opd.tf_name, "Max") ? "max" : "average"}; if(!starts_with(opd.tf_name, "Max") && !starts_with(opd.tf_name, "Av"))
{
MIGRAPHX_THROW("tf pooling mode must be Max or Average");
}
op::pooling op{starts_with(opd.tf_name, "Max") ? op::pooling_mode::max
: op::pooling_mode::average};
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
......
...@@ -13,6 +13,7 @@ endfunction() ...@@ -13,6 +13,7 @@ endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
...@@ -25,6 +25,23 @@ TEST_CASE(load_and_run) ...@@ -25,6 +25,23 @@ TEST_CASE(load_and_run)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
} }
TEST_CASE(load_and_run_ctx)
{
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
for(auto&& name : param_shapes.names())
{
pp.add(name, migraphx::argument::generate(param_shapes[name]));
}
auto ctx = p.experimental_get_context();
p.eval(pp);
ctx.finish();
}
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
{ {
auto run_prog = [&](auto cond) { auto run_prog = [&](auto cond) {
......
#include <numeric>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(add_op)
{
migraphx::program p;
migraphx::module m = p.get_main_module();
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
auto x = m.add_parameter("x", param_shape);
auto y = m.add_parameter("y", param_shape);
auto add_op = migraphx::operation("add");
auto r = m.add_instruction(add_op, {x, y});
m.add_return({r});
// run on ref target
p.compile(migraphx::target("ref"));
migraphx::program_parameters pp;
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
pp.add("x", migraphx::argument(param_shape, x_data.data()));
pp.add("y", migraphx::argument(param_shape, y_data.data()));
auto outputs = p.eval(pp);
auto output = outputs[0];
std::vector<float> expected(9, 0);
CHECK(bool(output == migraphx::argument(param_shape, expected.data())));
}
TEST_CASE(if_then_else_op)
{
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
migraphx::shape cond_s{migraphx_shape_bool_type};
auto create_program = [&]() {
migraphx::program p;
auto mm = p.get_main_module();
auto cond = mm.add_parameter("cond", cond_s);
auto x = mm.add_parameter("x", param_shape);
auto y = mm.add_parameter("y", param_shape);
auto then_mod = p.create_module("If_0_if");
auto x_identity = then_mod.add_instruction(migraphx::operation("identity"), {x});
then_mod.add_return({x_identity});
auto else_mod = p.create_module("If_0_else");
auto y_identity = else_mod.add_instruction(migraphx::operation("identity"), {y});
else_mod.add_return({y_identity});
auto if_ins = mm.add_instruction(migraphx::operation("if"), {cond}, {then_mod, else_mod});
auto get_tuple_op = migraphx::operation("get_tuple_elem", "{index: 0}");
auto ret = mm.add_instruction(get_tuple_op, {if_ins});
mm.add_return({ret});
return p;
};
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
auto x_arg = migraphx::argument(param_shape, x_data.data());
auto y_arg = migraphx::argument(param_shape, y_data.data());
auto run_prog = [&](bool cond) {
auto p = create_program();
p.compile(migraphx::target("ref"));
auto outputs =
p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}});
return outputs;
};
// then branch
auto then_res = run_prog(true);
CHECK(bool{then_res[0] == x_arg});
// else branch
auto else_res = run_prog(false);
CHECK(bool{else_res[0] == y_arg});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -55,7 +55,8 @@ TEST_CASE(rewrite_pad) ...@@ -55,7 +55,8 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(padded_img, channels, m); auto l0 = create_im2col(padded_img, channels, m);
auto l1 = create_conv(padded_img, channels, m); auto l1 = create_conv(padded_img, channels, m);
auto l2 = m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), padded_img); auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), padded_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2); m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
auto s0 = l0->get_shape(); auto s0 = l0->get_shape();
......
...@@ -55,7 +55,9 @@ TEST_CASE(rewrite_pad) ...@@ -55,7 +55,9 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(l_img, channels, m); auto l0 = create_im2col(l_img, channels, m);
auto l1 = create_conv(l_img, channels, m); auto l1 = create_conv(l_img, channels, m);
auto l2 = m.add_instruction( auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {0, 0, 1, 1}}}), l_img); migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"padding", {0, 0, 1, 1}}}),
l_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2); m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
run_pass(m); run_pass(m);
...@@ -76,8 +78,10 @@ TEST_CASE(rewrite_pad_symmetric) ...@@ -76,8 +78,10 @@ TEST_CASE(rewrite_pad_symmetric)
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = m.add_literal(migraphx::literal{s_img, input}); auto l_img = m.add_literal(migraphx::literal{s_img, input});
m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {1, 1, 1, 1}}}), m.add_instruction(
l_img); migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"padding", {1, 1, 1, 1}}}),
l_img);
run_pass(m); run_pass(m);
EXPECT(std::none_of( EXPECT(std::none_of(
......
celu_alpha_test:R

xy"Celu*
alphaL?celu_alpha_testZ
x

b
y

B
\ No newline at end of file
celu_default_test:K
xy"Celucelu_default_testZ
x


b
y


B
\ No newline at end of file
celu_wrong_type_test:N
xy"Celucelu_wrong_type_testZ
x



b
y



B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment