Unverified Commit d73c6d7c authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dyn ref multibroadcast; dyn binary (#1423)

Updated Multibroadcast op to have a two input version for dynamic shapes
Current dynamic shape broadcasting logic
dynamic_dimensions must be the same or one of them is {1, 1, 0} or {1, 1, 1}
Works for dyn-dyn, dyn-static, and static-static shape combinations
Changed common.cpp for multibroadcasting for binary ops with dynamic shapes
Extended binary.hpp for dynamic shapes to test the new common.cpp stuff
parent df2e7635
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis // In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving // of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
//
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1) std::vector<std::size_t> s1)
{ {
...@@ -50,25 +52,63 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, ...@@ -50,25 +52,63 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return s0; return s0;
if(s0.size() > s1.size()) if(s0.size() > s1.size())
s0.swap(s1); s0.swap(s1);
std::vector<std::size_t> out_lens(s1); std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size(); auto offset = s1.size() - s0.size();
std::transform( std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) { s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1) if(a != b and a != 1 and b != 1)
{ {
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" + MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + migraphx::to_string_range(s0) +
to_string_range(s1) + "} mismatch!"); "} and {" + migraphx::to_string_range(s1) + "} mismatch!");
} }
return std::max(a, b); return std::max(a, b);
}); });
return out_lens; return out_lens;
} }
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
// change both shapes to dynamic_dimension representation
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
if(s0.ndim() > s1.ndim())
{
std::swap(s0, s1);
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(a == one_dyn_dim or b == one_dyn_dim)
{
// setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
}
});
return out_dims;
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{ {
assert(not shapes.empty()); assert(not shapes.empty());
assert(
std::none_of(shapes.cbegin(), shapes.cend(), [](auto shape) { return shape.dynamic(); }));
return transform_accumulate(shapes.begin() + 1, return transform_accumulate(shapes.begin() + 1,
shapes.end(), shapes.end(),
shapes.front().lens(), shapes.front().lens(),
...@@ -114,20 +154,63 @@ instruction_ref insert_common_op(module& m, ...@@ -114,20 +154,63 @@ instruction_ref insert_common_op(module& m,
const operation& op, const operation& op,
std::vector<instruction_ref> inputs) std::vector<instruction_ref> inputs)
{ {
auto common = common_shape(to_shapes(inputs)); if(std::any_of(
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
if(input->get_shape().lens() != common.lens()) {
// currently only handles the binary case
if(inputs.size() != 2)
{ {
input = m.insert_instruction( MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input); "inputs, only handle two inputs if any are dynamic shape");
} }
if(input->get_shape().type() != common.type())
auto c_type = compute_common_types(to_shapes(inputs));
auto c_dyn_dims =
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{ {
input = m.insert_instruction( inputs[0] = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input); ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0],
inputs[1]);
} }
return input; if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
}); {
inputs[1] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1],
inputs[0]);
}
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{
input =
m.insert_instruction(ins, make_op("convert", {{"target_type", c_type}}), input);
}
return input;
});
}
else
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
}
return input;
});
}
return m.insert_instruction(ins, op, inputs); return m.insert_instruction(ins, op, inputs);
} }
......
...@@ -36,6 +36,9 @@ struct operation; ...@@ -36,6 +36,9 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1); std::vector<std::size_t> s1);
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m, instruction_ref insert_common_op(module& m,
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -60,10 +61,19 @@ struct binary : op_name<Derived> ...@@ -60,10 +61,19 @@ struct binary : op_name<Derived>
value attributes() const { return base_attributes(); } value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims(); check_shapes{inputs, static_cast<const Derived&>(*this), true}
.has(2)
.same_type()
.same_dims();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed()) if(s0.dynamic() or s1.dynamic())
{
if(s0 == s1)
return s0;
MIGRAPHX_THROW("BINARY: " + point_function() + ": fixed-dyn shape for inputs");
}
else if(s0 == s1 and s0.packed())
{ {
return s0; return s0;
} }
...@@ -81,9 +91,9 @@ struct binary : op_name<Derived> ...@@ -81,9 +91,9 @@ struct binary : op_name<Derived>
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
......
...@@ -27,23 +27,30 @@ ...@@ -27,23 +27,30 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This /**
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are * 1 input version:
/// computed from multi-indicies by computing the inner product on the multi-index with the strides. * Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want * broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd * that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is * input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
/// obvious from there that we can negate the effects of a given axis by setting the stride of that * work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
/// axis to zero. * with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D static shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens; std::vector<std::size_t> broadcast_lens = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -54,36 +61,86 @@ struct broadcast ...@@ -54,36 +61,86 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto input = inputs.at(0); check_shapes{inputs, *this, true}.has(1, 2);
auto t = input.type(); auto s0 = inputs.at(0);
auto t = s0.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); if(inputs.size() == 1)
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis is out of range"); // the ONNX broadcast op is deprecated now, so not handling the negative
} // value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
if(broadcast_lens.size() - axis < input.lens().size()) std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
{ std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims"); shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
{
// don't think this can occur?
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
}
return output;
} }
else
if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); // two inputs
} auto s1 = inputs.at(1);
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); if(s0.dynamic())
{
MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape");
}
if(s0.ndim() != 1)
{
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
}
if(axis >= s1.ndim())
{
MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
}
if(s1.dynamic())
{
s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
return s1;
}
shape output{t, broadcast_lens, std::move(bcast_strides)}; if(s0.lens()[0] != s1.lens()[axis])
if(output.elements() < input.elements()) {
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size"); MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
return output; "dimension length (" +
migraphx::to_string(s0.lens()[0]) +
" != " + migraphx::to_string(s1.lens()[axis]) + ")");
}
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, s1.lens(), std::move(bcast_strides)};
return output;
}
} }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -26,64 +26,105 @@ ...@@ -26,64 +26,105 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/common.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator: one input and two inputs.
* One input version uses output_lens attribute and broadcasts to it.
* Two inputs version broadcasts both inputs to the common shape at evaluation time.
*/
struct multibroadcast struct multibroadcast
{ {
std::vector<std::size_t> output_lens; std::vector<std::size_t> output_lens = {};
// optional attribute
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.output_lens, "out_lens")); return pack(f(self.output_lens, "out_lens"), f(self.output_dyn_dims, "out_dyn_dims"));
} }
std::string name() const { return "multibroadcast"; } std::string name() const { return "multibroadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1, 2);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
if(input.lens().empty()) auto t = inputs.at(0).type();
{ auto s0 = inputs.at(0);
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input.lens().size() > output_lens.size()) if(s0.max_lens().empty())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size"); MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should be > 0");
} }
auto offset = output_lens.size() - input.lens().size(); auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) {
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{
if(bcast_lens[i + offset] == s0.lens()[i])
{
bcast_strides[i + offset] = s0.strides()[i];
}
}
return bcast_strides;
};
if(inputs.size() == 1)
{ {
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1) if(s0.lens().size() > output_lens.size())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) + MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
} }
}
std::vector<size_t> bcast_strides(output_lens.size(), 0); auto offset = output_lens.size() - s0.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(s0.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
}
}
auto bcast_strides = make_bcast_strides(output_lens, offset);
return {t, output_lens, std::move(bcast_strides)};
}
else
{ {
if(output_lens[i + offset] == input.lens()[i]) // two inputs
auto s1 = inputs.at(1);
if(s0.dynamic() or s1.dynamic())
{ {
bcast_strides[i + offset] = input.strides()[i]; if(not output_dyn_dims.empty())
{
return {t, output_dyn_dims};
}
return {t, compute_broadcasted_dyn_dims(s0, s1)};
}
else
{
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
auto offset = bcast_lens.size() - s0.lens().size();
auto bcast_strides = make_bcast_strides(bcast_lens, offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
} }
} }
return {t, output_lens, bcast_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -89,7 +90,10 @@ struct shape ...@@ -89,7 +90,10 @@ struct shape
std::size_t opt = 0; std::size_t opt = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f); static auto reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool is_fixed() const; bool is_fixed() const;
bool has_optimal() const; bool has_optimal() const;
...@@ -115,6 +119,12 @@ struct shape ...@@ -115,6 +119,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank)
shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{ {
...@@ -136,6 +146,12 @@ struct shape ...@@ -136,6 +146,12 @@ struct shape
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*!
* The number of dimensions in the shape.
* Same as the number of indices required to get a data value.
*/
std::size_t ndim() const;
/*! /*!
* Return the number of elements in the tensor. * Return the number of elements in the tensor.
*/ */
...@@ -221,6 +237,9 @@ struct shape ...@@ -221,6 +237,9 @@ struct shape
shape with_type(type_t t) const; shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape
shape to_dynamic() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
...@@ -44,7 +44,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -44,7 +44,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto x_lens = args[0]->get_shape().lens(); auto x_lens = args[0]->get_shape().max_lens();
auto x_type = args[0]->get_shape().type(); auto x_type = args[0]->get_shape().type();
if(std::any_of(args.cbegin() + 1, args.cend(), [](auto a) { if(std::any_of(args.cbegin() + 1, args.cend(), [](auto a) {
......
...@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op> ...@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op>
parser.parse_value(info.attributes.at("broadcast")).at<uint64_t>(); parser.parse_value(info.attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0) if(broadcasted != 0)
{ {
if(std::any_of(
args.cbegin(), args.cend(), [](auto a) { return a->get_shape().dynamic(); }))
{
MIGRAPHX_THROW(
"Binary op broadcast attribute not supported for dynamic input shapes");
}
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>(); uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction( auto l = info.add_instruction(
make_op("broadcast", make_op("broadcast",
......
...@@ -71,6 +71,19 @@ struct shape_impl ...@@ -71,6 +71,19 @@ struct shape_impl
{ {
} }
shape_impl(shape::type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: m_type(t)
{
assert(mins.size() == maxes.size() and maxes.size() == opts.size());
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], opts[i]});
}
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type; shape::type_t m_type;
...@@ -224,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims) ...@@ -224,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
{ {
} }
shape::shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: impl(std::make_shared<shape_impl>(t, std::move(mins), std::move(maxes), std::move(opts)))
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {} shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
...@@ -244,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } ...@@ -244,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::ndim() const
{
if(this->dynamic())
{
return dyn_dims().size();
}
return lens().size();
}
std::size_t shape::elements() const { return impl->elements(); } std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const std::size_t shape::bytes() const
...@@ -437,6 +467,16 @@ shape shape::with_type(type_t t) const ...@@ -437,6 +467,16 @@ shape shape::with_type(type_t t) const
return {c}; return {c};
} }
shape shape::to_dynamic() const
{
if(this->dynamic())
{
return *this;
}
std::vector<std::size_t> zeroes(this->ndim(), 0);
return {type(), lens(), lens(), zeroes};
}
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
...@@ -464,15 +504,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; ...@@ -464,15 +504,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; } bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
return (x.min == y.min and x.max == y.max and x.opt == y.opt); // don't check opt if both are fixed
return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt)));
} }
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
......
...@@ -420,6 +420,74 @@ def batch_norm_invalid_bias_rank_test(): ...@@ -420,6 +420,74 @@ def batch_norm_invalid_bias_rank_test():
return ([node], [x, scale, bias, mean, var], [out]) return ([node], [x, scale, bias, mean, var], [out])
@onnx_test
def binary_dyn_brcst_prelu_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT,
[None, 3, 4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 5])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT,
[None, 3, 4, 5])
node = onnx.helper.make_node(
'PRelu',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def binary_dyn_brcst_add_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT,
[None, 3, 4, 5])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT,
[None, 3, 4, 5])
node = onnx.helper.make_node(
'Add',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def binary_dyn_brcst_attr_error_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT,
[None, 3, 4, 5])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT,
[None, 3, 4, 5])
node = onnx.helper.make_node('Add',
inputs=['0', '1'],
outputs=['out'],
broadcast=1,
axis=1)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def binary_dyn_brcst_mul_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT,
[None, 3, 4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 1])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT,
[None, 3, 4, 5])
node = onnx.helper.make_node(
'Mul',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test @onnx_test
def cast_test(): def cast_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [10])
......
...@@ -521,6 +521,76 @@ TEST_CASE(batch_norm_invalid_bias_rank) ...@@ -521,6 +521,76 @@ TEST_CASE(batch_norm_invalid_bias_rank)
EXPECT(test::throws([&] { migraphx::parse_onnx("batch_norm_invalid_bias_rank.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("batch_norm_invalid_bias_rank.onnx"); }));
} }
TEST_CASE(binary_dyn_brcst_prelu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto ret = add_common_op(*mm, migraphx::make_op("prelu"), {l0, l1});
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("binary_dyn_brcst_prelu_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(binary_dyn_brcst_add_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, {4, 5}});
auto l1 = mm->add_parameter(
"1",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
auto ret = add_common_op(*mm, migraphx::make_op("add"), {l0, l1});
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("binary_dyn_brcst_add_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(binary_dyn_brcst_attr_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("binary_dyn_brcst_attr_error_test.onnx", options); }));
}
TEST_CASE(binary_dyn_brcst_mul_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1}});
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast",
{{"out_dyn_dims", to_value(l0->get_shape().dyn_dims())}}),
l1,
l0);
auto ret = mm->add_instruction(migraphx::make_op("mul"), l0, bl1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("binary_dyn_brcst_mul_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(cast_test) TEST_CASE(cast_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -81,6 +81,14 @@ void throws_shape(const migraphx::shape&, Ts...) ...@@ -81,6 +81,14 @@ void throws_shape(const migraphx::shape&, Ts...)
"An expected shape should not be passed to throws_shape function"); "An expected shape should not be passed to throws_shape function");
} }
TEST_CASE(binary_dyn_static_error)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1, 0}, {4, 4, 4}, {4, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("add"), a_shape, b_shape);
}
TEST_CASE(broadcast) TEST_CASE(broadcast)
{ {
{ {
...@@ -118,6 +126,69 @@ TEST_CASE(broadcast) ...@@ -118,6 +126,69 @@ TEST_CASE(broadcast)
} }
} }
TEST_CASE(broadcast_axis_out_of_range_error)
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 4}, {"out_lens", lens}}), input);
}
TEST_CASE(broadcast_2in_static_static)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {1, 0}},
migraphx::make_op("broadcast", {{"axis", 0}}),
a_input,
b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {0, 1}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_not_matching_error)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {2, 2}, {2, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_dynamic_s0_error1)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input);
}
TEST_CASE(broadcast_2in_dynamic_s0_error2)
{
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}};
migraphx::shape a_input{migraphx::shape::float_type, dd};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_static_dyn)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_dyn_s0_ndim_greater_than_1_error)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
TEST_CASE(convolution_shape) TEST_CASE(convolution_shape)
{ {
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
...@@ -1114,6 +1185,213 @@ TEST_CASE(multibroadcast) ...@@ -1114,6 +1185,213 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(multibroadcast_2in_static_dyn0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 4}, {4, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_dyn1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_dyn2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(b)}}),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(b)}}),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_dyn_error0)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
TEST_CASE(multibroadcast_2in_static_dyn_error1)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
TEST_CASE(multibroadcast_2in_static_dyn_error2)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
TEST_CASE(multibroadcast_2in_dyn_dyn0)
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_dyn_dyn1)
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(a)}}),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(a)}}),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_dyn_dyn_error0)
{
// max doesn't match on second dimension of a
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
TEST_CASE(multibroadcast_2in_dyn_dyn_error1)
{
// opt doesn't match on second dimension of a
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 3}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
TEST_CASE(multibroadcast_2in_static_static0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 8}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {8, 1}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {0, 0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {4, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static3)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {16, 4, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static4)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 1, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {4, 0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static_error0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
TEST_CASE(multinomial) TEST_CASE(multinomial)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 5}};
......
...@@ -225,6 +225,30 @@ TEST_CASE(add_test) ...@@ -225,6 +225,30 @@ TEST_CASE(add_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(add_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("add"), x, y);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{-1, 0, 1};
std::vector<float> y_data{1, 2, 3};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(argmax_test_0) TEST_CASE(argmax_test_0)
{ {
migraphx::program p; migraphx::program p;
...@@ -670,6 +694,52 @@ TEST_CASE(broadcast_test) ...@@ -670,6 +694,52 @@ TEST_CASE(broadcast_test)
EXPECT(output(1, 1) == -3); EXPECT(output(1, 1) == -3);
} }
TEST_CASE(broadcast_2in_static_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {2, 2}};
std::vector<int32_t> a_data{0, 0, 0, 0};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3};
uint64_t axis = 0;
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
mm->add_instruction(migraphx::make_op("broadcast", {{"axis", axis}}), l2, l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto output = result.get<int32_t>();
EXPECT(output(0, 0) == -2);
EXPECT(output(0, 1) == -2);
EXPECT(output(1, 0) == -3);
EXPECT(output(1, 1) == -3);
}
TEST_CASE(broadcast_2in_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 2, 0}, {2, 4, 0}}};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3};
uint64_t axis = 0;
auto pa = mm->add_parameter("a", a_shape);
auto lb = mm->add_literal(migraphx::literal{b_shape, b_data});
mm->add_instruction(migraphx::make_op("broadcast", {{"axis", axis}}), lb, pa);
p.compile(migraphx::ref::target{});
std::vector<int32_t> a_data{0, 0, 0, 0};
migraphx::shape input_fixed_shape0{migraphx::shape::int32_type, {2, 2}};
migraphx::parameter_map params0;
params0["a"] = migraphx::argument(input_fixed_shape0, a_data.data());
auto result = p.eval(params0).back();
auto output = result.get<int32_t>();
EXPECT(output(0, 0) == -2);
EXPECT(output(0, 1) == -2);
EXPECT(output(1, 0) == -3);
EXPECT(output(1, 1) == -3);
}
TEST_CASE(ceil_test) TEST_CASE(ceil_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1860,6 +1930,32 @@ TEST_CASE(div_test) ...@@ -1860,6 +1930,32 @@ TEST_CASE(div_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(div_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 3}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("div"), x, y);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{-1.0f, 0.5f, 1.0f};
std::vector<float> y_data{1.0f, 2.0f, 4.0f};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(x_data.size());
std::transform(
x_data.begin(), x_data.end(), y_data.begin(), gold.begin(), std::divides<float>());
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(elu_test) TEST_CASE(elu_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1947,6 +2043,35 @@ TEST_CASE(equal_test) ...@@ -1947,6 +2043,35 @@ TEST_CASE(equal_test)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(equal_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{6, 12, 9}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto p0 = mm->add_parameter("l", s);
auto p1 = mm->add_parameter("r", s);
auto eq = mm->add_instruction(migraphx::make_op("equal"), p0, p1);
auto r = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
eq);
mm->add_return({r});
p.compile(migraphx::ref::target{});
std::vector<float> l_data{1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
std::vector<float> r_data{1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}};
params0["l"] = migraphx::argument(input_fixed_shape0, l_data.data());
params0["r"] = migraphx::argument(input_fixed_shape0, r_data.data());
auto result = p.eval(params0).back();
std::vector<bool> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold = {true, false, false, false, true, false, true, false, false};
EXPECT(results_vector == gold);
}
TEST_CASE(erf_test) TEST_CASE(erf_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2607,6 +2732,35 @@ TEST_CASE(greater_test) ...@@ -2607,6 +2732,35 @@ TEST_CASE(greater_test)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(greater_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, 9}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
auto gr = mm->add_instruction(migraphx::make_op("greater"), left, right);
auto r = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
gr);
mm->add_return({r});
p.compile(migraphx::ref::target{});
std::vector<float> left_data{1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
std::vector<float> right_data{1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}};
params0["l"] = migraphx::argument(input_fixed_shape0, left_data.data());
params0["r"] = migraphx::argument(input_fixed_shape0, right_data.data());
auto result = p.eval(params0).back();
std::vector<bool> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold = {false, false, true, true, false, true, false, false, true};
EXPECT(results_vector == gold);
}
TEST_CASE(identity_test) TEST_CASE(identity_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3187,6 +3341,40 @@ TEST_CASE(less_test) ...@@ -3187,6 +3341,40 @@ TEST_CASE(less_test)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(less_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, 9}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
auto le = mm->add_instruction(migraphx::make_op("less"), left, right);
auto r = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
le);
mm->add_return({r});
p.compile(migraphx::ref::target{});
std::vector<float> left_data = {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
std::vector<float> right_data = {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}};
params0["l"] = migraphx::argument(input_fixed_shape0, left_data.data());
params0["r"] = migraphx::argument(input_fixed_shape0, right_data.data());
auto result = p.eval(params0).back();
std::vector<bool> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold(left_data.size());
std::transform(left_data.begin(),
left_data.end(),
right_data.begin(),
gold.begin(),
[](float n1, float n2) -> bool { return n1 < n2; });
EXPECT(results_vector == gold);
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3250,6 +3438,35 @@ TEST_CASE(logical_and_test) ...@@ -3250,6 +3438,35 @@ TEST_CASE(logical_and_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(logical_and_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}};
migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
mm->add_instruction(migraphx::make_op("logical_and"), left, right);
p.compile(migraphx::ref::target{});
std::vector<char> left_data{1, 0, 1, 0};
std::vector<char> right_data{1, 1, 0, 0};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::bool_type, {4}};
params0["l"] = migraphx::argument(input_fixed_shape0, left_data.data());
params0["r"] = migraphx::argument(input_fixed_shape0, right_data.data());
auto result = p.eval(params0).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold(left_data.size());
std::transform(left_data.begin(),
left_data.end(),
right_data.begin(),
gold.begin(),
[](bool n1, bool n2) -> bool { return n1 and n2; });
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(logical_or_test) TEST_CASE(logical_or_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3272,6 +3489,35 @@ TEST_CASE(logical_or_test) ...@@ -3272,6 +3489,35 @@ TEST_CASE(logical_or_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(logical_or_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}};
migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
mm->add_instruction(migraphx::make_op("logical_or"), left, right);
p.compile(migraphx::ref::target{});
std::vector<char> left_data{1, 0, 1, 0};
std::vector<char> right_data{1, 1, 0, 0};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::bool_type, {4}};
params0["l"] = migraphx::argument(input_fixed_shape0, left_data.data());
params0["r"] = migraphx::argument(input_fixed_shape0, right_data.data());
auto result = p.eval(params0).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold(left_data.size());
std::transform(left_data.begin(),
left_data.end(),
right_data.begin(),
gold.begin(),
[](bool n1, bool n2) -> bool { return n1 or n2; });
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(logical_xor_test) TEST_CASE(logical_xor_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3294,6 +3540,35 @@ TEST_CASE(logical_xor_test) ...@@ -3294,6 +3540,35 @@ TEST_CASE(logical_xor_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(logical_xor_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}};
migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
mm->add_instruction(migraphx::make_op("logical_xor"), left, right);
p.compile(migraphx::ref::target{});
std::vector<char> left_data{1, 0, 1, 0};
std::vector<char> right_data{1, 1, 0, 0};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::bool_type, {4}};
params0["l"] = migraphx::argument(input_fixed_shape0, left_data.data());
params0["r"] = migraphx::argument(input_fixed_shape0, right_data.data());
auto result = p.eval(params0).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold = {false, true, true, false};
std::transform(left_data.begin(),
left_data.end(),
right_data.begin(),
gold.begin(),
[](bool n1, bool n2) -> bool { return n1 ^ n2; });
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(logsoftmax_test_axis_0) TEST_CASE(logsoftmax_test_axis_0)
{ {
migraphx::program p; migraphx::program p;
...@@ -3521,6 +3796,34 @@ TEST_CASE(max_test) ...@@ -3521,6 +3796,34 @@ TEST_CASE(max_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(max_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto curr_max = mm->add_instruction(migraphx::make_op("max"), x, y);
mm->add_instruction(migraphx::make_op("max"), curr_max, z);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{1, 4, 3};
std::vector<float> y_data{2, 8, 6};
std::vector<float> z_data{7, 5, 9};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
params0["z"] = migraphx::argument(input_fixed_shape0, z_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{7, 8, 9};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3714,6 +4017,34 @@ TEST_CASE(min_test) ...@@ -3714,6 +4017,34 @@ TEST_CASE(min_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(min_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto curr_min = mm->add_instruction(migraphx::make_op("min"), x, y);
mm->add_instruction(migraphx::make_op("min"), curr_min, z);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{1, 4, 3};
std::vector<float> y_data{2, 8, 6};
std::vector<float> z_data{7, 5, 9};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
params0["z"] = migraphx::argument(input_fixed_shape0, z_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 4, 3};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(fmod_test) TEST_CASE(fmod_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3732,6 +4063,34 @@ TEST_CASE(fmod_test) ...@@ -3732,6 +4063,34 @@ TEST_CASE(fmod_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(fmod_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto curr_mod = mm->add_instruction(migraphx::make_op("fmod"), x, y);
mm->add_instruction(migraphx::make_op("fmod"), curr_mod, z);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{-7, 8, -3};
std::vector<float> y_data{2, 4, 6};
std::vector<float> z_data{7, 5, 9};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
params0["z"] = migraphx::argument(input_fixed_shape0, z_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1, 0, -3};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(fmod_float_test) TEST_CASE(fmod_float_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3769,6 +4128,34 @@ TEST_CASE(mod_test) ...@@ -3769,6 +4128,34 @@ TEST_CASE(mod_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(mod_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto curr_mod = mm->add_instruction(migraphx::make_op("mod"), x, y);
mm->add_instruction(migraphx::make_op("mod"), curr_mod, z);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{-3, 8, -7};
std::vector<float> y_data{3, 3, 3};
std::vector<float> z_data{10, 2, 9};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
params0["z"] = migraphx::argument(input_fixed_shape0, z_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 2};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(mod_float_test) TEST_CASE(mod_float_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3810,6 +4197,100 @@ TEST_CASE(mul_test) ...@@ -3810,6 +4197,100 @@ TEST_CASE(mul_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(mul_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("mul"), x, y);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{-1, 0, 1};
std::vector<float> y_data{1, 2, 3};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(x_data.size());
std::transform(x_data.begin(),
x_data.end(),
y_data.begin(),
gold.begin(),
[](float n1, float n2) -> float { return n1 * n2; });
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {2, 2}};
std::vector<int32_t> a_data{0, 0, 0, 0};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3};
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", l1->get_shape().lens()}}),
l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto output = result.get<int32_t>();
EXPECT(output(0, 0) == -2);
EXPECT(output(0, 1) == -3);
EXPECT(output(1, 0) == -2);
EXPECT(output(1, 1) == -3);
}
TEST_CASE(multibroadcast_2in_static_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {2, 2}};
std::vector<int32_t> a_data{0, 0, 0, 0};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3};
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
mm->add_instruction(migraphx::make_op("multibroadcast"), l2, l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto output = result.get<int32_t>();
EXPECT(output(0, 0) == -2);
EXPECT(output(0, 1) == -3);
EXPECT(output(1, 0) == -2);
EXPECT(output(1, 1) == -3);
}
TEST_CASE(multibroadcast_2in_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 4, 0}, {2, 2, 0}}};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3};
auto l1 = mm->add_parameter("a", a_shape);
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
mm->add_instruction(migraphx::make_op("multibroadcast"), l2, l1);
p.compile(migraphx::ref::target{});
std::vector<int32_t> a_data{0, 0, 0, 0};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 2}};
params0["a"] = migraphx::argument(input_fixed_shape0, a_data.data());
auto result = p.eval(params0).back();
auto output = result.get<int32_t>();
EXPECT(output(0, 0) == -2);
EXPECT(output(0, 1) == -3);
EXPECT(output(1, 0) == -2);
EXPECT(output(1, 1) == -3);
}
TEST_CASE(multinomial_test) TEST_CASE(multinomial_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -4389,7 +4870,31 @@ TEST_CASE(pow_test) ...@@ -4389,7 +4870,31 @@ TEST_CASE(pow_test)
mm->add_instruction(migraphx::make_op("pow"), b, e); mm->add_instruction(migraphx::make_op("pow"), b, e);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pow_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = mm->add_parameter("b", s);
auto e = mm->add_parameter("e", s);
mm->add_instruction(migraphx::make_op("pow"), b, e);
p.compile(migraphx::ref::target{});
std::vector<float> data = {1, 2, 3};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["b"] = migraphx::argument(input_fixed_shape0, data.data());
params0["e"] = migraphx::argument(input_fixed_shape0, data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
...@@ -4778,6 +5283,30 @@ TEST_CASE(prelu_test) ...@@ -4778,6 +5283,30 @@ TEST_CASE(prelu_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(prelu_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto slope = mm->add_parameter("slope", s);
mm->add_instruction(migraphx::make_op("prelu"), x, slope);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{-1, 0, 2};
std::vector<float> slope_data{2, 1, 2};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["slope"] = migraphx::argument(input_fixed_shape0, slope_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 0.0f, 2.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(quant_conv2d_padding_stride_test) TEST_CASE(quant_conv2d_padding_stride_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -6449,6 +6978,30 @@ TEST_CASE(sqdiff_test) ...@@ -6449,6 +6978,30 @@ TEST_CASE(sqdiff_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sqdiff_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("sqdiff"), x, y);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{-1, 0, 1};
std::vector<float> y_data{1, 2, 3};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sqrt_test) TEST_CASE(sqrt_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -6584,6 +7137,30 @@ TEST_CASE(sub_test) ...@@ -6584,6 +7137,30 @@ TEST_CASE(sub_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sub_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("sub"), x, y);
p.compile(migraphx::ref::target{});
std::vector<float> x_data{-1, 0, 1};
std::vector<float> y_data{1, 2, 3};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, -2, -2};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(tan_test) TEST_CASE(tan_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -38,6 +38,27 @@ TEST_CASE(test_shape_default) ...@@ -38,6 +38,27 @@ TEST_CASE(test_shape_default)
EXPECT(s.elements() == 0); EXPECT(s.elements() == 0);
EXPECT(s.bytes() == 0); EXPECT(s.bytes() == 0);
} }
TEST_CASE(test_dyn_4arg_constructor)
{
migraphx::shape s{migraphx::shape::float_type,
{
1,
4,
4,
},
{
4,
4,
4,
},
{0, 0, 0}};
std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {
{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
EXPECT(s.dynamic());
EXPECT(s.dyn_dims() == expected_dyn_dims);
}
TEST_CASE(test_shape_assign) TEST_CASE(test_shape_assign)
{ {
migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}}; migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}};
...@@ -185,6 +206,31 @@ TEST_CASE(test_shape_packed) ...@@ -185,6 +206,31 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_ndim_static)
{
migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4, 4}};
EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type, {2, 4, 4, 1, 3}};
EXPECT(s2.ndim() == 5);
}
TEST_CASE(test_shape_ndim_dyn)
{
migraphx::shape s0{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}};
EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type,
{{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}};
EXPECT(s2.ndim() == 5);
}
TEST_CASE(test_shape_non_packed_single_dim) TEST_CASE(test_shape_non_packed_single_dim)
{ {
migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}}; migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}};
...@@ -212,6 +258,21 @@ TEST_CASE(test_shape_transposed2) ...@@ -212,6 +258,21 @@ TEST_CASE(test_shape_transposed2)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_static_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_dynamic();
migraphx::shape s2{migraphx::shape::float_type, {{1, 1, 0}, {2, 2, 0}, {4, 4, 0}, {4, 4, 0}}};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_dyn_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
migraphx::shape s1 = s0.to_dynamic();
EXPECT(s0 == s1);
}
TEST_CASE(test_shape_overlap) TEST_CASE(test_shape_overlap)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment