Commit 5af9aac0 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_batch_pass' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_test_runner

parents 7b2516e0 05e81ed3
...@@ -70,9 +70,6 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const ...@@ -70,9 +70,6 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
if(not fs::exists(out_path)) if(not fs::exists(out_path))
MIGRAPHX_THROW("Output file missing: " + out); MIGRAPHX_THROW("Output file missing: " + out);
if(process)
out_path = process(out_path);
return read_buffer(out_path.string()); return read_buffer(out_path.string());
} }
......
...@@ -326,8 +326,7 @@ struct compiler ...@@ -326,8 +326,7 @@ struct compiler
loader l; loader l;
program_params parameters; program_params parameters;
compiler_target ct; compiler_target ct;
bool offload_copy = false; compile_options co;
bool fast_math = true;
precision quantize = precision::fp32; precision quantize = precision::fp32;
std::vector<std::string> fill0; std::vector<std::string> fill0;
...@@ -337,19 +336,26 @@ struct compiler ...@@ -337,19 +336,26 @@ struct compiler
l.parse(ap); l.parse(ap);
parameters.parse(ap); parameters.parse(ap);
ct.parse(ap); ct.parse(ap);
ap(offload_copy, ap(co.offload_copy,
{"--enable-offload-copy"}, {"--enable-offload-copy"},
ap.help("Enable implicit offload copying"), ap.help("Enable implicit offload copying"),
ap.set_value(true)); ap.set_value(true));
ap(fast_math, ap(co.fast_math,
{"--disable-fast-math"}, {"--disable-fast-math"},
ap.help("Disable fast math optimization"), ap.help("Disable fast math optimization"),
ap.set_value(false)); ap.set_value(false));
ap(co.exhaustive_tune,
{"--exhaustive-tune"},
ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16)); ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8)); ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
} }
auto params(const program& p) { return parameters.generate(p, ct.get_target(), offload_copy); } auto params(const program& p)
{
return parameters.generate(p, ct.get_target(), co.offload_copy);
}
program compile() program compile()
{ {
...@@ -366,10 +372,7 @@ struct compiler ...@@ -366,10 +372,7 @@ struct compiler
{ {
quantize_int8(p, t, {params(p)}); quantize_int8(p, t, {params(p)});
} }
compile_options options; p.compile(t, co);
options.offload_copy = offload_copy;
options.fast_math = fast_math;
p.compile(t, options);
l.save(p); l.save(p);
return p; return p;
} }
...@@ -402,60 +405,41 @@ struct params : command<params> ...@@ -402,60 +405,41 @@ struct params : command<params>
struct verify : command<verify> struct verify : command<verify>
{ {
loader l; compiler c;
program_params parameters;
compiler_target ct;
double tolerance = 80; double tolerance = 80;
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
bool offload_copy = false;
bool fast_math = true;
precision quantize = precision::fp32;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
l.parse(ap); c.parse(ap);
parameters.parse(ap);
ct.parse(ap);
ap(offload_copy,
{"--enable-offload-copy"},
ap.help("Enable implicit offload copying"),
ap.set_value(true));
ap(fast_math,
{"--disable-fast-math"},
ap.help("Disable fast math optimization"),
ap.set_value(false));
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors")); ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
ap(per_instruction, ap(per_instruction,
{"-i", "--per-instruction"}, {"-i", "--per-instruction"},
ap.help("Verify each instruction"), ap.help("Verify each instruction"),
ap.set_value(true)); ap.set_value(true));
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true)); ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
} }
void run() void run()
{ {
auto p = l.load(); auto p = c.l.load();
l.save(p); c.l.save(p);
std::cout << p << std::endl; std::cout << p << std::endl;
compile_options options; auto t = c.ct.get_target();
options.offload_copy = offload_copy; auto m = c.parameters.generate(p, t, true);
options.fast_math = fast_math;
auto t = ct.get_target();
auto m = parameters.generate(p, t, true);
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, options, quantize, tolerance); verify_instructions(p, t, c.co, c.quantize, tolerance);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, options, quantize, m, tolerance); verify_reduced_program(p, t, c.co, c.quantize, m, tolerance);
} }
else else
{ {
verify_program(l.file, p, t, options, quantize, m, tolerance); verify_program(c.l.file, p, t, c.co, c.quantize, m, tolerance);
} }
} }
}; };
......
...@@ -108,8 +108,6 @@ target get_target(bool gpu) ...@@ -108,8 +108,6 @@ target get_target(bool gpu)
return make_target("cpu"); return make_target("cpu");
} }
void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
...@@ -37,7 +37,6 @@ parameter_map create_param_map(const program& p, const target& t, bool offload = ...@@ -37,7 +37,6 @@ parameter_map create_param_map(const program& p, const target& t, bool offload =
parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu); parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu);
parameter_map create_param_map(const program& p, bool gpu = true); parameter_map create_param_map(const program& p, bool gpu = true);
target get_target(bool gpu); target get_target(bool gpu);
void compile_program(program& p, bool gpu = true);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -87,7 +87,7 @@ struct check_shapes ...@@ -87,7 +87,7 @@ struct check_shapes
} }
/*! /*!
* Check if the number of shape objects is equal to atleast one of the * Require the number of shape objects to equal to one of the
* given sizes. * given sizes.
* \param ns template parameter pack of sizes to check against * \param ns template parameter pack of sizes to check against
*/ */
...@@ -100,6 +100,23 @@ struct check_shapes ...@@ -100,6 +100,23 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Require the number of shape objects to equal at least a given amount. Use this
* method for ops that can take any number (variadic) of inputs.
* \param n min. number of shapes
*/
const check_shapes& has_at_least(std::size_t n) const
{
if(this->size() < n)
MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected at least " +
to_string(n) + " but given " + std::to_string(size()));
return *this;
}
/*!
* Require all shapes to have the same number of elements.
* \param n number of
*/
const check_shapes& nelements(std::size_t n) const const check_shapes& nelements(std::size_t n) const
{ {
if(not this->all_of([&](const shape& s) { return s.elements() == n; })) if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
......
...@@ -32,8 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,8 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct compile_options struct compile_options
{ {
bool offload_copy = false; bool offload_copy = false;
bool fast_math = true; bool fast_math = true;
bool exhaustive_tune = false;
tracer trace{}; tracer trace{};
}; };
......
...@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&) ...@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&)
{ {
return {}; return {};
} }
template <class T> template <class T>
void wait_for_context(T&, any_ptr) void wait_for_context(T&, any_ptr)
{ {
...@@ -302,7 +303,7 @@ struct context ...@@ -302,7 +303,7 @@ struct context
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(value) : private_detail_te_value(std::move(value))
{ {
} }
...@@ -412,6 +413,7 @@ inline const ValueType& any_cast(const context& x) ...@@ -412,6 +413,7 @@ inline const ValueType& any_cast(const context& x)
#endif #endif
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
#endif #endif
......
...@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type; ...@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type;
namespace std { namespace std {
template <class T> template <class T>
struct common_type<migraphx::half, T> : std::common_type<float, T> struct common_type<migraphx::half, T> : std::common_type<float, T> // NOLINT
{ {
}; };
template <class T> template <class T>
struct common_type<T, migraphx::half> : std::common_type<float, T> struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{ {
}; };
......
...@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept; ...@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
namespace std { namespace std {
template <> template <>
struct hash<migraphx::instruction_ref> struct hash<migraphx::instruction_ref> // NOLINT
{ {
using argument_type = migraphx::instruction_ref; using argument_type = migraphx::instruction_ref;
using result_type = std::size_t; using result_type = std::size_t;
...@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref> ...@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
}; };
template <> template <>
struct equal_to<migraphx::instruction_ref> struct equal_to<migraphx::instruction_ref> // NOLINT
{ {
using argument_type = migraphx::instruction_ref; using argument_type = migraphx::instruction_ref;
using result_type = bool; using result_type = bool;
......
...@@ -36,22 +36,46 @@ template <class F> ...@@ -36,22 +36,46 @@ template <class F>
struct layernorm_matcher struct layernorm_matcher
{ {
F f; F f;
auto last_axis() const
{
return make_basic_pred_matcher([](instruction_ref ins) {
auto v = ins->get_operator().to_value();
if(not v.contains("axes"))
return false;
auto axes = v["axes"].to_vector<std::size_t>();
if(axes.size() != 1)
return false;
return axes.front() == ins->inputs().front()->get_shape().lens().size() - 1;
});
}
auto reduce_mean() const { return f("reduce_mean")(last_axis()); }
auto x_minus_mean() const auto x_minus_mean() const
{ {
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean")))); return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(reduce_mean())));
} }
auto variance() const auto variance() const
{ {
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))))); return reduce_mean()(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(reduce_mean()))))));
} }
auto layernorm_onnx() const auto sqrt_add_eps(const std::string& name) const
{ {
return f("div")(arg(0)(x_minus_mean()), auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return skip_broadcasts(f(name)(arg(0)(any_of(add_eps, variance()))));
}
arg(1)(skip_broadcasts(f("sqrt")(arg(0)( auto layernorm_onnx() const
f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")))))))); {
auto div_sqrt = f("div")(arg(0)(x_minus_mean()), arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
...@@ -39,7 +39,7 @@ struct memory_coloring ...@@ -39,7 +39,7 @@ struct memory_coloring
{ {
std::string allocation_op{}; std::string allocation_op{};
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory_coloring"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -54,6 +54,10 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins ...@@ -54,6 +54,10 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
*/ */
struct module struct module
{ {
// used by replace_allocate pass
// allocate memory in this module rather than using output parmaeters
bool use_local_alloc = false;
module(const std::string& name = ""); module(const std::string& name = "");
// move constructor // move constructor
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <array> #include <array>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -73,49 +74,87 @@ struct concat ...@@ -73,49 +74,87 @@ struct concat
} }
return offsets; return offsets;
} }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) // inputs can contain 1 or more shapes (variadic). compute_shape_op ensures there must
// be at least 1.
check_shapes{inputs, *this, true}.same_ndims().same_type();
if(std::none_of(inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
{ {
MIGRAPHX_THROW("CONCAT: Number of input tensors should exceed 0"); // Static input shapes
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
for(std::size_t ll = 0; ll < first_shape_lens.size(); ll++)
{
if(ll != axis)
{
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[ll] == first_shape_lens[ll];
}))
{
MIGRAPHX_THROW("CONCAT: all input dimensions should match along axis " +
std::to_string(ll));
}
}
}
std::size_t new_dim_axis = 0;
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens = first_shape_lens;
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
else if(std::all_of(
const auto& first_shape_lens = inputs.front().lens(); inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
const auto& type = inputs.front().type();
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{ {
if(l != axis) // Dynamic input shapes
for(std::size_t index = 0; index < inputs[0].ndim(); index++)
{ {
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) { if(index != axis)
return s.lens()[l] == first_shape_lens[l];
}))
{ {
MIGRAPHX_THROW("CONCAT: Non-axis dimensions should match"); if(not std::all_of(inputs.begin(), inputs.end(), [&](const shape& s) {
return s.dyn_dims()[index] == inputs[0].dyn_dims()[index];
}))
MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " +
std::to_string(index));
} }
} }
std::size_t new_min = 0;
std::size_t new_max = 0;
for(const auto& input : inputs)
{
auto ddim = input.dyn_dims()[axis];
new_min += ddim.min;
new_max += ddim.max;
}
auto new_dims = inputs[0].dyn_dims();
new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max, 0};
return {inputs[0].type(), new_dims};
} }
std::size_t new_dim_axis = 0; else
for(const auto& input : inputs)
{ {
const auto& lens = input.lens(); MIGRAPHX_THROW("CONCAT: Cannot mix static and dynamic input shapes.");
new_dim_axis += lens[axis];
} }
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<std::size_t> coffsets = compute_offsets(dyn_out.computed_shape, args);
for(std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape = auto slice_shape = shape{dyn_out.computed_shape.type(),
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()}; input.get_shape().lens(),
auto slice = make_view(slice_shape, output.data() + coffsets[l]); dyn_out.computed_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
std::copy(input.begin(), input.end(), slice.begin()); std::copy(input.begin(), input.end(), slice.begin());
}); });
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <array> #include <array>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -61,35 +62,59 @@ struct gather ...@@ -61,35 +62,59 @@ struct gather
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this, true}.has(2);
auto lens = inputs[0].lens(); shape data = inputs[0];
auto type = inputs[0].type(); shape indices = inputs[1];
lens.erase(lens.begin() + axis); auto type = data.type();
if(not inputs[1].scalar()) // If index_dims is dynamic, convert the data to dynamic too.
if(indices.dynamic())
{ {
auto ind_lens = inputs[1].lens(); data = data.to_dynamic();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
} }
if(data.dynamic())
// for scalar output
if(lens.empty())
{ {
return {type}; auto dims = data.dyn_dims();
dims.erase(dims.begin() + axis);
if(not indices.scalar())
{
auto index_dims = indices.to_dynamic().dyn_dims();
dims.insert(dims.begin() + axis, index_dims.begin(), index_dims.end());
}
return {type, dims};
} }
else
{
// Both data and indices are static. indices may be scalar
auto lens = data.lens();
lens.erase(lens.begin() + axis);
return {type, lens}; if(not indices.scalar())
{
auto ind_lens = indices.lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
}
// for scalar output
if(lens.empty())
{
return {type};
}
return {type, lens};
}
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
auto lens = args[0].get_shape().lens(); auto lens = args[0].get_shape().lens();
std::size_t axis_dim_size = lens[axis]; std::size_t axis_dim_size = lens[axis];
// max dimension in axis // max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
if(output_shape.scalar()) if(dyn_out.computed_shape.scalar())
{ {
auto in_index = indices.front(); auto in_index = indices.front();
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP #define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -47,33 +48,103 @@ struct gathernd ...@@ -47,33 +48,103 @@ struct gathernd
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this, true}.has(2);
auto r = inputs.front().lens().size(); auto i_shape = inputs.back();
auto q = inputs.back().lens().size(); auto data_shape = inputs.front();
auto k = inputs.back().lens().back(); auto r = data_shape.ndim();
auto q = i_shape.ndim();
size_t k;
if(i_shape.dynamic())
{
// the rank of the output is a function of k, so it must be fixed.
if(not i_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = i_shape.dyn_dims().back().min;
}
else
k = i_shape.lens().back();
// Begin input validation checks.
int output_ndim = int(q) + r - k - batch_dims - 1;
if(k > r - batch_dims) if(k > r - batch_dims)
{ {
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) + MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " + " cannot be used to access data of rank " +
std::to_string(r - batch_dims)); std::to_string(r - batch_dims));
} }
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1; if(batch_dims >= q or batch_dims >= r)
std::vector<std::size_t> output_lens(output_lens_size); {
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin()); MIGRAPHX_THROW("GATHERND: rank of an input cannot be less than batch_dims=" +
if(k < r - batch_dims) std::to_string(batch_dims));
}
if(output_ndim < 0)
{
MIGRAPHX_THROW("GATHERND: Indices too large for static data input: k=" +
std::to_string(k));
}
if(migraphx::none_of(inputs, [](auto v) { return v.dynamic(); }))
{
auto indices_lens_iter = i_shape.lens().begin();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape{data_shape.type(), {1}};
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<std::size_t> output_lens(output_ndim);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_lens = data_shape.lens();
std::copy(data_lens.begin() + batch_dims + k,
data_lens.end(),
output_lens.begin() + q - 1);
}
shape output_shape{data_shape.type(), output_lens};
return output_shape;
}
else
{ {
auto data_lens = inputs.front().lens(); // If one or both inputs are dynamic shapes, the output is dynamic.
std::copy( // Make both inputs dynamic to simplify computations.
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1); data_shape = data_shape.to_dynamic();
i_shape = i_shape.to_dynamic();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape(data_shape.type(), {shape::dynamic_dimension({1, 1, 0})});
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<shape::dynamic_dimension> output_dims(output_ndim);
std::copy(i_shape.dyn_dims().begin(),
i_shape.dyn_dims().begin() + q - 1,
output_dims.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_dims = data_shape.dyn_dims();
std::copy(data_dims.begin() + batch_dims + k,
data_dims.begin() + r,
output_dims.begin() + q - 1);
}
shape output_shape(data_shape.type(), output_dims);
return output_shape;
} }
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape(); auto indices_shape = indices.get_shape();
......
...@@ -31,18 +31,30 @@ namespace migraphx { ...@@ -31,18 +31,30 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// different attributes /**
// 1) use_input(default)/use_output * `normalize_attribute` settings:
// 2) use_rank(default)/use_len * Note that default options are not included as enums.
// 3) clip_min(default)/not_clip_min * 1. `use_input` (default) vs. `use_output`:
// 3.1) include_min(default)/exclude_min * Affects the rank of the attribute.
// 4) clip_max(default)/not_clip_max * `use_input -> lens.size()`, `use_output -> lens.size() + vec.size()`.
// 4.1) exclude_max(default)/include_max * 2. use_rank (default) vs use_len:
// 5) normalize padding * `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
* Include or exclude the minimum value/index for range checking and clipping.
* 5. `clip_max` vs. `not_clip_max` (default):
* Clip values greater than the maximum or not.
* 6. `include_max` vs. `exclude_max` (default):
* Include or exclude the maximum value/index for range checking and clipping.
* 7. `normalize_padding`:
* To normalize the padding to `2*(pad ndim)` dimensions.
*/
enum class normalize_attribute enum class normalize_attribute
{ {
use_len,
use_output, use_output,
use_len,
clip_max, clip_max,
clip_min, clip_min,
include_max, include_max,
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
...@@ -60,6 +61,7 @@ struct reverse ...@@ -60,6 +61,7 @@ struct reverse
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1);
return inputs[0].with_lens(inputs[0].lens()); return inputs[0].with_lens(inputs[0].lens());
} }
......
...@@ -28,44 +28,89 @@ ...@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template <class Derived> template <class Derived>
struct scatternd_op : op_name<Derived> struct scatternd_op : op_name<Derived>
{ {
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this, true}.has(3);
auto r = inputs.front().lens().size(); auto data_shape = inputs.front();
auto q = inputs.at(1).lens().size(); auto index_shape = inputs.at(1);
auto k = inputs.at(1).lens().back(); auto upd_shape = inputs.back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens(); auto r = data_shape.ndim();
auto data_lens = inputs.front().lens(); auto q = index_shape.ndim();
size_t k;
if(index_shape.dynamic())
{
// the rank of the output is a function of k, so k must be fixed.
if(not index_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = index_shape.dyn_dims().back().min;
}
else
k = index_shape.lens().back();
// Checks on the sizes of input tensors
if(q + r != upd_shape.ndim() + k + 1)
MIGRAPHX_THROW("ScatterND: ranks of inputs don't match. " + std::to_string(q) + " + " +
std::to_string(r) + " - " + std::to_string(k) +
" - 1 != " + std::to_string(upd_shape.ndim()));
if(k > r) if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) + MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r)); " is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1))) // Convert all static shape dimensions to dynamic so they can be compared.
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] " // It's possible for some of the 3 inputs to be dynamic shapes and some static,
"++ data.lens[k:r-1]"); // but any dynamic dimension that's compared to a static dimension must be fixed.
auto s = inputs.front(); auto ind_dims = index_shape.to_dynamic().dyn_dims();
if(s.broadcasted()) auto upd_dims = upd_shape.to_dynamic().dyn_dims();
auto data_dims = data_shape.to_dynamic().dyn_dims();
// Check that corresponding portions of tensor shapes match.
if(not(std::equal(ind_dims.begin(), ind_dims.begin() + q - 1, upd_dims.begin()) and
std::equal(data_dims.begin() + k, data_dims.end(), upd_dims.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. Update dimensions must match "
"indices and data.");
if(data_shape.dynamic())
return data_shape;
else if(data_shape.broadcasted())
{ {
return {s.type(), s.lens()}; return {data_shape.type(), data_shape.lens()};
} }
else else
{ {
return s.with_lens(s.lens()); return data_shape.with_lens(data_shape.lens());
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto& self = static_cast<const Derived&>(*this); auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) { visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin()); std::copy(data.begin(), data.end(), output.begin());
...@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived> ...@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto updates_std = shape{updates_shape.type(), updates_shape.lens()}; auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape(); auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back(); auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size(); auto q = indices_shape.ndim();
auto r = output_shape.lens().size(); auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) { par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i); auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0); std::vector<std::size_t> indices_idx(q, 0);
...@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived> ...@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std::copy(index_start, index_end, out_idx.begin()); std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]); self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
}); });
}); });
}); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct select_module
{
shape output_dyn_shapes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_dyn_shapes, "output_dyn_shapes"));
}
std::string name() const { return "select_module"; }
shape compute_shape(const std::vector<shape>&, const std::vector<module_ref>&) const
{
return shape{output_dyn_shapes};
}
argument compute(const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& submodule_list,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
// find submodule with input parameter shapes exactly the same as the input arguments
// assuming arguments are in the same order as the input parameters
auto module_iter =
std::find_if(submodule_list.cbegin(), submodule_list.cend(), [&](module_ref mr) {
auto param_names = mr->get_parameter_names();
assert(param_names.size() <= args.size());
return std::equal(param_names.cbegin(),
param_names.cend(),
args.cbegin(),
[&](auto p_name, auto a) {
return a.get_shape() == mr->get_parameter_shape(p_name);
});
});
if(module_iter == submodule_list.end())
{
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for given input shapes");
}
auto* module_to_run = *module_iter;
std::unordered_map<std::string, argument> params;
// add input parameters
auto param_names = module_to_run->get_parameter_names();
assert(param_names.size() <= args.size());
std::transform(param_names.begin(),
param_names.end(),
args.begin(),
std::inserter(params, params.end()),
[](auto&& name, auto&& a) { return std::make_pair(name, a); });
auto results = run(module_to_run, params);
return argument{results};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
...@@ -46,6 +47,10 @@ struct slice ...@@ -46,6 +47,10 @@ struct slice
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends")); return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
} }
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in
* limits.
*/
value attributes() const value attributes() const
{ {
value normalize = value::object{}; value normalize = value::object{};
...@@ -65,14 +70,6 @@ struct slice ...@@ -65,14 +70,6 @@ struct slice
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
if(r < 0)
r += lens[axis];
return std::size_t(r);
}
auto compute_offset(const shape& s) const auto compute_offset(const shape& s) const
{ {
const std::vector<std::size_t>& lens = s.lens(); const std::vector<std::size_t>& lens = s.lens();
...@@ -83,14 +80,14 @@ struct slice ...@@ -83,14 +80,14 @@ struct slice
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis]; offset += starts[i] * strides[axis];
} }
} }
else else
{ {
for(std::size_t axis = 0; axis < lens.size(); axis++) for(std::size_t axis = 0; axis < lens.size(); axis++)
{ {
offset += fix_index(lens, axis, starts[axis]) * strides[axis]; offset += starts[axis] * strides[axis];
} }
} }
return offset; return offset;
...@@ -98,37 +95,81 @@ struct slice ...@@ -98,37 +95,81 @@ struct slice
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; check_shapes{inputs, *this, true}.has(1);
auto t = input_shape.type(); auto input_shape = inputs[0];
const auto& old_lens = input_shape.lens(); auto t = input_shape.type();
const auto& old_strides = input_shape.strides();
if(std::any_of( // TODO: When support for dynamic shapes is added to normalize_attributes,
axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); })) // remove this restriction.
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{ {
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range"); MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis ");
} }
if(starts.size() != axes.size() or axes.size() != ends.size()) // For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
std::vector<std::size_t> new_mins;
std::vector<std::size_t> new_opts;
std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides;
if(input_shape.dynamic())
{
old_lens = input_shape.max_lens();
new_mins = input_shape.min_lens();
new_opts = input_shape.opt_lens();
}
else
{ {
MIGRAPHX_THROW("SLICE: inconsistent sizes"); old_lens = input_shape.lens();
// For static shape (including during eval step after a dynamic input) the strides are
// indexed into the pre-slice array, so they are larger than the apparent size of the
// resulting shape.
old_strides = input_shape.strides();
} }
std::vector<std::size_t> new_lens = old_lens; std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
new_lens[axis] = size_t sliced_length = ends[i] - starts[i];
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]); // A Numpy indexing convention: a slice size larger than the actual dimension
// is legal and the "ends" value is clipped to the axis size
new_lens[axis] = std::min(new_lens[axis], sliced_length);
if(input_shape.dynamic())
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
// sliced_length, making use of TBD start/end values.
std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
auto sliced_opt_length = ends[i] - starts[i];
if(new_opts[axis] != 0)
new_opts[axis] = sliced_opt_length;
if(new_opts[axis] < new_mins[axis] or new_opts[axis] > new_lens[axis])
new_opts[axis] = 0;
}
}
if(input_shape.dynamic())
{
return shape{t, new_mins, new_lens, new_opts};
}
else
{
return shape{t, new_lens, old_strides};
} }
return shape{t, new_lens, old_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }}; auto offset = compute_offset(input.get_shape()) * dyn_out.computed_shape.type_size();
return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment