"examples/deepseek_mla/vscode:/vscode.git/clone" did not exist on "d684094bd1b13059d4b2d764abb3ebb5e1dcf5c0"
Commit c539a7b0 authored by charlie's avatar charlie
Browse files

Refactor into precomputing dyn output shape

also adding limitations on broadcasting dynamic shapes
parent 5fc6afe6
...@@ -66,33 +66,50 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, ...@@ -66,33 +66,50 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return out_lens; return out_lens;
} }
// Handling opt dyn_dims calculation std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{ {
if(s0 == s1) if(s0.dynamic() or s1.dynamic())
return s0; {
if(s0.size() > s1.size()) // change both shapes to dynamic_dimension representation
s0.swap(s1); if(not s0.dynamic())
std::vector<std::size_t> out_lens(s1); s0 = s0.to_dynamic();
auto offset = s1.size() - s0.size(); if(not s1.dynamic())
std::transform( s1 = s1.to_dynamic();
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a == b) if(s0.rank() > s1.rank())
{ {
return a; std::swap(s0, s1);
} }
else if((a == 1 or b == 1) and a != 0 and b != 0) auto offset = s1.rank() - s0.rank();
{ std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
return std::max(a, b); std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
} std::transform(
else s0.dyn_dims().cbegin(),
{ s0.dyn_dims().cend(),
// if not matching nor 1, set to 0 s1.dyn_dims().cbegin() + offset,
return static_cast<std::size_t>(0); out_dims.begin() + offset,
} [&](auto a, auto b) {
}); if(a == b)
return out_lens; {
return a;
}
else if(contains(one_dyn_dims, a) or contains(one_dyn_dims, b))
{
return shape::dynamic_dimension{
std::max(a.min, b.min), std::max(a.max, b.max), std::max(a.opt, b.opt)};
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
}
});
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: given two static shapes");
}
} }
// Compute the common (broadcasted) dimensions of a list of fixed shapes // Compute the common (broadcasted) dimensions of a list of fixed shapes
...@@ -149,24 +166,36 @@ instruction_ref insert_common_op(module& m, ...@@ -149,24 +166,36 @@ instruction_ref insert_common_op(module& m,
if(std::any_of( if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); })) inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{ {
// currently only handles the binary case
if(inputs.size() != 2)
{
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs");
}
auto c_type = compute_common_types(to_shapes(inputs)); auto c_type = compute_common_types(to_shapes(inputs));
// broadcast all inputs combinations auto c_dyn_dims =
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto a_input) { compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
const auto& ori_input = a_input;
// multibroadcast this input between every other input // following should work for a static or dynamic shape
std::for_each(inputs.cbegin(), inputs.cend(), [&](auto b_input) { // TODO: compute_broadcasted_dyn_dims() is going to be called again in the multibroadcast
if(b_input != ori_input) // compute_shape should figure out a way to get around recomputing that. Attribute in
{ // multibroadcast?
a_input = if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
m.insert_instruction(ins, make_op("multibroadcast"), a_input, b_input); {
} inputs[0] = m.insert_instruction(ins, make_op("multibroadcast"), inputs[0], inputs[1]);
}); }
if(a_input->get_shape().type() != c_type) if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[1] = m.insert_instruction(ins, make_op("multibroadcast"), inputs[1], inputs[0]);
}
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{ {
a_input = m.insert_instruction( input =
ins, make_op("convert", {{"target_type", c_type}}), a_input); m.insert_instruction(ins, make_op("convert", {{"target_type", c_type}}), input);
} }
return a_input; return input;
}); });
} }
else else
......
...@@ -37,9 +37,6 @@ struct operation; ...@@ -37,9 +37,6 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1); std::vector<std::size_t> s1);
std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m, instruction_ref insert_common_op(module& m,
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -63,7 +64,15 @@ struct binary : op_name<Derived> ...@@ -63,7 +64,15 @@ struct binary : op_name<Derived>
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims(); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed()) if(s0.dynamic() and s1.dynamic() and s0 == s1)
{
return s0;
}
else if(s0.dynamic() or s1.dynamic())
{
MIGRAPHX_THROW("BINARY: " + point_function() + ": fixed-dyn shape for inputs");
}
else if(s0 == s1 and s0.packed())
{ {
return s0; return s0;
} }
...@@ -81,9 +90,9 @@ struct binary : op_name<Derived> ...@@ -81,9 +90,9 @@ struct binary : op_name<Derived>
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
......
...@@ -104,10 +104,51 @@ struct multibroadcast ...@@ -104,10 +104,51 @@ struct multibroadcast
if(s0.dynamic() and s1.dynamic()) if(s0.dynamic() and s1.dynamic())
{ {
// TODO handle both dynamic case // TODO handle both dynamic case
MIGRAPHX_THROW("MULTIBROADCAST_2IN: two dynamic shape inputs not handled."); MIGRAPHX_THROW(
"MULTIBROADCAST_2IN: not handled; two dynamic shape inputs not handled");
} }
else if(s0.dynamic() or s1.dynamic()) else if(s0.dynamic() or s1.dynamic())
{ {
// only handles the case when broadcasting static shape to dynamic shape
// all the dimensions in the static shape must match to a fixed dimension in the
// dynamic shape or be 1
// TODO: handling the other possibilities
if(s1.dynamic())
{
std::swap(s0, s1);
}
auto static_rank = s1.lens().size();
auto dyn_rank = s0.max_lens().size();
if(static_rank > dyn_rank)
{
MIGRAPHX_THROW("MULTIBROADCAST_2IN: not handled; static shape has a higher "
"rank than dynamic shape");
}
return s0;
auto offset = dyn_rank - static_rank;
std::vector<shape::dynamic_dimension> out_dims(s0.dyn_dims());
std::transform(s0.dyn_dims().begin(),
s0.dyn_dims().end(),
s1.lens().begin() + offset,
out_lens.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if((a == 1 or b == 1) and a != 0 and b != 0)
{
return std::max(a, b);
}
else
{
// if not matching nor 1, set to 0
return static_cast<std::size_t>(0);
}
});
/*
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens()); auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens()); auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_opt_lens = compute_broadcasted_opt_lens(s0.opt_lens(), s1.opt_lens()); auto bcast_opt_lens = compute_broadcasted_opt_lens(s0.opt_lens(), s1.opt_lens());
...@@ -115,6 +156,7 @@ struct multibroadcast ...@@ -115,6 +156,7 @@ struct multibroadcast
std::move(bcast_min_lens), std::move(bcast_min_lens),
std::move(bcast_max_lens), std::move(bcast_max_lens),
std::move(bcast_opt_lens)}; std::move(bcast_opt_lens)};
*/
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment