Commit b72ad090 authored by charlie's avatar charlie
Browse files

initial

parent 57f734a5
...@@ -51,21 +51,23 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, ...@@ -51,21 +51,23 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
}); });
return out_lens; return out_lens;
} }
std::vector<shape::dynamic_dimension>
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1) compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
std::vector<shape::dynamic_dimension> dds1)
{ {
// change both shapes to dynamic_dimension representation if(dds0.size() > dds1.size())
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
if(s0.ndim() > s1.ndim())
{ {
std::swap(s0, s1); std::swap(dds0, dds1);
} }
auto offset = s1.ndim() - s0.ndim(); auto offset = dds1.size() - dds0.size();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims()); std::vector<shape::dynamic_dimension> out_dims(dds1);
std::transform(s0.dyn_dims().cbegin(), // If one within the range of the other
s0.dyn_dims().cend(), auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
s1.dyn_dims().cbegin() + offset, return (x.min >= y.min and x.max <= y.max);
};
std::transform(dds0.cbegin(),
dds0.cend(),
dds1.cbegin() + offset,
out_dims.begin() + offset, out_dims.begin() + offset,
[&](auto a, auto b) { [&](auto a, auto b) {
if(a == b or b == 1) if(a == b or b == 1)
...@@ -76,16 +78,32 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -76,16 +78,32 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{ {
return b; return b;
} }
else if(dd_within_range(a, b))
{
return a;
}
else if(dd_within_range(b, a))
{
return b;
}
else else
{ {
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" + MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" + migraphx::to_string_range(dds0) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!"); migraphx::to_string_range(dds1) + "} mismatch!");
} }
}); });
return out_dims; return out_dims;
} }
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
// change both shapes to dynamic_dimension representation
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
return compute_broadcasted_dyn_dims(s0.dyn_dims(), s1.dyn_dims());
}
std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes) std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes)
{ {
auto ret_shape = shapes.at(0); auto ret_shape = shapes.at(0);
...@@ -151,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> ...@@ -151,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
auto c_dyn_dims = compute_common_dyn_dims(input_shapes); auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
auto s0 = inputs[0]->get_shape(); auto s0 = inputs[0]->get_shape();
if(not s0.dynamic() or s0.dyn_dims() != c_dyn_dims) // changed to always add the multibroadcast to handle the cases from split_single_dyn_dim
{ inputs[0] = m.insert_instruction(
inputs[0] = m.insert_instruction( ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
}
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) { std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the // uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime // full set of input shapes at runtime
auto s = input->get_shape(); auto s = input->get_shape();
if(not s.dynamic() or s.dyn_dims() != c_dyn_dims) return m.insert_instruction(
{ ins,
return m.insert_instruction( make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
ins, input,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs[0]);
input,
inputs[0]);
}
return input;
}); });
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type) if(input->get_shape().type() != c_type)
......
...@@ -58,6 +58,11 @@ MIGRAPHX_EXPORT ...@@ -58,6 +58,11 @@ MIGRAPHX_EXPORT
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1); std::vector<std::size_t> s1);
MIGRAPHX_EXPORT
std::vector<shape::dynamic_dimension>
compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
std::vector<shape::dynamic_dimension> dds1);
MIGRAPHX_EXPORT MIGRAPHX_EXPORT
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1); std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
......
...@@ -34,6 +34,9 @@ namespace migraphx { ...@@ -34,6 +34,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Matrix multiplication of two tensors.
*/
struct dot struct dot
{ {
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
...@@ -50,25 +53,36 @@ struct dot ...@@ -50,25 +53,36 @@ struct dot
} }
if(a.dynamic() or b.dynamic()) if(a.dynamic() or b.dynamic())
{ {
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
auto s0 = a.to_dynamic(); auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic(); auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2, if(not std::equal(s0.dyn_dims().rbegin() + 2,
s0.dyn_dims().rend(), s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2, s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend())) s1.dyn_dims().rend(),
[&](auto x, auto y) {
return (dd_within_range(x, y) or dd_within_range(y, x));
}))
{ {
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch: {" + MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch or not within "
"dynamic_dimension range: {" +
to_string_range(s0.dyn_dims()) + "} x {" + to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}"); to_string_range(s1.dyn_dims()) + "}");
} }
std::size_t dim_0 = s0.ndim() - 2; std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1; std::size_t dim_1 = s0.ndim() - 1;
if(s0.dyn_dims()[dim_1] != s1.dyn_dims()[dim_0]) auto x = s0.dyn_dims()[dim_1];
auto y = s1.dyn_dims()[dim_0];
if(not dd_within_range(x, y) and not dd_within_range(y, x))
{ {
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" + MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
to_string_range(s0.dyn_dims()) + "} x {" + to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}"); to_string_range(s1.dyn_dims()) + "}");
} }
// NOTE could make this compute_shape more precise by using outer dimensions of the
// shape that's dd_within_range. currently this just uses the outer dimensions of s0.
auto out_dyn_dims = s0.dyn_dims(); auto out_dyn_dims = s0.dyn_dims();
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1]; out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims}; return {t, out_dyn_dims};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_DOT_BROADCAST_HPP
#define MIGRAPHX_GUARD_OPERATORS_DOT_BROADCAST_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/common.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* Broadcast dimensions between two tensors for the `dot` operator.
* Essentially broadcasts between two shapes for dimensions other than the last two.
* This operator is only needed if one of the shapes are dynamic.
* Example:
* a = shape[{1, 4}, 3, 248, 248]
* b = shape[248, 365]
* dot_broadcast(a, b) => shape[{1, 4}, 3, 248, 248] (no change)
* dot_broadcast(b, a) => shape[{1, 4}, 3, 248, 365]
*/
struct dot_broadcast
{
std::string name() const { return "dot_broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(2);
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0.dynamic() or s1.dynamic())
{
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
auto dds0_it = s0.dyn_dims().end() - 2;
auto dds1_it = s1.dyn_dims().end() - 2;
std::vector<shape::dynamic_dimension> sliced_dds0{s0.dyn_dims().begin(), dds0_it};
std::vector<shape::dynamic_dimension> sliced_dds1{s1.dyn_dims().begin(), dds1_it};
auto output_dyn_dims = compute_broadcasted_dyn_dims(sliced_dds0, sliced_dds1);
output_dyn_dims.insert(output_dyn_dims.end(), dds0_it, s0.dyn_dims().end());
return {s0.type(), output_dyn_dims};
}
else
{
auto l0_it = s0.lens().begin() + s0.ndim() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0.lens().begin(), l0_it);
auto l1_it = s1.lens().begin() + s1.ndim() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1.lens().begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
output_lens.insert(output_lens.end(), l0_it, s0.lens().end());
return {s0.type(), output_lens};
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -71,14 +71,18 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -71,14 +71,18 @@ struct parse_matmul : op_parser<parse_matmul>
auto s0_dds = a0->get_shape().to_dynamic().dyn_dims(); auto s0_dds = a0->get_shape().to_dynamic().dyn_dims();
auto s1_dds = a1->get_shape().to_dynamic().dyn_dims(); auto s1_dds = a1->get_shape().to_dynamic().dyn_dims();
// TODO: handling this case requires a new multibroadcast mode
if(not std::equal( if(not std::equal(
s0_dds.rbegin() + 2, s0_dds.rend(), s1_dds.rbegin() + 2, s1_dds.rend())) s0_dds.rbegin() + 2, s0_dds.rend(), s1_dds.rbegin() + 2, s1_dds.rend()))
{ {
MIGRAPHX_THROW("PARSE_MATMUL: dynamic shape broadcasting not supported"); auto broadcasted_a0 = info.add_instruction(make_op("dot_broadcast"), a0, a1);
auto broadcasted_a1 = info.add_instruction(make_op("dot_broadcast"), a1, a0);
dot_res =
info.add_instruction(make_op(opd.op_name), broadcasted_a0, broadcasted_a1);
}
else
{
dot_res = info.add_instruction(make_op(opd.op_name), a0, a1);
} }
dot_res = info.add_instruction(make_op(opd.op_name), a0, a1);
} }
else else
{ {
......
...@@ -318,6 +318,39 @@ struct find_const_alloc_fill ...@@ -318,6 +318,39 @@ struct find_const_alloc_fill
} }
}; };
/**
* Simplify dot_broadcast instructions with two static shaped arguments
* From:
* dot_broadcast(static_shape_arg, static_shape_arg)
* To:
* multibroadcast(static_shape_arg); output_lens = static_dot_broadcasted_shape
*/
struct find_static_dot_broadcast
{
auto matcher() const
{
return match::name("dot_broadcast")(match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto dot_broadcast_ins = mr.result;
auto inputs = dot_broadcast_ins->inputs();
auto s0 = inputs.at(0)->get_shape();
auto s1 = inputs.at(1)->get_shape();
auto l0_it = s0.lens().begin() + s0.ndim() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0.lens().begin(), l0_it);
auto l1_it = s1.lens().begin() + s1.ndim() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1.lens().begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
output_lens.insert(output_lens.end(), l0_it, s0.lens().end());
m.replace_instruction(dot_broadcast_ins,
make_op("multibroadcast", {{"out_lens", output_lens}}),
inputs.at(0));
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches(m, match::find_matches(m,
...@@ -327,7 +360,8 @@ void simplify_dyn_ops::apply(module& m) const ...@@ -327,7 +360,8 @@ void simplify_dyn_ops::apply(module& m) const
find_const_2in_slice{}, find_const_2in_slice{},
find_const_3in_slice{}, find_const_3in_slice{},
find_const_4in_slice{}, find_const_4in_slice{},
find_const_alloc_fill{}); find_const_alloc_fill{},
find_static_dot_broadcast{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment