Commit 9280150b authored by charlie's avatar charlie
Browse files

scratch work to get bert_uncased working with dynamic batch

parent c84b8195
......@@ -139,6 +139,7 @@ register_migraphx_ops(
dimensions_of
div
dot
dot_broadcast
elu
equal
erf
......
......@@ -51,21 +51,23 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
});
return out_lens;
}
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
std::vector<shape::dynamic_dimension>
compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
std::vector<shape::dynamic_dimension> dds1)
{
// change both shapes to dynamic_dimension representation
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
if(s0.ndim() > s1.ndim())
if(dds0.size() > dds1.size())
{
std::swap(s0, s1);
std::swap(dds0, dds1);
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::transform(s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset,
auto offset = dds1.size() - dds0.size();
std::vector<shape::dynamic_dimension> out_dims(dds1);
// If one within the range of the other
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
std::transform(dds0.cbegin(),
dds0.cend(),
dds1.cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b or b == 1)
......@@ -76,16 +78,32 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{
return b;
}
else if(dd_within_range(a, b))
{
return a;
}
else if(dd_within_range(b, a))
{
return b;
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
migraphx::to_string_range(dds0) + "} and {" +
migraphx::to_string_range(dds1) + "} mismatch!");
}
});
return out_dims;
}
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
// change both shapes to dynamic_dimension representation
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
return compute_broadcasted_dyn_dims(s0.dyn_dims(), s1.dyn_dims());
}
std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes)
{
auto ret_shape = shapes.at(0);
......
......@@ -58,6 +58,11 @@ MIGRAPHX_EXPORT
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
MIGRAPHX_EXPORT
std::vector<shape::dynamic_dimension>
compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
std::vector<shape::dynamic_dimension> dds1);
MIGRAPHX_EXPORT
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
......
......@@ -63,7 +63,12 @@ struct dot
}
std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1;
if(s0.dyn_dims()[dim_1] != s1.dyn_dims()[dim_0])
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
auto x = s0.dyn_dims()[dim_1];
auto y = s1.dyn_dims()[dim_0];
if(not dd_within_range(x, y) and not dd_within_range(y, x))
{
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_DOT_BROADCAST_HPP
#define MIGRAPHX_GUARD_OPERATORS_DOT_BROADCAST_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/common.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* Broadcast dimensions between two tensors for the `dot` operator.
* Essentially broadcasts between two shapes for dimensions other than the last two.
* This operator is only needed if one of the shapes are dynamic.
* Example:
* a = shape[{1, 4}, 3, 248, 248]
* b = shape[248, 365]
* dot_broadcast(a, b) => shape[{1, 4}, 3, 248, 248] (no change)
* dot_broadcast(b, a) => shape[{1, 4}, 3, 248, 365]
*/
struct dot_broadcast
{
std::string name() const { return "dot_broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(2);
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0.dynamic() or s1.dynamic())
{
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
auto dds0_it = s0.dyn_dims().end() - 2;
auto dds1_it = s1.dyn_dims().end() - 2;
std::vector<shape::dynamic_dimension> sliced_dds0{s0.dyn_dims().begin(), dds0_it};
std::vector<shape::dynamic_dimension> sliced_dds1{s1.dyn_dims().begin(), dds1_it};
auto output_dyn_dims = compute_broadcasted_dyn_dims(sliced_dds0, sliced_dds1);
output_dyn_dims.insert(output_dyn_dims.end(), dds0_it, s0.dyn_dims().end());
return {s0.type(), output_dyn_dims};
}
else
{
auto l0_it = s0.lens().begin() + s0.ndim() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0.lens().begin(), l0_it);
auto l1_it = s1.lens().begin() + s1.ndim() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1.lens().begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
output_lens.insert(output_lens.end(), l0_it, s0.lens().end());
return {s0.type(), output_lens};
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -69,7 +69,7 @@ struct reshape
auto dyn_dims = s0.dyn_dims();
auto num_not_fixed = std::count_if(
dyn_dims.cbegin(), dyn_dims.cend(), [](auto dd) { return not dd.is_fixed(); });
if(num_not_fixed != 1)
if(num_not_fixed == 1)
{
MIGRAPHX_THROW("Reshape: Only supports one non-fixed dynamic_dimension");
}
......@@ -110,6 +110,12 @@ struct reshape
return shape::dynamic_dimension{dim, dim};
});
return {s0.type(), output_dyn_dims};
/*
std::size_t max_val = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dds(dims.size(),
shape::dynamic_dimension{0, max_val});
return {s0.type(), dds};
*/
}
template <class Iterator>
......
......@@ -71,14 +71,18 @@ struct parse_matmul : op_parser<parse_matmul>
auto s0_dds = a0->get_shape().to_dynamic().dyn_dims();
auto s1_dds = a1->get_shape().to_dynamic().dyn_dims();
// TODO: handling this case requires a new multibroadcast mode
if(not std::equal(
s0_dds.rbegin() + 2, s0_dds.rend(), s1_dds.rbegin() + 2, s1_dds.rend()))
{
MIGRAPHX_THROW("PARSE_MATMUL: dynamic shape broadcasting not supported");
auto broadcasted_a0 = info.add_instruction(make_op("dot_broadcast"), a0, a1);
auto broadcasted_a1 = info.add_instruction(make_op("dot_broadcast"), a1, a0);
dot_res =
info.add_instruction(make_op(opd.op_name), broadcasted_a0, broadcasted_a1);
}
else
{
dot_res = info.add_instruction(make_op(opd.op_name), a0, a1);
}
dot_res = info.add_instruction(make_op(opd.op_name), a0, a1);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment