"docs/vscode:/vscode.git/clone" did not exist on "8ff039c2c709fe1fbeda1ba5bcc96f5287c3d2cd"
Commit 7bacd3ba authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into expand_dims_op

parents bbb5f61a 5af04e80
...@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z) ...@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z)
template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max() / 64;
const auto half_max = max / 2; const auto half_max = max / 2;
return half_max - (z % max); return half_max - (z % max);
} }
...@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z) ...@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z)
template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})> template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max() / 64;
return z % max; return z % max;
} }
......
...@@ -79,6 +79,7 @@ struct literal : raw_data<literal> ...@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
template <class Iterator> template <class Iterator>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard()) if(m_shape.standard())
{ {
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); }); m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reduce_sum"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
lens[axis] = 1;
return {s.type(), lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](auto&& in_idx) {
auto out_idx = in_idx;
for(auto axis : axes)
out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include <migraphx/op/outline.hpp> #include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
......
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
......
...@@ -2,7 +2,17 @@ ...@@ -2,7 +2,17 @@
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
......
...@@ -35,6 +35,7 @@ add_library(migraphx_device ...@@ -35,6 +35,7 @@ add_library(migraphx_device
device/gather.cpp device/gather.cpp
device/sub.cpp device/sub.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
...@@ -70,6 +71,7 @@ add_library(migraphx_gpu ...@@ -70,6 +71,7 @@ add_library(migraphx_gpu
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
clip.cpp clip.cpp
reduce_sum.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
...@@ -50,6 +50,14 @@ struct hip_array ...@@ -50,6 +50,14 @@ struct hip_array
result[i] = x[i] * y[i]; result[i] = x[i] * y[i];
return result; return result;
} }
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator+(const hip_array& x, const hip_array& y)
{
hip_array result{};
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] + y[i];
return result;
}
}; };
} // namespace device } // namespace device
......
...@@ -11,9 +11,33 @@ namespace device { ...@@ -11,9 +11,33 @@ namespace device {
struct index struct index
{ {
std::size_t global; std::size_t global = 0;
std::size_t local; std::size_t local = 0;
std::size_t group; std::size_t group = 0;
__device__ std::size_t nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ std::size_t nlocal() const { return blockDim.x; } // NOLINT
template <class F>
__device__ void global_stride(std::size_t n, F f) const
{
const auto stride = nglobal();
for(std::size_t i = global; i < n; i += stride)
{
f(i);
}
}
template <class F>
__device__ void local_stride(std::size_t n, F f) const
{
const auto stride = nlocal();
for(std::size_t i = local; i < n; i += stride)
{
f(i);
}
}
}; };
template <class F> template <class F>
...@@ -35,18 +59,26 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local) ...@@ -35,18 +59,26 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local)
}; };
} }
template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index idx) -> decltype(f(i, idx))
{
return f(i, idx);
}
template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i))
{
return f(i);
}
inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024) inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024)
{ {
std::size_t groups = 1 + n / local; std::size_t groups = 1 + n / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) { launch(stream, nglobal, local)(
for(size_t i = idx.global; i < n; i += nglobal) [=](auto idx) { idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); });
{
f(i);
}
});
}; };
} }
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_HPP
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
struct sum
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x + y;
}
};
struct id
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x;
}
};
struct max
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x > y ? x : y;
}
};
struct min
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x < y ? x : y;
}
};
struct lowest
{
template <class T>
operator T() const
{
return device_cast(std::numeric_limits<host_type<T>>::lowest());
}
};
struct highest
{
template <class T>
operator T() const
{
return device_cast(std::numeric_limits<host_type<T>>::max());
}
};
#ifdef MIGRAPHX_NO_DPP
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x;
__syncthreads();
for(std::size_t s = 1; s < idx.nlocal(); s *= 2)
{
const std::size_t index = 2 * s * idx.local;
if(index + s < idx.nlocal())
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
__syncthreads();
}
return buffer[0];
}
#else
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }
constexpr unsigned int dpp_row_bcast(unsigned int x)
{
unsigned int y = 0;
switch(x)
{
case 15: y = 0x142; break;
case 31: y = 0x143; break;
default: throw std::runtime_error("Unknown bcast");
}
return y;
}
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
static const std::size_t n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type
{
uint32_t reg[n];
T data;
};
type output{};
type input{};
// cppcheck-suppress unreadVariable
input.data = x;
for(std::size_t i = 0; i < n; i++)
{
output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
}
return output.data;
}
template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
T out;
out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
}
__device__ inline void dpp_reduce(float& x, sum)
{
#ifdef MIGRAPHX_USE_CLANG_TIDY
(void)x;
#else
__asm__ volatile("s_nop 4\n"
"v_add_f32 %0 %0 %0 row_shr:1\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:2\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:4 bank_mask:0xe\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:8 bank_mask:0xc\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_bcast:15 row_mask:0xa\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_bcast:31 row_mask:0xc\n"
"s_nop 1\n"
: "=v"(x)
: "0"(x));
#endif
}
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op);
const auto ldsidx = idx.local / 64;
if((idx.local % 64) == 63)
{
buffer[ldsidx] = x;
}
__syncthreads();
type y = init;
for(std::size_t i = 0; i < idx.nlocal() / 64; i++)
{
y = op(y, buffer[i]);
}
return y;
}
#endif
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
size_t block_size = 64;
while(block_size < max_block_size and block_size < n)
block_size *= 2;
return block_size;
}
template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output)
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens};
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements();
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
auto base_idx = output.get_shape().multi(out_idx);
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
auto reduce_idx = reduce_shape.multi(j);
return read_input(input[reduce_idx + base_idx]);
});
if(idx.local == 0)
output.data()[out_idx] = read_output(r);
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -73,6 +73,22 @@ struct hip_shape ...@@ -73,6 +73,22 @@ struct hip_shape
} }
return result; return result;
} }
MIGRAPHX_DEVICE_CONSTEXPR hip_index carry(hip_index result) const
{
std::ptrdiff_t rem = 0;
for(std::ptrdiff_t i = result.size() - 1; i >= 0; i--)
{
auto z = result[i] + rem;
rem = z - std::ptrdiff_t(lens[i]) + 1;
if(rem > 0)
z -= rem;
else
rem = 0;
result[i] = z;
}
return result;
}
}; };
template <std::size_t N> template <std::size_t N>
......
...@@ -91,7 +91,7 @@ using device_type = typename detail::device_type<T>::type; ...@@ -91,7 +91,7 @@ using device_type = typename detail::device_type<T>::type;
template <class T> template <class T>
host_type<T> host_cast(T x) host_type<T> host_cast(T x)
{ {
return reinterpret_cast<host_type<T>>(x); return reinterpret_cast<const host_type<T>&>(x);
} }
template <class T> template <class T>
......
#include <migraphx/gpu/device/reduce_sum.hpp>
#include <migraphx/gpu/device/reduce.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{
reduce(stream, result, arg, sum{}, 0, id{}, id{});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_SUM_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_SUM_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_SUM_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_SUM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/reflect.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_reduce_sum
{
op::reduce_sum op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::reduce_sum"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -47,6 +47,7 @@ ...@@ -47,6 +47,7 @@
#include <migraphx/gpu/lrn.hpp> #include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/convert.hpp> #include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -105,6 +106,7 @@ struct miopen_apply ...@@ -105,6 +106,7 @@ struct miopen_apply
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op<hip_pad, op::pad>("pad");
add_extend_op<hip_convert, op::convert>("convert"); add_extend_op<hip_convert, op::convert>("convert");
add_extend_op<hip_clip, op::clip>("clip"); add_extend_op<hip_clip, op::clip>("clip");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_lrn_op(); add_lrn_op();
add_convolution_op(); add_convolution_op();
......
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/reduce_sum.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_reduce_sum::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument
hip_reduce_sum::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::reduce_sum(ctx.get_stream().get(), args.back(), args.front());
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -112,6 +112,7 @@ struct tf_parser ...@@ -112,6 +112,7 @@ struct tf_parser
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0}); add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Tanh", op::tanh{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
......
...@@ -1583,4 +1583,79 @@ TEST_CASE(clip_test) ...@@ -1583,4 +1583,79 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(reduce_sum_test0)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{15, 18, 21, 24};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{4, 6, 12, 14, 20, 22};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{3, 7, 11, 15, 19, 23};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{33, 45};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{10, 26, 42};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment