Commit b7c6b8ca authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge reduce mean to test_bert branch

parents 4d358059 a4055723
......@@ -12,7 +12,7 @@ namespace op {
struct argmax
{
int axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -25,8 +25,8 @@ struct argmax
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
//#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
//#include <migraphx/stringutils.hpp>
//#include <migraphx/literal.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
//#include <cmath>
//#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,7 +12,7 @@ namespace op {
struct argmin
{
int axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -30,8 +25,8 @@ struct argmin
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
......
#ifndef MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_mean
{
std::vector<int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reduce_mean"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
{
if(axis < 0 or axis >= lens.size())
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void calc_mean(tensor_view<T>& input,
shape& batch_shape,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val / batch_shape.elements();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_mean(input, batch_shape, out_idx, output);
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -4,6 +4,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
......@@ -13,7 +14,7 @@ namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes;
std::vector<int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -29,19 +30,47 @@ struct reduce_sum
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
{
if(axis < 0 or axis >= lens.size())
MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void calc_sum(tensor_view<T>& input,
shape& batch_shape,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](auto&& in_idx) {
auto out_idx = in_idx;
for(auto axis : axes)
out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_sum(input, batch_shape, out_idx, output);
});
});
......
......@@ -46,6 +46,7 @@
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
......
......@@ -100,6 +100,7 @@ struct onnx_parser
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_mean);
// init the activation function map
init_actv_func();
......@@ -285,10 +286,10 @@ struct onnx_parser
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 0;
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
......@@ -300,7 +301,7 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
......@@ -312,10 +313,10 @@ struct onnx_parser
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 0;
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
......@@ -327,7 +328,7 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
......@@ -1378,13 +1379,13 @@ struct onnx_parser
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
......@@ -1400,8 +1401,40 @@ struct onnx_parser
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
instruction_ref parse_reduce_mean(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_mean{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_mean{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
......
......@@ -650,44 +650,6 @@ struct cpu_logsoftmax
}
};
struct cpu_argmax
{
op::argmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct cpu_argmin
{
op::argmin op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct cpu_apply
{
program* prog;
......@@ -707,8 +669,6 @@ struct cpu_apply
void init()
{
apply_map["argmax"] = extend_op<cpu_argmax, op::argmax>();
apply_map["argmin"] = extend_op<cpu_argmin, op::argmin>();
apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
......
......@@ -39,6 +39,7 @@ add_library(migraphx_device
device/clip.cpp
device/reduce_sum.cpp
device/pow.cpp
device/reduce_mean.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device)
......@@ -77,6 +78,7 @@ add_library(migraphx_gpu
adjust_allocation.cpp
clip.cpp
reduce_sum.cpp
reduce_mean.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
......@@ -12,12 +12,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<type, argmax_op<type>>(argmax_op<type>{}, stream, result, arg, axis);
});
arg_op(argmax_op{}, stream, result, arg, axis);
}
} // namespace device
......
......@@ -12,12 +12,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis)
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<type, argmin_op<type>>(argmin_op<type>{}, stream, result, arg, axis);
});
arg_op(argmin_op{}, stream, result, arg, axis);
}
} // namespace device
......
......@@ -28,6 +28,16 @@ struct id
}
};
struct scale
{
size_t item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return static_cast<T>(x / item_num);
}
};
struct max
{
template <class T, class U>
......
#include <migraphx/gpu/device/reduce_mean.hpp>
#include <migraphx/gpu/device/reduce.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{
std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, scale{item_num});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,8 +22,20 @@ struct val_index
};
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v)
{
return {v, -1};
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
{
return {v, i};
}
struct argmax_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
......@@ -36,12 +48,12 @@ struct argmax_op
}
}
MIGRAPHX_DEVICE_CONSTEXPR T init() const { return lowest(); }
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
template <class T>
struct argmin_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
......@@ -54,11 +66,11 @@ struct argmin_op
}
}
MIGRAPHX_DEVICE_CONSTEXPR T init() const { return highest(); }
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class T, class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis)
template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens();
......@@ -69,28 +81,28 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch.
const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto batch_idx = batch_s.multi(i / block_size);
auto data_idx = batch_idx;
T init_val = op.init();
val_index<T> init = {init_val, -1};
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto batch_idx = batch_s.multi(i / block_size);
auto data_idx = batch_idx;
auto init = make_val_index<type>(op.init());
auto op_output = block_reduce<max_block_size, Op, val_index<T>>(
idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
T val = input[arg_s.index(data_idx)];
return val_index<T>{val, static_cast<int64_t>(j)};
});
auto op_output =
block_reduce<max_block_size>(idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return make_val_index(input[arg_s.index(data_idx)], j);
});
if(idx.local == 0)
{
output[batch_s.index(batch_idx)] = op_output.index;
}
});
if(idx.local == 0)
{
output[batch_s.index(batch_idx)] = op_output.index;
}
});
});
}
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis);
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_MEAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_MEAN_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_MEAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_MEAN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/reflect.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_reduce_mean
{
op::reduce_mean op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::reduce_mean"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -51,6 +51,7 @@
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -113,6 +114,7 @@ struct miopen_apply
add_extend_op<hip_convert, op::convert>("convert");
add_extend_op<hip_clip, op::clip>("clip");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean");
add_lrn_op();
add_convolution_op();
......
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/reduce_mean.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_reduce_mean::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument
hip_reduce_mean::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::reduce_mean(ctx.get_stream().get(), args.back(), args.front());
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -1778,4 +1778,79 @@ TEST_CASE(reduce_sum_test12)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 3, 6, 7, 10, 11};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5.5, 7.5};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5f, 6.5f, 10.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_int)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int> gold{2, 6, 10};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment