Commit 920ed950 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into gpu_div

parents 81cc18c6 ab35b581
#ifndef MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_mean
{
std::vector<std::int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reduce_mean"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void calc_mean(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : tuned_axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val / batch_shape.elements();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : tuned_axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_mean(input, batch_shape, tuned_axes, out_idx, output);
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -4,6 +4,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
......@@ -13,7 +14,7 @@ namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes;
std::vector<int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -23,25 +24,82 @@ struct reduce_sum
std::string name() const { return "reduce_sum"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
auto s = inputs.at(0);
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void calc_sum(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : tuned_axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
std::vector<int64_t> tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : tuned_axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](auto&& in_idx) {
auto out_idx = in_idx;
for(auto axis : axes)
out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_sum(input, batch_shape, tuned_axes, out_idx, output);
});
});
......
......@@ -46,6 +46,7 @@
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
......
......@@ -96,7 +96,8 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
// init the activation function map
init_actv_func();
......@@ -1288,20 +1289,21 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output};
}
instruction_ref parse_reduce_sum(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
template <class T>
instruction_ref parse_reduce_oper(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
......@@ -1312,13 +1314,12 @@ struct onnx_parser
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
return prog.add_instruction(T{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
auto ins = prog.add_instruction(T{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
......
......@@ -10,8 +10,8 @@ inline namespace MIGRAPHX_INLINE_NS {
bool skip_propogate(instruction_ref ins)
{
if(ins->name() == "@literal")
return true;
if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front());
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
return true;
......@@ -33,7 +33,7 @@ void propagate_constant::apply(program& p) const
ins->outputs().end());
for(auto child : children)
{
if(skip_propogate(child))
if(child->name() == "@literal" or skip_propogate(child))
{
self(child);
continue;
......
......@@ -40,6 +40,7 @@ add_library(migraphx_device
device/div.cpp
device/clip.cpp
device/reduce_sum.cpp
device/reduce_mean.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device)
......@@ -78,6 +79,7 @@ add_library(migraphx_gpu
adjust_allocation.cpp
clip.cpp
reduce_sum.cpp
reduce_mean.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
......@@ -28,6 +28,16 @@ struct id
}
};
struct mean
{
size_t item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return static_cast<T>(x / item_num);
}
};
struct max
{
template <class T, class U>
......
#include <migraphx/gpu/device/reduce_mean.hpp>
#include <migraphx/gpu/device/reduce.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{
std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, mean{item_num});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_MEAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_MEAN_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_MEAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_MEAN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/reflect.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_reduce_mean
{
op::reduce_mean op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::reduce_mean"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -52,6 +52,7 @@
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -115,6 +116,7 @@ struct miopen_apply
add_extend_op<hip_convert, op::convert>("convert");
add_extend_op<hip_clip, op::clip>("clip");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean");
add_lrn_op();
add_convolution_op();
......
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/reduce_mean.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_reduce_mean::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument
hip_reduce_mean::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::reduce_mean(ctx.get_stream().get(), args.back(), args.front());
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -1703,7 +1703,7 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_sum_test0)
TEST_CASE(reduce_sum_axis0)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1718,7 +1718,7 @@ TEST_CASE(reduce_sum_test0)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test1)
TEST_CASE(reduce_sum_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1733,7 +1733,7 @@ TEST_CASE(reduce_sum_test1)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test2)
TEST_CASE(reduce_sum_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1748,7 +1748,7 @@ TEST_CASE(reduce_sum_test2)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test02)
TEST_CASE(reduce_sum_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1763,7 +1763,7 @@ TEST_CASE(reduce_sum_test02)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test12)
TEST_CASE(reduce_sum_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1778,4 +1778,79 @@ TEST_CASE(reduce_sum_test12)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 3, 6, 7, 10, 11};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5.5, 7.5};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5f, 6.5f, 10.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_int)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int> gold{2, 6, 10};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -3543,4 +3543,40 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
};
};
struct test_reduce_mean : verify_program<test_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 9, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -858,6 +858,27 @@ TEST_CASE(reducesum_test3)
EXPECT(p == prog);
}
TEST_CASE(reducemean_test1)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducemean_test1.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducemean_test2)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
auto prog = migraphx::parse_onnx("reducemean_test2.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test)
{
migraphx::program p;
......
reducemean-example:}
0
xy"
ReduceMean*
axes@*
keepdimstest_reducemeanZ
x




b
y




B
......@@ -390,37 +390,111 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
template <class T>
void test_argop_var()
TEST_CASE(test_argmax)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::op::argmax{0},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::op::argmax{1},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::op::argmax{2},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::op::argmax{3},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::argmax{4}, input);
}
}
TEST_CASE(test_argmin)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::op::argmin{0},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::op::argmin{1},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::op::argmin{2},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::op::argmin{3},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::argmin{4}, input);
}
}
template <class T>
void test_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4}, input);
throws_shape(T{{4}}, input);
}
}
TEST_CASE(argmax) { test_argop_var<migraphx::op::argmax>(); }
TEST_CASE(argmin) { test_argop_var<migraphx::op::argmin>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
// 2 inputs arguments
TEST_CASE(matmul)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment