"include/ck/utility/functional.hpp" did not exist on "569ad66e2a03789c4a1fa6659dc8296b4dfb868b"
Unverified Commit 63563da2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add dimensionality reduction to device functions (#587)



* Add reduce dims

* Formatting

* Reduce dims on the gpu

* Formatting

* Fix tidy issues

* Convert to assert

* Reduce dims for contiguous

* Formatting

* Remove move

* Fix arguments used

* Formatting

* Fix warnings

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
parent a5648d9c
...@@ -21,6 +21,7 @@ add_library(migraphx ...@@ -21,6 +21,7 @@ add_library(migraphx
instruction.cpp instruction.cpp
program.cpp program.cpp
quantization.cpp quantization.cpp
reduce_dims.cpp
remap.cpp remap.cpp
shape.cpp shape.cpp
schedule.cpp schedule.cpp
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<shape> reduce_dims(const std::vector<shape>& shapes);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
{
std::vector<std::size_t> new_lens;
for(const auto& s : shapes)
{
assert(n < s.lens().size());
if((n + 1) >= s.lens().size())
return false;
auto astride = s.strides()[n];
auto alen = s.lens()[n];
auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1];
if(astride == bstride * blen)
{
new_lens.push_back(alen * blen);
}
}
if(new_lens.size() != shapes.size())
return false;
std::size_t i = 0;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.erase(lens.begin() + n);
strides.erase(strides.begin() + n);
lens[n] = new_lens[i];
s = shape{s.type(), lens, strides};
i++;
}
return true;
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size())
{
}
return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
{
std::size_t n = 0;
while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n);
}
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
{
return std::accumulate(
shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) {
std::vector<std::size_t> result;
const auto* x = &s.lens();
const auto* y = &lens;
if(x->size() > y->size())
std::swap(x, y);
std::transform(
x->begin(), x->end(), y->begin(), std::back_inserter(result), [&](auto a, auto b) {
return std::max(a, b);
});
return result;
});
}
shape mask_shape(const shape& s, const std::vector<std::size_t>& lens)
{
assert(s.lens().size() == lens.size());
std::vector<std::size_t> rstrides(lens.size());
std::size_t stride = 1;
for(std::size_t i = lens.size() - 1; i < lens.size(); i--)
{
if(lens[i] == s.lens()[i])
{
rstrides[i] = stride;
stride *= lens[i];
}
else if(lens[i] != 1 and s.lens()[i] != 1)
{
return shape{};
}
}
return shape{s.type(), lens, rstrides};
}
std::vector<shape> reduce_dims(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
auto result = shapes;
auto base = base_lens(shapes);
for(auto&& s : shapes)
{
if(s.lens().size() != base.size())
return shapes;
if(s.lens() == base)
continue;
auto mshape = mask_shape(s, base);
if(mshape.lens().size() != base.size())
return shapes;
result.push_back(mshape);
}
reduce_dim_all(result);
result.erase(result.begin() + shapes.size(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -105,7 +105,6 @@ add_library(migraphx_gpu ...@@ -105,7 +105,6 @@ add_library(migraphx_gpu
quant_convolution.cpp quant_convolution.cpp
softmax.cpp softmax.cpp
logsoftmax.cpp logsoftmax.cpp
contiguous.cpp
concat.cpp concat.cpp
leaky_relu.cpp leaky_relu.cpp
batchnorm.cpp batchnorm.cpp
......
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)});
}
argument miopen_contiguous::compute(context& ctx,
shape output_shape,
const std::vector<argument>& args) const
{
assert(output_shape == args[1].get_shape());
assert(output_shape.standard());
(void)output_shape;
device::contiguous(ctx.get_stream().get(), args.at(1), args.at(0));
return args.at(1);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -22,7 +22,7 @@ argument concat(hipStream_t stream, ...@@ -22,7 +22,7 @@ argument concat(hipStream_t stream,
auto output_shape = shape{ auto output_shape = shape{
arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()}; arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()};
auto output = argument{output_shape, args.back().data() + byte_offset}; auto output = argument{output_shape, args.back().data() + byte_offset};
contiguous(stream, std::move(output), arg); contiguous(stream, output, arg);
} }
return args.back(); return args.back();
} }
......
...@@ -7,9 +7,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -7,9 +7,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void contiguous(hipStream_t stream, argument result, argument arg) void contiguous(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, std::move(result), std::move(arg))([](auto x) __device__ { return x; }); nary(stream, result, arg)([](auto x) __device__ { return x; });
} }
} // namespace device } // namespace device
......
...@@ -34,17 +34,7 @@ constexpr void visit_tensor_size(index_int n, F f) ...@@ -34,17 +34,7 @@ constexpr void visit_tensor_size(index_int n, F f)
f(std::integral_constant<index_int, 4>{}); f(std::integral_constant<index_int, 4>{});
break; break;
} }
case 5: default: throw std::runtime_error("Tensor dims " + std::to_string(n) + " out of range");
{
f(std::integral_constant<index_int, 5>{});
break;
}
case 6:
{
f(std::integral_constant<index_int, 6>{});
break;
}
default: throw std::runtime_error("Tensor size dim out of range");
} }
} }
......
...@@ -169,68 +169,16 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -169,68 +169,16 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation); contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation);
} }
struct hip_triadd struct hip_triadd : ternary_device<hip_triadd, &device::add>
{ {
std::string name() const { return "hip::triadd"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
struct hip_triadd_clip struct hip_triadd_clip : quinary_device<hip_triadd_clip, &device::add_clip>
{ {
std::string name() const { return "hip::triadd_clip"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(6);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_clip(ctx.get_stream().get(),
args.at(5),
args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
return args.at(5);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
struct hip_add_clip struct hip_add_clip : quaternary_device<hip_add_clip, &device::add_clip>
{ {
std::string name() const { return "hip::add_clip"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(5);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_clip(
ctx.get_stream().get(), args.at(4), args.at(0), args.at(1), args.at(2), args.at(3));
return args.at(4);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu> struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
...@@ -273,43 +221,12 @@ struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new> ...@@ -273,43 +221,12 @@ struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new>
{ {
}; };
struct hip_mul_add struct hip_mul_add : ternary_device<hip_mul_add, &device::mul_add>
{ {
std::string name() const { return "hip::mul_add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
struct hip_mul_add_relu struct hip_mul_add_relu : ternary_device<hip_mul_add_relu, &device::mul_add_relu>
{ {
std::string name() const { return "hip::mul_add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add_relu(
ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
void move_broadcasted_back(std::vector<instruction_ref>& args) void move_broadcasted_back(std::vector<instruction_ref>& args)
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -10,22 +12,17 @@ namespace gpu { ...@@ -10,22 +12,17 @@ namespace gpu {
struct context; struct context;
struct miopen_contiguous struct miopen_contiguous : unary_device<miopen_contiguous, &device::contiguous>
{ {
op::contiguous op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::contiguous"; } std::string name() const { return "gpu::contiguous"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const
argument compute(context&, shape output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; check_shapes{inputs, *this}.has(2);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
} }
}; };
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void contiguous(hipStream_t stream, argument result, argument arg); void contiguous(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <utility> #include <utility>
#include <iostream> #include <iostream>
...@@ -15,95 +16,126 @@ namespace migraphx { ...@@ -15,95 +16,126 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class Derived, void (*F)(hipStream_t, const argument&, const argument&)> template <class Derived, std::size_t N>
struct unary_device : oper<Derived> struct device_base : oper<Derived>
{ {
shape compute_shape(const std::vector<shape>& inputs) const template <class Self, class F>
static auto reflect(Self&, F)
{ {
check_shapes{inputs, *this}.has(2); return pack();
auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const std::vector<shape> reduce_shapes;
void finalize(context&, const shape&, const std::vector<shape>& inputs)
{ {
F(ctx.get_stream().get(), args[1], args[0]); reduce_shapes = reduce_dims(inputs);
return args[1];
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const argument get_arg(const std::vector<argument>& args, std::size_t i) const
{ {
return shapes.size() - 1; if(reduce_shapes.empty())
return args[i];
return args.at(i).reshape(reduce_shapes.at(i));
} }
};
template <class Derived, void (*F)(hipStream_t, const argument&, const argument&, const argument&)>
struct binary_device : oper<Derived>
{
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(N + 1);
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); if(std::all_of(inputs.begin(), inputs.end() - 1, [&](auto s) { return s == s0; }) and
if(s0 == s1 and s0.packed()) s0.packed())
{
return s0; return s0;
}
else else
{
return {s0.type(), s0.lens()}; return {s0.type(), s0.lens()};
}
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
template <class Derived, void (*F)(hipStream_t, const argument&, const argument&)>
struct unary_device : device_base<Derived, 1>
{
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
F(ctx.get_stream().get(), args[2], args[0], args[1]); F(ctx.get_stream().get(), this->get_arg(args, 1), this->get_arg(args, 0));
return args[2]; return args[1];
} }
};
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const template <class Derived, void (*F)(hipStream_t, const argument&, const argument&, const argument&)>
struct binary_device : device_base<Derived, 2>
{
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
return shapes.size() - 1; F(ctx.get_stream().get(),
this->get_arg(args, 2),
this->get_arg(args, 0),
this->get_arg(args, 1));
return args[2];
} }
}; };
template <class Derived, template <class Derived,
void (*F)( void (*F)(
hipStream_t, const argument&, const argument&, const argument&, const argument&)> hipStream_t, const argument&, const argument&, const argument&, const argument&)>
struct ternary_device : oper<Derived> struct ternary_device : device_base<Derived, 3>
{ {
shape compute_shape(const std::vector<shape>& inputs) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
check_shapes{inputs, *this}.has(4); F(ctx.get_stream().get(),
auto s0 = inputs.at(0); this->get_arg(args, 3),
auto s1 = inputs.at(1); this->get_arg(args, 0),
auto s2 = inputs.at(2); this->get_arg(args, 1),
if(s0 == s1 and s1 == s2 and s0.packed()) this->get_arg(args, 2));
{ return args[3];
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
} }
};
template <class Derived,
void (*F)(hipStream_t,
const argument&,
const argument&,
const argument&,
const argument&,
const argument&)>
struct quaternary_device : device_base<Derived, 4>
{
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
F(ctx.get_stream().get(), args[3], args[0], args[1], args[2]); F(ctx.get_stream().get(),
return args[3]; this->get_arg(args, 4),
this->get_arg(args, 0),
this->get_arg(args, 1),
this->get_arg(args, 2),
this->get_arg(args, 3));
return args[4];
} }
};
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const template <class Derived,
void (*F)(hipStream_t,
const argument&,
const argument&,
const argument&,
const argument&,
const argument&,
const argument&)>
struct quinary_device : device_base<Derived, 5>
{
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
return shapes.size() - 1; F(ctx.get_stream().get(),
this->get_arg(args, 5),
this->get_arg(args, 0),
this->get_arg(args, 1),
this->get_arg(args, 2),
this->get_arg(args, 3),
this->get_arg(args, 4));
return args[5];
} }
}; };
......
...@@ -169,8 +169,8 @@ struct miopen_apply ...@@ -169,8 +169,8 @@ struct miopen_apply
add_generic_op<hip_ceil>("ceil"); add_generic_op<hip_ceil>("ceil");
add_generic_op<hip_floor>("floor"); add_generic_op<hip_floor>("floor");
add_generic_op<hip_recip>("recip"); add_generic_op<hip_recip>("recip");
add_generic_op<miopen_contiguous>("contiguous");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax"); add_extend_op<hip_softmax, op::softmax>("softmax");
add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax"); add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax");
......
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include "test.hpp"
migraphx::shape make_shape(std::vector<std::size_t> lens)
{
return {migraphx::shape::float_type, std::move(lens)};
}
migraphx::shape make_shape(std::vector<std::size_t> lens, std::vector<std::size_t> strides)
{
return {migraphx::shape::float_type, std::move(lens), std::move(strides)};
}
TEST_CASE(same_standard)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64 * 3 * 7 * 7});
std::vector<migraphx::shape> ishapes = {is, is, is};
std::vector<migraphx::shape> eshapes = {os, os, os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(same_broadcast1)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64, 3, 7 * 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 7, 7}, {0, 1, 0, 0}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({64, 3, 7 * 7}, {0, 1, 0}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(same_broadcast2)
{
auto is = make_shape({64, 3, 8, 7, 7});
auto os = make_shape({64, 8 * 3, 7 * 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 8, 7, 7}, {0, 8, 1, 0, 0}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({64, 8 * 3, 7 * 7}, {0, 1, 0}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(same_transposed)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64 * 3, 7, 7});
std::vector<migraphx::shape> ishapes = {is, migraphx::reorder_shape(is, {0, 1, 3, 2}), is};
std::vector<migraphx::shape> eshapes = {os, migraphx::reorder_shape(os, {0, 2, 1}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(different_masked1)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64, 3, 7 * 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3, 1, 1}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(different_masked2)
{
auto is = make_shape({64, 3, 7, 7});
auto os = make_shape({64, 3, 7 * 7});
std::vector<migraphx::shape> ishapes = {
is, make_shape({1, 3, 1, 1}), make_shape({64, 1, 7, 7})};
std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), make_shape({64, 1, 7 * 7})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(different_incompatible)
{
auto is = make_shape({64, 3, 7, 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3, 2, 1}), is};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(ishapes == rshapes);
}
TEST_CASE(different_ranks)
{
auto is = make_shape({64, 3, 7, 7});
std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3}), is};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(ishapes == rshapes);
}
TEST_CASE(transposed1)
{
std::vector<migraphx::shape> ishapes = {
make_shape({8, 28, 4, 56, 56}),
make_shape({8, 28, 4, 56, 56}, {351232, 3136, 87808, 56, 1})};
std::vector<migraphx::shape> eshapes = {
make_shape({8, 28, 4, 56 * 56}), make_shape({8, 28, 4, 56 * 56}, {351232, 3136, 87808, 1})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(empty)
{
auto rshapes = migraphx::reduce_dims({});
EXPECT(rshapes.empty());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment