Commit 940220c8 authored by charlie's avatar charlie
Browse files

Progress?

parent 8b25fd3e
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis // In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving // of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
//
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1) std::vector<std::size_t> s1)
{ {
...@@ -64,6 +66,79 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, ...@@ -64,6 +66,79 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return out_lens; return out_lens;
} }
// Handling opt dyn_dims calculation
std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
if(s0 == s1)
return s0;
if(s0.size() > s1.size())
s0.swap(s1);
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(a == 1 or b == 1)
{
return std::max(a, b);
}
else
{
// if not matching nor 1, set to 0
return static_cast<std::size_t>(0);
}
});
return out_lens;
}
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
if(s0 == s1)
return s0.dyn_dims();
if(not s0.dynamic() or not s1_dynamic())
{
// mixed fixed and dynamic
if(s0.dynamic())
s0.swap(s1);
}
else
{
// both dynamic
if(s0.dyn_dims().size() > s1.dyn_dims().size())
std::swap(s0, s1);
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
auto offset = s1.size() - s0.size();
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
std::transform(s0.begin(),
s0.end(),
s1.begin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(not contains(one_dyn_dims, a) and not contains(one_dyn_dims, b))
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0) + "} and {" +
migraphx::to_string_range(s1) + "} mismatch!");
}
else
{
return shape::dynamic_dimension{std::max(a.min, b.min),
std::max(a.max, b.max),
(a.opt != b.opt) ? 0 : a.opt};
}
});
return out_dims;
}
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes // Compute the common (broadcasted) dimensions of a list of fixed shapes
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{ {
......
...@@ -37,9 +37,8 @@ struct operation; ...@@ -37,9 +37,8 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1); std::vector<std::size_t> s1);
// This version doesn't allow s0.size() > s1.size() std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> broadcast_s0s1_lens(std::vector<std::size_t> s0, std::vector<std::size_t> s1);
std::vector<std::size_t> s1);
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
......
...@@ -105,7 +105,7 @@ struct multibroadcast ...@@ -105,7 +105,7 @@ struct multibroadcast
{ {
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens()); auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens()); auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_opt_lens = compute_broadcasted_lens(s0.opt_lens(), s1.opt_lens()); auto bcast_opt_lens = compute_broadcasted_opt_lens(s0.opt_lens(), s1.opt_lens());
return {t, return {t,
std::move(bcast_min_lens), std::move(bcast_min_lens),
std::move(bcast_max_lens), std::move(bcast_max_lens),
......
...@@ -1124,6 +1124,53 @@ TEST_CASE(multibroadcast) ...@@ -1124,6 +1124,53 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(multibroadcast_2in)
{
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, a},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
// dynamic_dimensions must be the same or one is {1, 1, 0}
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1, 0}, {2, 4, 0}, {1, 1, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, a},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
migraphx::shape b_shape{migraphx::shape::float_type, {1, 6, 2}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 6, 0}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {10, 3, 8}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{10, 10, 0}, {3, 4, 0}, {8, 8, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
// both inputs are fixed
}
TEST_CASE(multinomial) TEST_CASE(multinomial)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 5}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment