Commit 5af9aac0 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_batch_pass' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_test_runner

parents 7b2516e0 05e81ed3
......@@ -70,9 +70,6 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
if(not fs::exists(out_path))
MIGRAPHX_THROW("Output file missing: " + out);
if(process)
out_path = process(out_path);
return read_buffer(out_path.string());
}
......
......@@ -326,8 +326,7 @@ struct compiler
loader l;
program_params parameters;
compiler_target ct;
bool offload_copy = false;
bool fast_math = true;
compile_options co;
precision quantize = precision::fp32;
std::vector<std::string> fill0;
......@@ -337,19 +336,26 @@ struct compiler
l.parse(ap);
parameters.parse(ap);
ct.parse(ap);
ap(offload_copy,
ap(co.offload_copy,
{"--enable-offload-copy"},
ap.help("Enable implicit offload copying"),
ap.set_value(true));
ap(fast_math,
ap(co.fast_math,
{"--disable-fast-math"},
ap.help("Disable fast math optimization"),
ap.set_value(false));
ap(co.exhaustive_tune,
{"--exhaustive-tune"},
ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
}
auto params(const program& p) { return parameters.generate(p, ct.get_target(), offload_copy); }
auto params(const program& p)
{
return parameters.generate(p, ct.get_target(), co.offload_copy);
}
program compile()
{
......@@ -366,10 +372,7 @@ struct compiler
{
quantize_int8(p, t, {params(p)});
}
compile_options options;
options.offload_copy = offload_copy;
options.fast_math = fast_math;
p.compile(t, options);
p.compile(t, co);
l.save(p);
return p;
}
......@@ -402,60 +405,41 @@ struct params : command<params>
struct verify : command<verify>
{
loader l;
program_params parameters;
compiler_target ct;
compiler c;
double tolerance = 80;
bool per_instruction = false;
bool reduce = false;
bool offload_copy = false;
bool fast_math = true;
precision quantize = precision::fp32;
void parse(argument_parser& ap)
{
l.parse(ap);
parameters.parse(ap);
ct.parse(ap);
ap(offload_copy,
{"--enable-offload-copy"},
ap.help("Enable implicit offload copying"),
ap.set_value(true));
ap(fast_math,
{"--disable-fast-math"},
ap.help("Disable fast math optimization"),
ap.set_value(false));
c.parse(ap);
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
ap(per_instruction,
{"-i", "--per-instruction"},
ap.help("Verify each instruction"),
ap.set_value(true));
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
}
void run()
{
auto p = l.load();
l.save(p);
auto p = c.l.load();
c.l.save(p);
std::cout << p << std::endl;
compile_options options;
options.offload_copy = offload_copy;
options.fast_math = fast_math;
auto t = ct.get_target();
auto m = parameters.generate(p, t, true);
auto t = c.ct.get_target();
auto m = c.parameters.generate(p, t, true);
if(per_instruction)
{
verify_instructions(p, t, options, quantize, tolerance);
verify_instructions(p, t, c.co, c.quantize, tolerance);
}
else if(reduce)
{
verify_reduced_program(p, t, options, quantize, m, tolerance);
verify_reduced_program(p, t, c.co, c.quantize, m, tolerance);
}
else
{
verify_program(l.file, p, t, options, quantize, m, tolerance);
verify_program(c.l.file, p, t, c.co, c.quantize, m, tolerance);
}
}
};
......
......@@ -108,8 +108,6 @@ target get_target(bool gpu)
return make_target("cpu");
}
void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
......@@ -37,7 +37,6 @@ parameter_map create_param_map(const program& p, const target& t, bool offload =
parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu);
parameter_map create_param_map(const program& p, bool gpu = true);
target get_target(bool gpu);
void compile_program(program& p, bool gpu = true);
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
......
......@@ -87,7 +87,7 @@ struct check_shapes
}
/*!
* Check if the number of shape objects is equal to atleast one of the
* Require the number of shape objects to equal to one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
......@@ -100,6 +100,23 @@ struct check_shapes
return *this;
}
/*!
* Require the number of shape objects to equal at least a given amount. Use this
* method for ops that can take any number (variadic) of inputs.
* \param n min. number of shapes
*/
const check_shapes& has_at_least(std::size_t n) const
{
if(this->size() < n)
MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected at least " +
to_string(n) + " but given " + std::to_string(size()));
return *this;
}
/*!
* Require all shapes to have the same number of elements.
* \param n number of
*/
const check_shapes& nelements(std::size_t n) const
{
if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
......
......@@ -32,8 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct compile_options
{
bool offload_copy = false;
bool fast_math = true;
bool offload_copy = false;
bool fast_math = true;
bool exhaustive_tune = false;
tracer trace{};
};
......
......@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&)
{
return {};
}
template <class T>
void wait_for_context(T&, any_ptr)
{
......@@ -302,7 +303,7 @@ struct context
PrivateDetailTypeErasedT value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(value)
: private_detail_te_value(std::move(value))
{
}
......@@ -412,6 +413,7 @@ inline const ValueType& any_cast(const context& x)
#endif
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
#endif
......
......@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type;
namespace std {
template <class T>
struct common_type<migraphx::half, T> : std::common_type<float, T>
struct common_type<migraphx::half, T> : std::common_type<float, T> // NOLINT
{
};
template <class T>
struct common_type<T, migraphx::half> : std::common_type<float, T>
struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{
};
......
......@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
namespace std {
template <>
struct hash<migraphx::instruction_ref>
struct hash<migraphx::instruction_ref> // NOLINT
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
......@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
};
template <>
struct equal_to<migraphx::instruction_ref>
struct equal_to<migraphx::instruction_ref> // NOLINT
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
......
......@@ -36,22 +36,46 @@ template <class F>
struct layernorm_matcher
{
F f;
auto last_axis() const
{
return make_basic_pred_matcher([](instruction_ref ins) {
auto v = ins->get_operator().to_value();
if(not v.contains("axes"))
return false;
auto axes = v["axes"].to_vector<std::size_t>();
if(axes.size() != 1)
return false;
return axes.front() == ins->inputs().front()->get_shape().lens().size() - 1;
});
}
auto reduce_mean() const { return f("reduce_mean")(last_axis()); }
auto x_minus_mean() const
{
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean"))));
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(reduce_mean())));
}
auto variance() const
{
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f)))));
return reduce_mean()(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(reduce_mean()))))));
}
auto layernorm_onnx() const
auto sqrt_add_eps(const std::string& name) const
{
return f("div")(arg(0)(x_minus_mean()),
auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return skip_broadcasts(f(name)(arg(0)(any_of(add_eps, variance()))));
}
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
auto layernorm_onnx() const
{
auto div_sqrt = f("div")(arg(0)(x_minus_mean()), arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
}
auto matcher() const { return layernorm_onnx(); }
......
......@@ -39,7 +39,7 @@ struct memory_coloring
{
std::string allocation_op{};
bool verify = false;
std::string name() const { return "memory coloring"; }
std::string name() const { return "memory_coloring"; }
void apply(module& m) const;
};
......
......@@ -54,6 +54,10 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
*/
struct module
{
// used by replace_allocate pass
// allocate memory in this module rather than using output parmaeters
bool use_local_alloc = false;
module(const std::string& name = "");
// move constructor
......
......@@ -26,6 +26,7 @@
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
......@@ -73,49 +74,87 @@ struct concat
}
return offsets;
}
shape normalize_compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
// inputs can contain 1 or more shapes (variadic). compute_shape_op ensures there must
// be at least 1.
check_shapes{inputs, *this, true}.same_ndims().same_type();
if(std::none_of(inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
{
MIGRAPHX_THROW("CONCAT: Number of input tensors should exceed 0");
// Static input shapes
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
for(std::size_t ll = 0; ll < first_shape_lens.size(); ll++)
{
if(ll != axis)
{
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[ll] == first_shape_lens[ll];
}))
{
MIGRAPHX_THROW("CONCAT: all input dimensions should match along axis " +
std::to_string(ll));
}
}
}
std::size_t new_dim_axis = 0;
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens = first_shape_lens;
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
else if(std::all_of(
inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
{
if(l != axis)
// Dynamic input shapes
for(std::size_t index = 0; index < inputs[0].ndim(); index++)
{
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l];
}))
if(index != axis)
{
MIGRAPHX_THROW("CONCAT: Non-axis dimensions should match");
if(not std::all_of(inputs.begin(), inputs.end(), [&](const shape& s) {
return s.dyn_dims()[index] == inputs[0].dyn_dims()[index];
}))
MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " +
std::to_string(index));
}
}
std::size_t new_min = 0;
std::size_t new_max = 0;
for(const auto& input : inputs)
{
auto ddim = input.dyn_dims()[axis];
new_min += ddim.min;
new_max += ddim.max;
}
auto new_dims = inputs[0].dyn_dims();
new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max, 0};
return {inputs[0].type(), new_dims};
}
std::size_t new_dim_axis = 0;
for(const auto& input : inputs)
else
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
MIGRAPHX_THROW("CONCAT: Cannot mix static and dynamic input shapes.");
}
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
argument result{dyn_out.computed_shape};
std::vector<std::size_t> coffsets = compute_offsets(dyn_out.computed_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
auto slice_shape = shape{dyn_out.computed_shape.type(),
input.get_shape().lens(),
dyn_out.computed_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
std::copy(input.begin(), input.end(), slice.begin());
});
}
......
......@@ -26,6 +26,7 @@
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
......@@ -61,35 +62,59 @@ struct gather
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
auto type = inputs[0].type();
lens.erase(lens.begin() + axis);
if(not inputs[1].scalar())
check_shapes{inputs, *this, true}.has(2);
shape data = inputs[0];
shape indices = inputs[1];
auto type = data.type();
// If index_dims is dynamic, convert the data to dynamic too.
if(indices.dynamic())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
data = data.to_dynamic();
}
// for scalar output
if(lens.empty())
if(data.dynamic())
{
return {type};
auto dims = data.dyn_dims();
dims.erase(dims.begin() + axis);
if(not indices.scalar())
{
auto index_dims = indices.to_dynamic().dyn_dims();
dims.insert(dims.begin() + axis, index_dims.begin(), index_dims.end());
}
return {type, dims};
}
else
{
// Both data and indices are static. indices may be scalar
auto lens = data.lens();
lens.erase(lens.begin() + axis);
return {type, lens};
if(not indices.scalar())
{
auto ind_lens = indices.lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
}
// for scalar output
if(lens.empty())
{
return {type};
}
return {type, lens};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
// negative axis means counting dimensions from back
auto lens = args[0].get_shape().lens();
std::size_t axis_dim_size = lens[axis];
// max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
if(output_shape.scalar())
if(dyn_out.computed_shape.scalar())
{
auto in_index = indices.front();
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
......@@ -47,33 +48,103 @@ struct gathernd
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto r = inputs.front().lens().size();
auto q = inputs.back().lens().size();
auto k = inputs.back().lens().back();
check_shapes{inputs, *this, true}.has(2);
auto i_shape = inputs.back();
auto data_shape = inputs.front();
auto r = data_shape.ndim();
auto q = i_shape.ndim();
size_t k;
if(i_shape.dynamic())
{
// the rank of the output is a function of k, so it must be fixed.
if(not i_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = i_shape.dyn_dims().back().min;
}
else
k = i_shape.lens().back();
// Begin input validation checks.
int output_ndim = int(q) + r - k - batch_dims - 1;
if(k > r - batch_dims)
{
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " +
std::to_string(r - batch_dims));
}
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1;
std::vector<std::size_t> output_lens(output_lens_size);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
if(k < r - batch_dims)
if(batch_dims >= q or batch_dims >= r)
{
MIGRAPHX_THROW("GATHERND: rank of an input cannot be less than batch_dims=" +
std::to_string(batch_dims));
}
if(output_ndim < 0)
{
MIGRAPHX_THROW("GATHERND: Indices too large for static data input: k=" +
std::to_string(k));
}
if(migraphx::none_of(inputs, [](auto v) { return v.dynamic(); }))
{
auto indices_lens_iter = i_shape.lens().begin();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape{data_shape.type(), {1}};
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<std::size_t> output_lens(output_ndim);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_lens = data_shape.lens();
std::copy(data_lens.begin() + batch_dims + k,
data_lens.end(),
output_lens.begin() + q - 1);
}
shape output_shape{data_shape.type(), output_lens};
return output_shape;
}
else
{
auto data_lens = inputs.front().lens();
std::copy(
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
// If one or both inputs are dynamic shapes, the output is dynamic.
// Make both inputs dynamic to simplify computations.
data_shape = data_shape.to_dynamic();
i_shape = i_shape.to_dynamic();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape(data_shape.type(), {shape::dynamic_dimension({1, 1, 0})});
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<shape::dynamic_dimension> output_dims(output_ndim);
std::copy(i_shape.dyn_dims().begin(),
i_shape.dyn_dims().begin() + q - 1,
output_dims.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_dims = data_shape.dyn_dims();
std::copy(data_dims.begin() + batch_dims + k,
data_dims.begin() + r,
output_dims.begin() + q - 1);
}
shape output_shape(data_shape.type(), output_dims);
return output_shape;
}
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape();
......
......@@ -31,18 +31,30 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
// 5) normalize padding
/**
* `normalize_attribute` settings:
* Note that default options are not included as enums.
* 1. `use_input` (default) vs. `use_output`:
* Affects the rank of the attribute.
* `use_input -> lens.size()`, `use_output -> lens.size() + vec.size()`.
* 2. use_rank (default) vs use_len:
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
* Include or exclude the minimum value/index for range checking and clipping.
* 5. `clip_max` vs. `not_clip_max` (default):
* Clip values greater than the maximum or not.
* 6. `include_max` vs. `exclude_max` (default):
* Include or exclude the maximum value/index for range checking and clipping.
* 7. `normalize_padding`:
* To normalize the padding to `2*(pad ndim)` dimensions.
*/
enum class normalize_attribute
{
use_len,
use_output,
use_len,
clip_max,
clip_min,
include_max,
......
......@@ -28,6 +28,7 @@
#include <vector>
#include <cmath>
#include <utility>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/normalize_attribute.hpp>
......@@ -60,6 +61,7 @@ struct reverse
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs[0].with_lens(inputs[0].lens());
}
......
......@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template <class Derived>
struct scatternd_op : op_name<Derived>
{
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3);
auto r = inputs.front().lens().size();
auto q = inputs.at(1).lens().size();
auto k = inputs.at(1).lens().back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens();
auto data_lens = inputs.front().lens();
check_shapes{inputs, *this, true}.has(3);
auto data_shape = inputs.front();
auto index_shape = inputs.at(1);
auto upd_shape = inputs.back();
auto r = data_shape.ndim();
auto q = index_shape.ndim();
size_t k;
if(index_shape.dynamic())
{
// the rank of the output is a function of k, so k must be fixed.
if(not index_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = index_shape.dyn_dims().back().min;
}
else
k = index_shape.lens().back();
// Checks on the sizes of input tensors
if(q + r != upd_shape.ndim() + k + 1)
MIGRAPHX_THROW("ScatterND: ranks of inputs don't match. " + std::to_string(q) + " + " +
std::to_string(r) + " - " + std::to_string(k) +
" - 1 != " + std::to_string(upd_shape.ndim()));
if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]");
auto s = inputs.front();
if(s.broadcasted())
// Convert all static shape dimensions to dynamic so they can be compared.
// It's possible for some of the 3 inputs to be dynamic shapes and some static,
// but any dynamic dimension that's compared to a static dimension must be fixed.
auto ind_dims = index_shape.to_dynamic().dyn_dims();
auto upd_dims = upd_shape.to_dynamic().dyn_dims();
auto data_dims = data_shape.to_dynamic().dyn_dims();
// Check that corresponding portions of tensor shapes match.
if(not(std::equal(ind_dims.begin(), ind_dims.begin() + q - 1, upd_dims.begin()) and
std::equal(data_dims.begin() + k, data_dims.end(), upd_dims.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. Update dimensions must match "
"indices and data.");
if(data_shape.dynamic())
return data_shape;
else if(data_shape.broadcasted())
{
return {s.type(), s.lens()};
return {data_shape.type(), data_shape.lens()};
}
else
{
return s.with_lens(s.lens());
return data_shape.with_lens(data_shape.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin());
......@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size();
auto r = output_shape.lens().size();
auto q = indices_shape.ndim();
auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0);
......@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]);
self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
});
});
});
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct select_module
{
shape output_dyn_shapes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_dyn_shapes, "output_dyn_shapes"));
}
std::string name() const { return "select_module"; }
shape compute_shape(const std::vector<shape>&, const std::vector<module_ref>&) const
{
return shape{output_dyn_shapes};
}
argument compute(const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& submodule_list,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
// find submodule with input parameter shapes exactly the same as the input arguments
// assuming arguments are in the same order as the input parameters
auto module_iter =
std::find_if(submodule_list.cbegin(), submodule_list.cend(), [&](module_ref mr) {
auto param_names = mr->get_parameter_names();
assert(param_names.size() <= args.size());
return std::equal(param_names.cbegin(),
param_names.cend(),
args.cbegin(),
[&](auto p_name, auto a) {
return a.get_shape() == mr->get_parameter_shape(p_name);
});
});
if(module_iter == submodule_list.end())
{
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for given input shapes");
}
auto* module_to_run = *module_iter;
std::unordered_map<std::string, argument> params;
// add input parameters
auto param_names = module_to_run->get_parameter_names();
assert(param_names.size() <= args.size());
std::transform(param_names.begin(),
param_names.end(),
args.begin(),
std::inserter(params, params.end()),
[](auto&& name, auto&& a) { return std::make_pair(name, a); });
auto results = run(module_to_run, params);
return argument{results};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
......@@ -46,6 +47,10 @@ struct slice
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
}
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in
* limits.
*/
value attributes() const
{
value normalize = value::object{};
......@@ -65,14 +70,6 @@ struct slice
std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
if(r < 0)
r += lens[axis];
return std::size_t(r);
}
auto compute_offset(const shape& s) const
{
const std::vector<std::size_t>& lens = s.lens();
......@@ -83,14 +80,14 @@ struct slice
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis];
offset += starts[i] * strides[axis];
}
}
else
{
for(std::size_t axis = 0; axis < lens.size(); axis++)
{
offset += fix_index(lens, axis, starts[axis]) * strides[axis];
offset += starts[axis] * strides[axis];
}
}
return offset;
......@@ -98,37 +95,81 @@ struct slice
shape normalize_compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto t = input_shape.type();
const auto& old_lens = input_shape.lens();
const auto& old_strides = input_shape.strides();
check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0];
auto t = input_shape.type();
if(std::any_of(
axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); }))
// TODO: When support for dynamic shapes is added to normalize_attributes,
// remove this restriction.
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range");
MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis ");
}
if(starts.size() != axes.size() or axes.size() != ends.size())
// For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
std::vector<std::size_t> new_mins;
std::vector<std::size_t> new_opts;
std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides;
if(input_shape.dynamic())
{
old_lens = input_shape.max_lens();
new_mins = input_shape.min_lens();
new_opts = input_shape.opt_lens();
}
else
{
MIGRAPHX_THROW("SLICE: inconsistent sizes");
old_lens = input_shape.lens();
// For static shape (including during eval step after a dynamic input) the strides are
// indexed into the pre-slice array, so they are larger than the apparent size of the
// resulting shape.
old_strides = input_shape.strides();
}
std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
new_lens[axis] =
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
auto axis = axes[i];
size_t sliced_length = ends[i] - starts[i];
// A Numpy indexing convention: a slice size larger than the actual dimension
// is legal and the "ends" value is clipped to the axis size
new_lens[axis] = std::min(new_lens[axis], sliced_length);
if(input_shape.dynamic())
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
// sliced_length, making use of TBD start/end values.
std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
auto sliced_opt_length = ends[i] - starts[i];
if(new_opts[axis] != 0)
new_opts[axis] = sliced_opt_length;
if(new_opts[axis] < new_mins[axis] or new_opts[axis] > new_lens[axis])
new_opts[axis] = 0;
}
}
if(input_shape.dynamic())
{
return shape{t, new_mins, new_lens, new_opts};
}
else
{
return shape{t, new_lens, old_strides};
}
return shape{t, new_lens, old_strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }};
auto input = args[0];
auto offset = compute_offset(input.get_shape()) * dyn_out.computed_shape.type_size();
return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment