"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "3c133f818b7dd51e2ee2ec8f68e5211de736f2f2"
Commit a0edd061 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into jit-improve

parents 6deee23b 2781ccd8
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN PATH=/opt/cmake/bin:$PATH cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@02078ce236ad90e3aec04c0c770ef5bfc99e49c2 RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@26a4b3cfc0a1a15181490f24ae461608fef1b04e -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -88,6 +88,7 @@ add_library(migraphx ...@@ -88,6 +88,7 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
...@@ -93,9 +93,11 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -93,9 +93,11 @@ static bool try_compute_shape(instruction_ref ins,
return try_compute_shape(ins, inputs, mods); return try_compute_shape(ins, inputs, mods);
} }
void eliminate_contiguous::apply(module& m) const template <class F>
static void remove_contiguous(const std::string& op_name, module& m, F f)
{ {
std::vector<instruction_ref> const_instruction; auto last = std::prev(m.end());
std::vector<instruction_ref> const_instructions;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
...@@ -103,6 +105,12 @@ void eliminate_contiguous::apply(module& m) const ...@@ -103,6 +105,12 @@ void eliminate_contiguous::apply(module& m) const
if(ins->name() == "@return") if(ins->name() == "@return")
continue; continue;
if(ins != last and ins->outputs().empty())
continue;
if(not f(ins))
continue;
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args; auto new_args = args;
...@@ -110,36 +118,46 @@ void eliminate_contiguous::apply(module& m) const ...@@ -110,36 +118,46 @@ void eliminate_contiguous::apply(module& m) const
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
{ {
if(arg->name() == op_name) if(arg->name() != op_name)
continue;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{ {
auto prev = arg->inputs().front(); const_instructions.push_back(arg);
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
const_instruction.push_back(arg);
}
} }
} }
} }
// Perform evaluations in parallel // Perform evaluations in parallel
std::vector<argument> literals(const_instruction.size()); std::vector<argument> literals(const_instructions.size());
par_for(const_instruction.size(), 1, [&](const auto i) { par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{}; auto c = op::contiguous{};
auto prev = const_instruction[i]->inputs().front(); auto prev = const_instructions[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
}); });
for(size_t i = 0; i < const_instruction.size(); i++) for(size_t i = 0; i < const_instructions.size(); i++)
{ {
auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instruction[i], l); m.replace_instruction(const_instructions[i], l);
} }
} }
void eliminate_contiguous::apply(module& m) const
{
// Skip contiguous from splits first
remove_contiguous(op_name, m, [](auto ins) {
if(ins->name() != "slice")
return true;
return (ins->inputs().front()->outputs().size() == 1);
});
remove_contiguous(op_name, m, [](auto) { return true; });
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#define MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#include <migraphx/support_metric.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct assignment_options
{
support_metric metric = support_metric::latency;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
...@@ -71,6 +71,11 @@ struct check_shapes ...@@ -71,6 +71,11 @@ struct check_shapes
return end - begin; return end - begin;
} }
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template <class... Ts> template <class... Ts>
const check_shapes& has(Ts... ns) const const check_shapes& has(Ts... ns) const
{ {
......
...@@ -349,25 +349,27 @@ match::matcher_result find_match(module& modl, M&& m) ...@@ -349,25 +349,27 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the module /// Find matches for an instruction in the module
template <class... Ms> template <class Mod, class... Ms>
void find_matches(module& mod, instruction_ref ins, Ms&&... ms) void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
#endif #endif
bool trace = enabled(MIGRAPHX_TRACE_MATCHES{}); int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
bool match = false; bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
auto r = match_instruction(mod, ins, m.matcher()); if(trace > 1)
if(r.result == mod.end()) std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return; return;
if(trace) if(trace > 0)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
mod.debug_print(ins); get_module(mod).debug_print(ins);
} }
m.apply(mod, r); m.apply(mod, r);
match = true; match = true;
...@@ -376,10 +378,10 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms) ...@@ -376,10 +378,10 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
} }
/// Find matches in a module /// Find matches in a module
template <class... Ms> template <class Mod, class... Ms>
void find_matches(module& mod, Ms&&... ms) void find_matches(Mod& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(mod)) for(auto ins : iterator_for(get_module(mod)))
{ {
find_matches(mod, ins, ms...); find_matches(mod, ins, ms...);
} }
......
...@@ -124,7 +124,7 @@ struct module ...@@ -124,7 +124,7 @@ struct module
std::unordered_map<instruction_ref, instruction_ref> map_ins = {}); std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref> std::vector<instruction_ref>
add_instructions(module_ref m, add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {}); std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref> std::vector<instruction_ref>
...@@ -139,7 +139,7 @@ struct module ...@@ -139,7 +139,7 @@ struct module
std::vector<instruction_ref> std::vector<instruction_ref>
insert_instructions(instruction_ref ins, insert_instructions(instruction_ref ins,
module_ref m, const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {}); std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref> std::vector<instruction_ref>
...@@ -164,6 +164,10 @@ struct module ...@@ -164,6 +164,10 @@ struct module
instruction_ref replace_return(std::vector<instruction_ref> args); instruction_ref replace_return(std::vector<instruction_ref> args);
instruction_ref insert_literal(instruction_ref ins, literal l);
instruction_ref insert_parameter(instruction_ref ins, std::string name, shape s);
std::vector<std::string> get_parameter_names() const; std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const; shape get_parameter_shape(std::string name) const;
...@@ -227,6 +231,8 @@ struct module ...@@ -227,6 +231,8 @@ struct module
std::unique_ptr<module_impl> impl; std::unique_ptr<module_impl> impl;
}; };
inline module& get_module(module& m) { return m; }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -32,7 +32,8 @@ namespace migraphx { ...@@ -32,7 +32,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
using module_ref = module*; using module_ref = module*;
using const_module_ref = const module*;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -42,11 +42,12 @@ namespace op { ...@@ -42,11 +42,12 @@ namespace op {
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"), f(self.steps, "steps"));
} }
value attributes() const value attributes() const
...@@ -73,6 +74,9 @@ struct unsqueeze ...@@ -73,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
} }
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
...@@ -80,16 +84,27 @@ struct unsqueeze ...@@ -80,16 +84,27 @@ struct unsqueeze
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
new_lens[i] = 1; std::int64_t step = 1;
if(p == 0) // unsqueeze on the first axes if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
new_strides[i] = old_lens[0] * old_strides[0]; if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
} }
else // unsqueeze on middle or last axes else
{ {
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1; if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
} }
} }
else else
......
...@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op) ...@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return result; return result;
} }
/*!
* Returns the permutation needed to apply to the shape to undo the current permutation
*/
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation); std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
/*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape.
*/
std::vector<int64_t> find_permutation(const shape& s); std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes); std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
......
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
...@@ -84,6 +86,9 @@ struct program ...@@ -84,6 +86,9 @@ struct program
instruction_ref validate() const; instruction_ref validate() const;
target_assignments get_target_assignments(const std::vector<target>& targets,
assignment_options options = assignment_options{});
void compile(const target& t, compile_options options = compile_options{}); void compile(const target& t, compile_options options = compile_options{});
bool is_compiled() const; bool is_compiled() const;
......
...@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f) ...@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f)
std::transform(r.begin(), r.end(), it, f); std::transform(r.begin(), r.end(), it, f);
} }
template <class Range1, class Range2, class Iterator, class F>
void transform(Range1&& r1, Range2&& r2, Iterator it, F f)
{
std::transform(r1.begin(), r1.end(), r2.begin(), it, f);
}
template <class Range> template <class Range>
auto reverse(Range& r) auto reverse(Range& r)
{ {
......
...@@ -82,6 +82,23 @@ struct shape ...@@ -82,6 +82,23 @@ struct shape
{ {
}; };
struct dynamic_dimension
{
std::size_t min = 0;
std::size_t max = 0;
std::size_t opt = 0;
template <class Self, class F>
static auto reflect(Self& self, F f);
bool is_fixed() const;
bool has_optimal() const;
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
};
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
static std::string name(type_t t); static std::string name(type_t t);
...@@ -92,6 +109,12 @@ struct shape ...@@ -92,6 +109,12 @@ struct shape
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s); shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(type_t t, std::initializer_list<std::size_t> d);
shape(type_t t, std::vector<dynamic_dimension> dims);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{ {
...@@ -112,10 +135,44 @@ struct shape ...@@ -112,10 +135,44 @@ struct shape
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*!
* Return the number of elements in the tensor.
*/
std::size_t elements() const; std::size_t elements() const;
/*!
* Return the number of total bytes used for storage of the tensor data; includes subshapes.
* For dynamic shape, returns the maximum number of bytes presuming a packed shape.
*/
std::size_t bytes() const; std::size_t bytes() const;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std::size_t type_size() const; std::size_t type_size() const;
const std::vector<dynamic_dimension>& dyn_dims() const;
/*!
* Minimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> min_lens() const;
/*!
* Maximum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> max_lens() const;
/*!
* Optimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> opt_lens() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
/// Map multiple indices to space index /// Map multiple indices to space index
...@@ -136,19 +193,27 @@ struct shape ...@@ -136,19 +193,27 @@ struct shape
std::vector<std::size_t> multi(std::size_t i) const; std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const; void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
/// order /// order
bool transposed() const; bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero /// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const; bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and /// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed. /// not transposed.
bool standard() const; bool standard() const;
/// Returns true if all strides are equal to 0 (scalar tensor) /// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const; bool scalar() const;
/// Return true if the shape is dynamic
bool dynamic() const;
shape normalize_standard() const; shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const; shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
...@@ -191,6 +256,10 @@ struct shape ...@@ -191,6 +256,10 @@ struct shape
std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; } std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }
auto is_integral() const { return std::is_integral<type>{}; }
auto is_signed() const { return std::is_signed<type>{}; }
auto is_unsigned() const { return std::is_unsigned<type>{}; }
template <class U> template <class U>
type* from(U* buffer, std::size_t n = 0) const type* from(U* buffer, std::size_t n = 0) const
{ {
...@@ -248,6 +317,11 @@ struct shape ...@@ -248,6 +317,11 @@ struct shape
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
/*!
* Returns the number of elements in the data buffer.
* For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/
std::size_t element_space() const; std::size_t element_space() const;
private: private:
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class support_metric
{
latency,
throughput
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
...@@ -37,6 +37,8 @@ ...@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,6 +63,13 @@ struct target ...@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg) ...@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
...@@ -117,6 +132,8 @@ struct target ...@@ -117,6 +132,8 @@ struct target
// //
context get_context() const; context get_context() const;
// (optional) // (optional)
float is_supported(instruction_ref ins, support_metric m) const;
// (optional)
argument copy_to(const argument& input) const; argument copy_to(const argument& input) const;
// (optional) // (optional)
argument copy_from(const argument& input) const; argument copy_from(const argument& input) const;
...@@ -207,6 +224,12 @@ struct target ...@@ -207,6 +224,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
float is_supported(instruction_ref ins, support_metric m) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m);
}
argument copy_to(const argument& input) const argument copy_to(const argument& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -242,11 +265,31 @@ struct target ...@@ -242,11 +265,31 @@ struct target
virtual std::vector<pass> get_passes(context& ctx, virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0; const compile_options& options) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0; virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0; virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
template <class T>
static auto private_detail_te_default_is_supported(char,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m))
{
return private_detail_te_self.is_supported(ins, m);
}
template <class T>
static float private_detail_te_default_is_supported(float,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
{
return target_is_supported(private_detail_te_self, ins, m);
}
template <class T> template <class T>
static auto static auto
private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input) private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input)
...@@ -329,6 +372,12 @@ struct target ...@@ -329,6 +372,12 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override
{
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m);
}
argument copy_to(const argument& input) const override argument copy_to(const argument& input) const override
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments
{
void add_assignment(instruction_ref ins, const std::string& target);
auto begin() const { return assignments.cbegin(); }
auto end() const { return assignments.cend(); }
private:
std::unordered_map<instruction_ref, std::string> assignments;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
...@@ -399,7 +399,8 @@ module::add_instructions(const std::vector<instruction_ref>& instructions, ...@@ -399,7 +399,8 @@ module::add_instructions(const std::vector<instruction_ref>& instructions,
} }
std::vector<instruction_ref> std::vector<instruction_ref>
module::add_instructions(module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins) module::add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
return this->insert_instructions(this->end(), m, std::move(map_ins)); return this->insert_instructions(this->end(), m, std::move(map_ins));
} }
...@@ -420,8 +421,10 @@ module::insert_instructions(instruction_ref ins, ...@@ -420,8 +421,10 @@ module::insert_instructions(instruction_ref ins,
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins)); return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
} }
std::vector<instruction_ref> module::insert_instructions( std::vector<instruction_ref>
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins) module::insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins)); return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
} }
...@@ -436,11 +439,7 @@ module::insert_instructions(instruction_ref ins, ...@@ -436,11 +439,7 @@ module::insert_instructions(instruction_ref ins,
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins)); return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
} }
instruction_ref module::add_literal(literal l) instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
{
impl->emplace_front(std::move(l));
return impl->instructions.begin();
}
instruction_ref module::add_outline(const shape& s) instruction_ref module::add_outline(const shape& s)
{ {
...@@ -450,10 +449,7 @@ instruction_ref module::add_outline(const shape& s) ...@@ -450,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref module::add_parameter(std::string name, shape s) instruction_ref module::add_parameter(std::string name, shape s)
{ {
assert(get_parameter_shape(name) == shape{}); return insert_parameter(begin(), std::move(name), std::move(s));
impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return impl->instructions.begin();
} }
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
...@@ -466,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args) ...@@ -466,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return result; return result;
} }
instruction_ref module::insert_literal(instruction_ref ins, literal l)
{
impl->emplace(ins, std::move(l));
return std::prev(ins);
}
instruction_ref module::insert_parameter(instruction_ref ins, std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->insert(ins, {builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return std::prev(ins);
}
instruction_ref module::replace_return(std::vector<instruction_ref> args) instruction_ref module::replace_return(std::vector<instruction_ref> args)
{ {
auto last = std::prev(this->end()); auto last = std::prev(this->end());
......
...@@ -159,6 +159,25 @@ instruction_ref program::validate() const ...@@ -159,6 +159,25 @@ instruction_ref program::validate() const
return mm->validate(); return mm->validate();
} }
target_assignments program::get_target_assignments(const std::vector<target>& targets,
assignment_options options)
{
const auto m = options.metric;
target_assignments p;
const auto* mod = get_main_module();
for(auto it : iterator_for(*mod))
{
auto t = std::max_element(
targets.begin(), targets.end(), [it, m](const target& lhs, const target& rhs) {
return lhs.is_supported(it, m) < rhs.is_supported(it, m);
});
p.add_assignment(it, t->name());
}
return p;
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->target_name.empty(); }
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
...@@ -504,12 +523,14 @@ static void mod_from_val(module_ref mod, ...@@ -504,12 +523,14 @@ static void mod_from_val(module_ref mod,
if(name == "@param") if(name == "@param")
{ {
output = mod->add_parameter(fields["parameter"].to<std::string>(), output = mod->insert_parameter(mod->end(),
migraphx::from_value<shape>(node.at("shape"))); fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
} }
else if(name == "@literal") else if(name == "@literal")
{ {
output = mod->add_literal(migraphx::from_value<literal>(node.at("literal"))); output =
mod->insert_literal(mod->end(), migraphx::from_value<literal>(node.at("literal")));
} }
else else
{ {
...@@ -544,11 +565,11 @@ static void mod_from_val(module_ref mod, ...@@ -544,11 +565,11 @@ static void mod_from_val(module_ref mod,
} }
else if(module_inputs.empty()) else if(module_inputs.empty())
{ {
output = mod->add_instruction(op, inputs); output = mod->insert_instruction(mod->end(), op, inputs);
} }
else else
{ {
output = mod->add_instruction(op, inputs, module_inputs); output = mod->insert_instruction(mod->end(), op, inputs, module_inputs);
} }
} }
output->set_normalized(normalized); output->set_normalized(normalized);
...@@ -681,11 +702,13 @@ void program::perf_report(std::ostream& os, ...@@ -681,11 +702,13 @@ void program::perf_report(std::ostream& os,
double overhead_percent = overhead_time * 100.0 / total_time; double overhead_percent = overhead_time * 100.0 / total_time;
double total_instruction_time = 0.0; double total_instruction_time = 0.0;
std::unordered_map<std::string, double> op_times; std::unordered_map<std::string, double> op_times;
std::unordered_map<std::string, std::size_t> op_n;
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[perf_group(p.first->get_operator())] += avg; op_times[perf_group(p.first->get_operator())] += avg;
total_instruction_time += avg; total_instruction_time += avg;
op_n[perf_group(p.first->get_operator())]++;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
...@@ -706,18 +729,19 @@ void program::perf_report(std::ostream& os, ...@@ -706,18 +729,19 @@ void program::perf_report(std::ostream& os,
os << std::endl; os << std::endl;
os << "Summary:" << std::endl; os << "Summary:" << std::endl;
std::vector<std::pair<double, std::string>> op_times_sorted; std::vector<std::tuple<double, std::size_t, std::string>> op_times_sorted;
std::transform(op_times.begin(), std::transform(
op_times.end(), op_times.begin(), op_times.end(), std::back_inserter(op_times_sorted), [&](auto p) {
std::back_inserter(op_times_sorted), auto&& name = p.first;
[](auto p) { return std::make_pair(p.second, p.first); }); return std::make_tuple(p.second, op_n.at(name), name);
});
std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{}); std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{});
for(auto&& p : op_times_sorted) for(auto&& [avg, nn, name] : op_times_sorted)
{ {
auto&& name = p.second;
double avg = p.first;
double percent = std::ceil(100.0 * avg / total_instruction_time); double percent = std::ceil(100.0 * avg / total_instruction_time);
os << name << ": " << avg << "ms, " << percent << "%" << std::endl; double per_ins = avg / nn;
os << name << ": " << avg << "ms / " << nn << " = " << per_ins << "ms, " << percent << "%"
<< std::endl;
} }
os << std::endl; os << std::endl;
......
...@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd) ...@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd)
result["shape"] = migraphx::to_value(rd.get_shape()); result["shape"] = migraphx::to_value(rd.get_shape());
if(rd.get_shape().type() == shape::tuple_type) if(rd.get_shape().type() == shape::tuple_type)
result["sub"] = migraphx::to_value(rd.get_sub_objects()); result["sub"] = migraphx::to_value(rd.get_sub_objects());
else else if(not rd.empty())
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes()); result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result; v = result;
} }
...@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a) ...@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a)
literal l = migraphx::from_value<literal>(v); literal l = migraphx::from_value<literal>(v);
a = l.get_argument(); a = l.get_argument();
} }
else else if(v.contains("sub"))
{ {
a = migraphx::from_value<std::vector<argument>>(v.at("sub")); a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
...@@ -65,13 +66,21 @@ struct shape_impl ...@@ -65,13 +66,21 @@ struct shape_impl
std::is_sorted(m_strides.rbegin(), m_strides.rend()); std::is_sorted(m_strides.rbegin(), m_strides.rend());
} }
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
: m_type(t), m_dyn_dims(std::move(dims))
{
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type; shape::type_t m_type;
std::vector<std::size_t> m_lens = {}; std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {}; std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {}; std::vector<shape> m_shapes = {};
bool m_standard = false; bool m_standard = false;
std::vector<shape::dynamic_dimension> m_dyn_dims = {};
void calculate_strides() void calculate_strides()
{ {
m_strides.clear(); m_strides.clear();
...@@ -87,6 +96,12 @@ struct shape_impl ...@@ -87,6 +96,12 @@ struct shape_impl
std::size_t element_space() const std::size_t element_space() const
{ {
if(not m_dyn_dims.empty())
{
auto maxes = max_lens();
return std::accumulate(maxes.begin(), maxes.end(), std::size_t{1}, std::multiplies<>());
}
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
return 0; return 0;
...@@ -101,6 +116,11 @@ struct shape_impl ...@@ -101,6 +116,11 @@ struct shape_impl
std::size_t elements() const std::size_t elements() const
{ {
if(not m_dyn_dims.empty())
{
MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
}
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
return 0; return 0;
...@@ -108,6 +128,35 @@ struct shape_impl ...@@ -108,6 +128,35 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::vector<std::size_t> min_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.min; });
return ret;
}
std::vector<std::size_t> max_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.max; });
return ret;
}
std::vector<std::size_t> opt_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.opt; });
return ret;
}
// Does the shape skip over elements? // Does the shape skip over elements?
bool skips() const bool skips() const
{ {
...@@ -165,6 +214,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -165,6 +214,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{ {
} }
shape::shape(type_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {} shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
...@@ -180,9 +239,13 @@ shape shape::from_permutation(type_t t, ...@@ -180,9 +239,13 @@ shape shape::from_permutation(type_t t,
} }
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::elements() const { return impl->elements(); } std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const std::size_t shape::bytes() const
{ {
if(this->sub_shapes().empty()) if(this->sub_shapes().empty())
...@@ -199,6 +262,7 @@ std::size_t shape::bytes() const ...@@ -199,6 +262,7 @@ std::size_t shape::bytes() const
[&](auto x, auto y) { return x + y.bytes(); }); [&](auto x, auto y) { return x + y.bytes(); });
} }
} }
std::size_t shape::type_size() const std::size_t shape::type_size() const
{ {
std::size_t n = 0; std::size_t n = 0;
...@@ -206,20 +270,35 @@ std::size_t shape::type_size() const ...@@ -206,20 +270,35 @@ std::size_t shape::type_size() const
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n; return n;
} }
std::size_t shape::index(std::initializer_list<std::size_t> l) const std::size_t shape::index(std::initializer_list<std::size_t> l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(const std::vector<std::size_t>& l) const std::size_t shape::index(const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->standard()) if(this->standard())
return i; return i;
...@@ -267,12 +346,20 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -267,12 +346,20 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool shape::packed() const bool shape::packed() const
{ {
if(this->dynamic())
{
return false;
}
return this->sub_shapes().empty() and not impl->skips() and return this->sub_shapes().empty() and not impl->skips() and
this->elements() == this->element_space(); this->elements() == this->element_space();
} }
bool shape::transposed() const bool shape::transposed() const
{ {
if(this->dynamic())
{
return false;
}
if(this->broadcasted()) if(this->broadcasted())
{ {
// TODO: Use a filter_iterator instead // TODO: Use a filter_iterator instead
...@@ -292,6 +379,10 @@ bool shape::transposed() const ...@@ -292,6 +379,10 @@ bool shape::transposed() const
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::any_of( return std::any_of(
this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; }); this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
...@@ -299,6 +390,10 @@ bool shape::broadcasted() const ...@@ -299,6 +390,10 @@ bool shape::broadcasted() const
bool shape::scalar() const bool shape::scalar() const
{ {
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false // if any stride > 0, then accumulate will return false
return this->sub_shapes().empty() and return this->sub_shapes().empty() and
...@@ -317,6 +412,10 @@ shape shape::normalize_standard() const ...@@ -317,6 +412,10 @@ shape shape::normalize_standard() const
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
assert(l.size() == this->lens().size()); assert(l.size() == this->lens().size());
auto perm = find_permutation(*this); auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm); return shape::from_permutation(t, l, perm);
...@@ -324,6 +423,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const ...@@ -324,6 +423,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
shape shape::with_lens(const std::vector<std::size_t>& l) const shape shape::with_lens(const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
return this->with_lens(this->type(), l); return this->with_lens(this->type(), l);
} }
...@@ -338,20 +441,80 @@ std::size_t shape::element_space() const { return impl->element_space(); } ...@@ -338,20 +441,80 @@ std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_lens() const
{
return this->dynamic() ? impl->min_lens() : this->lens();
}
std::vector<std::size_t> shape::max_lens() const
{
return this->dynamic() ? impl->max_lens() : this->lens();
}
std::vector<std::size_t> shape::opt_lens() const
{
return this->dynamic() ? impl->opt_lens() : this->lens();
}
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return !(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and if(x.dynamic() and y.dynamic())
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes()); {
return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
x.sub_shapes() == y.sub_shapes());
}
return x.impl == y.impl or
(x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
} }
bool operator!=(const shape& x, const shape& y) { return !(x == y); } bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x) std::ostream& operator<<(std::ostream& os, const shape& x)
{ {
if(x.sub_shapes().empty()) if(x.sub_shapes().empty())
{ {
os << x.type_string() << ", "; if(x.dynamic())
os << "{" << to_string_range(x.lens()) << "}, "; {
os << "{" << to_string_range(x.strides()) << "}"; os << "dynamic, ";
os << x.type_string() << ", ";
os << "{" << to_string_range(x.dyn_dims()) << "}";
}
else
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
}
} }
else else
{ {
...@@ -375,12 +538,14 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; } ...@@ -375,12 +538,14 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s) void migraphx_to_value(value& v, const shape& s)
{ {
value result; value result;
result["type"] = migraphx::to_value(s.type_string()); result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens()); result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides()); result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
v = result; result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
v = result;
} }
void migraphx_from_value(const value& v, shape& s) void migraphx_from_value(const value& v, shape& s)
{ {
auto t = v.at("type").get_string(); auto t = v.at("type").get_string();
...@@ -390,9 +555,25 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -390,9 +555,25 @@ void migraphx_from_value(const value& v, shape& s)
} }
else else
{ {
s = shape{shape::parse_type(t), if(v.at("dynamic_dimensions").empty())
v.at("lens").to_vector<std::size_t>(), {
v.at("strides").to_vector<std::size_t>()}; s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
}
else
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) {
auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>();
auto x_opt = x.at("opt").template to<size_t>();
return shape::dynamic_dimension{x_min, x_max, x_opt};
});
s = shape{shape::parse_type(t), dyn_dims};
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment