"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "6ae2f087eb04b45acaff42fc20b30a6b08f24cca"
Commit b162c4ec authored by charlie's avatar charlie
Browse files

More progress

parent 412c298e
...@@ -31,6 +31,22 @@ ...@@ -31,6 +31,22 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
auto compute_broadcasting = [](std::vector<std::size_t> s0, std::vector<std::size_t> s1) {
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
};
// Example: // Example:
// s0 = (3,2,4,5) and s1 = (2,1,1) // s0 = (3,2,4,5) and s1 = (2,1,1)
// //
...@@ -50,20 +66,17 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, ...@@ -50,20 +66,17 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return s0; return s0;
if(s0.size() > s1.size()) if(s0.size() > s1.size())
s0.swap(s1); s0.swap(s1);
return compute_broadcasting(s0, s1);
}
std::vector<std::size_t> out_lens(s1); std::vector<std::size_t> broadcast_s0s1_lens(std::vector<std::size_t> s0,
auto offset = s1.size() - s0.size(); std::vector<std::size_t> s1)
std::transform( {
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) { if(s0 == s1)
if(a != b and a != 1 and b != 1) return s0;
{ if(s0.size() > s1.size())
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" + MIGRAPHX_THROW("BROADCAST_SHAPE_LENS: s0 size > s1 size and swap not allowed");
to_string_range(s1) + "} mismatch!"); return compute_broadcasting(s0, s1);
}
return std::max(a, b);
});
return out_lens;
} }
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
...@@ -114,6 +127,7 @@ instruction_ref insert_common_op(module& m, ...@@ -114,6 +127,7 @@ instruction_ref insert_common_op(module& m,
const operation& op, const operation& op,
std::vector<instruction_ref> inputs) std::vector<instruction_ref> inputs)
{ {
// TODO update this to handle dynamic shapes
auto common = common_shape(to_shapes(inputs)); auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens()) if(input->get_shape().lens() != common.lens())
......
...@@ -36,6 +36,11 @@ struct operation; ...@@ -36,6 +36,11 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1); std::vector<std::size_t> s1);
// This version doesn't allow s0.size() > s1.size()
std::vector<std::size_t> broadcast_s0s1_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m, instruction_ref insert_common_op(module& m,
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -56,63 +57,73 @@ struct broadcast ...@@ -56,63 +57,73 @@ struct broadcast
{ {
check_shapes{inputs, *this, true}.has(1, 2); check_shapes{inputs, *this, true}.has(1, 2);
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto t = s0.type(); auto t = s0.type();
if (inputs.size() == 1) if(inputs.size() == 1)
{ {
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative // the broadcast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
}
else
{
if(s0.dynamic() and s1.dynamic())
{
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_opt_lens = compute_broadcasted_lens(s0.opt_lens(), s1.opt_lens());
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(size_t i = 0; i < bcast_max_lens.size(); ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return {t, output_dyn_dims};
}
else if(not s0.dynamic() and not s1.dynamic())
{ {
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens()); MIGRAPHX_THROW("BROADCAST : axis is out of range");
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); }
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
return {t, std::move(bcast_lens), std::move(bcast_strides)}; if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
}
else
{
auto s1 = inputs.at(1);
if(axis >= s1.max_lens().size())
{
MIGRAPHX_THROW("BROADCAST_2in: axis is out of range of s1");
}
if(s1.max_lens().size() - axis < s0.max_lens().size())
{
MIGRAPHX_THROW("BROADCAST_2in: (s1 rank - axis) is less than s0 rank");
}
if(s0.dynamic() or s1.dynamic())
{
auto bcast_max_lens = broadcast_s0s1_lens(s0.max_lens(), s1.max_lens());
auto bcast_min_lens = broadcast_s0s1_lens(s0.min_lens(), s1.min_lens());
auto bcast_opt_lens = broadcast_s0s1_lens(s0.opt_lens(), s1.opt_lens());
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(size_t i = 0; i < bcast_max_lens.size(); ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
bcast_max_lens[i], bcast_min_lens[i], bcast_opt_lens[i]});
}
return {t, std::move(output_dyn_dims)};
} }
else else
{ {
MIGRAPHX_THROW( if(not std::equal(s0.lens().begin(), s0.lens().end(), s1.lens().begin() + axis))
"BROADCAST: s0 and s1 are not both dynamic or static"); {
MIGRAPHX_THROW("BROADCAST_2in: when broadcasting, succeeding sizes must match");
}
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
} }
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -50,27 +50,25 @@ struct multibroadcast ...@@ -50,27 +50,25 @@ struct multibroadcast
{ {
check_shapes{inputs, *this, true}.has(1, 2); check_shapes{inputs, *this, true}.has(1, 2);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
if(s0.lens().empty())
if(s0.lens().empty()) {
{ MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0"); }
}
auto make_bcast_strides = [&](std::size_t out_num_dims, std::size_t offset) {
auto make_bcast_strides = [&](std::size_t out_num_dims, std::size_t offset) std::vector<size_t> bcast_strides(out_num_dims, 0);
{ for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
std::vector<size_t> bcast_strides(out_num_dims, 0); {
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--) if(output_lens[i + offset] == s0.lens()[i])
{ {
if(output_lens[i + offset] == s0.lens()[i]) bcast_strides[i + offset] = s0.strides()[i];
{ }
bcast_strides[i + offset] = s0.strides()[i]; }
} return bcast_strides;
} };
return bcast_strides;
};
if(inputs.size() == 1) if(inputs.size() == 1)
{ {
...@@ -84,52 +82,42 @@ struct multibroadcast ...@@ -84,52 +82,42 @@ struct multibroadcast
{ {
if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1) if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1)
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(s0.lens()) +
"MULTIBROADCAST: input shape {" + to_string_range(s0.lens()) + "} cannot be broadcasted to {" + to_string_range(output_lens) +
"} cannot be broadcasted to {" + to_string_range(output_lens) + "}!"); "}!");
} }
} }
auto bcast_strides = make_bcast_strides(output_lens.size(), offset); auto bcast_strides = make_bcast_strides(output_lens.size(), offset);
return {t, output_lens, std::move(bcast_strides)}; return {t, output_lens, std::move(bcast_strides)};
} }
else else
{ {
// need both shapes when handling dynamic case
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes
// do we want this to work in both broadcast directions?
// s0 and s1 as shape inputs
// always s0 -> s1 shape or allow s0 to retain the same shape?
// presuming that it's always s0 -> s1 shape, since that's closer to the current behavior
// compute_broadcasted_lens() will swap the shapes if s1.size() < s0.size(), may need to make another function
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
if(s0.dynamic() and s1.dynamic()) if(s0.max_lens().size() > s1.max_lens().size())
{ {
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens()); MIGRAPHX_THROW("MULTIBROADCAST: s0 rank should <= s1 rank");
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens()); }
auto bcast_opt_lens = compute_broadcasted_lens(s0.opt_lens(), s1.opt_lens()); if(s0.dynamic() or s1.dynamic())
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(size_t i = 0; i < bcast_max_lens.size(); ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return {t, std::move(output_dyn_dims)};
}
else if(not s0.dynamic() and not s1.dynamic())
{ {
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens()); auto bcast_max_lens = broadcast_s0s1_lens(s0.max_lens(), s1.max_lens());
auto offset = s1.lens().size() - s0.lens().size(); auto bcast_min_lens = broadcast_s0s1_lens(s0.min_lens(), s1.min_lens());
auto bcast_strides = make_bcast_strides(s1.lens().size(), offset); auto bcast_opt_lens = broadcast_s0s1_lens(s0.opt_lens(), s1.opt_lens());
return {t, std::move(bcast_lens), std::move(bcast_strides)};
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(size_t i = 0; i < bcast_max_lens.size(); ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
bcast_max_lens[i], bcast_min_lens[i], bcast_opt_lens[i]});
}
return {t, std::move(output_dyn_dims)};
} }
else else
{ {
MIGRAPHX_THROW( auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
"MULTIBROADCAST: s0 and s1 are not both dynamic or static"); auto offset = s1.lens().size() - s0.lens().size();
auto bcast_strides = make_bcast_strides(s1.lens().size(), offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment