Commit 6d582c24 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into rewrite-fast-gelu

parents 6f692ebd 05b13c9f
...@@ -89,6 +89,7 @@ add_library(migraphx ...@@ -89,6 +89,7 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#define MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#include <migraphx/support_metric.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct assignment_options
{
support_metric metric = support_metric::latency;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
...@@ -71,6 +71,11 @@ struct check_shapes ...@@ -71,6 +71,11 @@ struct check_shapes
return end - begin; return end - begin;
} }
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template <class... Ts> template <class... Ts>
const check_shapes& has(Ts... ns) const const check_shapes& has(Ts... ns) const
{ {
......
...@@ -42,11 +42,12 @@ namespace op { ...@@ -42,11 +42,12 @@ namespace op {
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"), f(self.steps, "steps"));
} }
value attributes() const value attributes() const
...@@ -73,6 +74,9 @@ struct unsqueeze ...@@ -73,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
} }
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
...@@ -80,16 +84,27 @@ struct unsqueeze ...@@ -80,16 +84,27 @@ struct unsqueeze
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
new_lens[i] = 1; std::int64_t step = 1;
if(p == 0) // unsqueeze on the first axes if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
new_strides[i] = old_lens[0] * old_strides[0]; if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
} }
else // unsqueeze on middle or last axes else
{ {
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1; if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
} }
} }
else else
......
...@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op) ...@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return result; return result;
} }
/*!
* Returns the permutation needed to apply to the shape to undo the current permutation
*/
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation); std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
/*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape.
*/
std::vector<int64_t> find_permutation(const shape& s); std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes); std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
......
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
...@@ -84,6 +86,9 @@ struct program ...@@ -84,6 +86,9 @@ struct program
instruction_ref validate() const; instruction_ref validate() const;
target_assignments get_target_assignments(const std::vector<target>& targets,
assignment_options options = assignment_options{});
void compile(const target& t, compile_options options = compile_options{}); void compile(const target& t, compile_options options = compile_options{});
bool is_compiled() const; bool is_compiled() const;
......
...@@ -82,6 +82,23 @@ struct shape ...@@ -82,6 +82,23 @@ struct shape
{ {
}; };
struct dynamic_dimension
{
std::size_t min = 0;
std::size_t max = 0;
std::size_t opt = 0;
template <class Self, class F>
static auto reflect(Self& self, F f);
bool is_fixed() const;
bool has_optimal() const;
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
};
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
static std::string name(type_t t); static std::string name(type_t t);
...@@ -92,6 +109,12 @@ struct shape ...@@ -92,6 +109,12 @@ struct shape
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s); shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(type_t t, std::initializer_list<std::size_t> d);
shape(type_t t, std::vector<dynamic_dimension> dims);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{ {
...@@ -112,10 +135,44 @@ struct shape ...@@ -112,10 +135,44 @@ struct shape
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*!
* Return the number of elements in the tensor.
*/
std::size_t elements() const; std::size_t elements() const;
/*!
* Return the number of total bytes used for storage of the tensor data; includes subshapes.
* For dynamic shape, returns the maximum number of bytes presuming a packed shape.
*/
std::size_t bytes() const; std::size_t bytes() const;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std::size_t type_size() const; std::size_t type_size() const;
const std::vector<dynamic_dimension>& dyn_dims() const;
/*!
* Minimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> min_lens() const;
/*!
* Maximum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> max_lens() const;
/*!
* Optimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> opt_lens() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
/// Map multiple indices to space index /// Map multiple indices to space index
...@@ -136,19 +193,27 @@ struct shape ...@@ -136,19 +193,27 @@ struct shape
std::vector<std::size_t> multi(std::size_t i) const; std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const; void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
/// order /// order
bool transposed() const; bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero /// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const; bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and /// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed. /// not transposed.
bool standard() const; bool standard() const;
/// Returns true if all strides are equal to 0 (scalar tensor) /// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const; bool scalar() const;
/// Return true if the shape is dynamic
bool dynamic() const;
shape normalize_standard() const; shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const; shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
...@@ -252,6 +317,11 @@ struct shape ...@@ -252,6 +317,11 @@ struct shape
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
/*!
* Returns the number of elements in the data buffer.
* For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/
std::size_t element_space() const; std::size_t element_space() const;
private: private:
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class support_metric
{
latency,
throughput
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
...@@ -37,6 +37,8 @@ ...@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,6 +63,13 @@ struct target ...@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg) ...@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
...@@ -117,6 +132,8 @@ struct target ...@@ -117,6 +132,8 @@ struct target
// //
context get_context() const; context get_context() const;
// (optional) // (optional)
float is_supported(instruction_ref ins, support_metric m) const;
// (optional)
argument copy_to(const argument& input) const; argument copy_to(const argument& input) const;
// (optional) // (optional)
argument copy_from(const argument& input) const; argument copy_from(const argument& input) const;
...@@ -207,6 +224,12 @@ struct target ...@@ -207,6 +224,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
float is_supported(instruction_ref ins, support_metric m) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m);
}
argument copy_to(const argument& input) const argument copy_to(const argument& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -242,11 +265,31 @@ struct target ...@@ -242,11 +265,31 @@ struct target
virtual std::vector<pass> get_passes(context& ctx, virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0; const compile_options& options) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0; virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0; virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
template <class T>
static auto private_detail_te_default_is_supported(char,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m))
{
return private_detail_te_self.is_supported(ins, m);
}
template <class T>
static float private_detail_te_default_is_supported(float,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
{
return target_is_supported(private_detail_te_self, ins, m);
}
template <class T> template <class T>
static auto static auto
private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input) private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input)
...@@ -329,6 +372,12 @@ struct target ...@@ -329,6 +372,12 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override
{
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m);
}
argument copy_to(const argument& input) const override argument copy_to(const argument& input) const override
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments
{
void add_assignment(instruction_ref ins, const std::string& target);
auto begin() const { return assignments.cbegin(); }
auto end() const { return assignments.cend(); }
private:
std::unordered_map<instruction_ref, std::string> assignments;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
...@@ -159,6 +159,25 @@ instruction_ref program::validate() const ...@@ -159,6 +159,25 @@ instruction_ref program::validate() const
return mm->validate(); return mm->validate();
} }
target_assignments program::get_target_assignments(const std::vector<target>& targets,
assignment_options options)
{
const auto m = options.metric;
target_assignments p;
const auto* mod = get_main_module();
for(auto it : iterator_for(*mod))
{
auto t = std::max_element(
targets.begin(), targets.end(), [it, m](const target& lhs, const target& rhs) {
return lhs.is_supported(it, m) < rhs.is_supported(it, m);
});
p.add_assignment(it, t->name());
}
return p;
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->target_name.empty(); }
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
...@@ -683,11 +702,13 @@ void program::perf_report(std::ostream& os, ...@@ -683,11 +702,13 @@ void program::perf_report(std::ostream& os,
double overhead_percent = overhead_time * 100.0 / total_time; double overhead_percent = overhead_time * 100.0 / total_time;
double total_instruction_time = 0.0; double total_instruction_time = 0.0;
std::unordered_map<std::string, double> op_times; std::unordered_map<std::string, double> op_times;
std::unordered_map<std::string, std::size_t> op_n;
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[perf_group(p.first->get_operator())] += avg; op_times[perf_group(p.first->get_operator())] += avg;
total_instruction_time += avg; total_instruction_time += avg;
op_n[perf_group(p.first->get_operator())]++;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
...@@ -708,18 +729,19 @@ void program::perf_report(std::ostream& os, ...@@ -708,18 +729,19 @@ void program::perf_report(std::ostream& os,
os << std::endl; os << std::endl;
os << "Summary:" << std::endl; os << "Summary:" << std::endl;
std::vector<std::pair<double, std::string>> op_times_sorted; std::vector<std::tuple<double, std::size_t, std::string>> op_times_sorted;
std::transform(op_times.begin(), std::transform(
op_times.end(), op_times.begin(), op_times.end(), std::back_inserter(op_times_sorted), [&](auto p) {
std::back_inserter(op_times_sorted), auto&& name = p.first;
[](auto p) { return std::make_pair(p.second, p.first); }); return std::make_tuple(p.second, op_n.at(name), name);
});
std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{}); std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{});
for(auto&& p : op_times_sorted) for(auto&& [avg, nn, name] : op_times_sorted)
{ {
auto&& name = p.second;
double avg = p.first;
double percent = std::ceil(100.0 * avg / total_instruction_time); double percent = std::ceil(100.0 * avg / total_instruction_time);
os << name << ": " << avg << "ms, " << percent << "%" << std::endl; double per_ins = avg / nn;
os << name << ": " << avg << "ms / " << nn << " = " << per_ins << "ms, " << percent << "%"
<< std::endl;
} }
os << std::endl; os << std::endl;
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
...@@ -65,13 +66,21 @@ struct shape_impl ...@@ -65,13 +66,21 @@ struct shape_impl
std::is_sorted(m_strides.rbegin(), m_strides.rend()); std::is_sorted(m_strides.rbegin(), m_strides.rend());
} }
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
: m_type(t), m_dyn_dims(std::move(dims))
{
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type; shape::type_t m_type;
std::vector<std::size_t> m_lens = {}; std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {}; std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {}; std::vector<shape> m_shapes = {};
bool m_standard = false; bool m_standard = false;
std::vector<shape::dynamic_dimension> m_dyn_dims = {};
void calculate_strides() void calculate_strides()
{ {
m_strides.clear(); m_strides.clear();
...@@ -87,6 +96,12 @@ struct shape_impl ...@@ -87,6 +96,12 @@ struct shape_impl
std::size_t element_space() const std::size_t element_space() const
{ {
if(not m_dyn_dims.empty())
{
auto maxes = max_lens();
return std::accumulate(maxes.begin(), maxes.end(), std::size_t{1}, std::multiplies<>());
}
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
return 0; return 0;
...@@ -101,6 +116,11 @@ struct shape_impl ...@@ -101,6 +116,11 @@ struct shape_impl
std::size_t elements() const std::size_t elements() const
{ {
if(not m_dyn_dims.empty())
{
MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
}
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
return 0; return 0;
...@@ -108,6 +128,35 @@ struct shape_impl ...@@ -108,6 +128,35 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::vector<std::size_t> min_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.min; });
return ret;
}
std::vector<std::size_t> max_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.max; });
return ret;
}
std::vector<std::size_t> opt_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.opt; });
return ret;
}
// Does the shape skip over elements? // Does the shape skip over elements?
bool skips() const bool skips() const
{ {
...@@ -165,6 +214,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -165,6 +214,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{ {
} }
shape::shape(type_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {} shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
...@@ -180,9 +239,13 @@ shape shape::from_permutation(type_t t, ...@@ -180,9 +239,13 @@ shape shape::from_permutation(type_t t,
} }
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::elements() const { return impl->elements(); } std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const std::size_t shape::bytes() const
{ {
if(this->sub_shapes().empty()) if(this->sub_shapes().empty())
...@@ -199,6 +262,7 @@ std::size_t shape::bytes() const ...@@ -199,6 +262,7 @@ std::size_t shape::bytes() const
[&](auto x, auto y) { return x + y.bytes(); }); [&](auto x, auto y) { return x + y.bytes(); });
} }
} }
std::size_t shape::type_size() const std::size_t shape::type_size() const
{ {
std::size_t n = 0; std::size_t n = 0;
...@@ -206,20 +270,35 @@ std::size_t shape::type_size() const ...@@ -206,20 +270,35 @@ std::size_t shape::type_size() const
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n; return n;
} }
std::size_t shape::index(std::initializer_list<std::size_t> l) const std::size_t shape::index(std::initializer_list<std::size_t> l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(const std::vector<std::size_t>& l) const std::size_t shape::index(const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
} }
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->standard()) if(this->standard())
return i; return i;
...@@ -267,12 +346,20 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -267,12 +346,20 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool shape::packed() const bool shape::packed() const
{ {
if(this->dynamic())
{
return false;
}
return this->sub_shapes().empty() and not impl->skips() and return this->sub_shapes().empty() and not impl->skips() and
this->elements() == this->element_space(); this->elements() == this->element_space();
} }
bool shape::transposed() const bool shape::transposed() const
{ {
if(this->dynamic())
{
return false;
}
if(this->broadcasted()) if(this->broadcasted())
{ {
// TODO: Use a filter_iterator instead // TODO: Use a filter_iterator instead
...@@ -292,6 +379,10 @@ bool shape::transposed() const ...@@ -292,6 +379,10 @@ bool shape::transposed() const
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::any_of( return std::any_of(
this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; }); this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
...@@ -299,6 +390,10 @@ bool shape::broadcasted() const ...@@ -299,6 +390,10 @@ bool shape::broadcasted() const
bool shape::scalar() const bool shape::scalar() const
{ {
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false // if any stride > 0, then accumulate will return false
return this->sub_shapes().empty() and return this->sub_shapes().empty() and
...@@ -317,6 +412,10 @@ shape shape::normalize_standard() const ...@@ -317,6 +412,10 @@ shape shape::normalize_standard() const
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
assert(l.size() == this->lens().size()); assert(l.size() == this->lens().size());
auto perm = find_permutation(*this); auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm); return shape::from_permutation(t, l, perm);
...@@ -324,6 +423,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const ...@@ -324,6 +423,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
shape shape::with_lens(const std::vector<std::size_t>& l) const shape shape::with_lens(const std::vector<std::size_t>& l) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
return this->with_lens(this->type(), l); return this->with_lens(this->type(), l);
} }
...@@ -338,21 +441,81 @@ std::size_t shape::element_space() const { return impl->element_space(); } ...@@ -338,21 +441,81 @@ std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_lens() const
{
return this->dynamic() ? impl->min_lens() : this->lens();
}
std::vector<std::size_t> shape::max_lens() const
{
return this->dynamic() ? impl->max_lens() : this->lens();
}
std::vector<std::size_t> shape::opt_lens() const
{
return this->dynamic() ? impl->opt_lens() : this->lens();
}
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return !(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
return os;
}
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and if(x.dynamic() and y.dynamic())
{
return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
x.sub_shapes() == y.sub_shapes());
}
return x.impl == y.impl or
(x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes()); x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
} }
bool operator!=(const shape& x, const shape& y) { return !(x == y); } bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x) std::ostream& operator<<(std::ostream& os, const shape& x)
{ {
if(x.sub_shapes().empty()) if(x.sub_shapes().empty())
{
if(x.dynamic())
{
os << "dynamic, ";
os << x.type_string() << ", ";
os << "{" << to_string_range(x.dyn_dims()) << "}";
}
else
{ {
os << x.type_string() << ", "; os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, "; os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}"; os << "{" << to_string_range(x.strides()) << "}";
} }
}
else else
{ {
os << "[" << to_string_range(x.sub_shapes()) << "]"; os << "[" << to_string_range(x.sub_shapes()) << "]";
...@@ -379,8 +542,10 @@ void migraphx_to_value(value& v, const shape& s) ...@@ -379,8 +542,10 @@ void migraphx_to_value(value& v, const shape& s)
result["lens"] = migraphx::to_value(s.lens()); result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides()); result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
v = result; v = result;
} }
void migraphx_from_value(const value& v, shape& s) void migraphx_from_value(const value& v, shape& s)
{ {
auto t = v.at("type").get_string(); auto t = v.at("type").get_string();
...@@ -389,11 +554,27 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -389,11 +554,27 @@ void migraphx_from_value(const value& v, shape& s)
s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))}; s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
} }
else else
{
if(v.at("dynamic_dimensions").empty())
{ {
s = shape{shape::parse_type(t), s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(), v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()}; v.at("strides").to_vector<std::size_t>()};
} }
else
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) {
auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>();
auto x_opt = x.at("opt").template to<size_t>();
return shape::dynamic_dimension{x_min, x_max, x_opt};
});
s = shape{shape::parse_type(t), dyn_dims};
}
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -272,7 +272,7 @@ struct find_concat_transpose ...@@ -272,7 +272,7 @@ struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::name("transpose")));
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target_assignments.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void target_assignments::add_assignment(instruction_ref ins, const std::string& target)
{
assignments.emplace(ins, target);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -43,6 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -43,6 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG_SYM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
...@@ -227,6 +228,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -227,6 +228,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(params.find("-std=") == std::string::npos) if(params.find("-std=") == std::string::npos)
params += " --std=c++17"; params += " --std=c++17";
params += " -fno-gpu-rdc"; params += " -fno-gpu-rdc";
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g";
params += " -c"; params += " -c";
if(is_hcc_compiler()) if(is_hcc_compiler())
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "test.hpp"
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/target_assignments.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto diff = mm->add_instruction(migraphx::make_op("div"), x, y);
mm->add_instruction(migraphx::make_op("div"), diff, z);
return p;
}
TEST_CASE(is_supported)
{
auto p = create_program();
auto targets = migraphx::get_targets();
EXPECT(!targets.empty());
auto first_target = targets[0];
auto t = migraphx::make_target(first_target);
const auto assignments = p.get_target_assignments({t});
for(const auto& [ins, target] : assignments)
{
(void)ins;
EXPECT(target == first_target);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -981,7 +981,8 @@ TEST_CASE(multibroadcast) ...@@ -981,7 +981,8 @@ TEST_CASE(multibroadcast)
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}}; std::vector<std::size_t> empt = {};
migraphx::shape input{migraphx::shape::float_type, empt};
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
{ {
...@@ -1533,15 +1534,46 @@ TEST_CASE(test_squeeze_wrong_axis) ...@@ -1533,15 +1534,46 @@ TEST_CASE(test_squeeze_wrong_axis)
TEST_CASE(test_unsqueeze) TEST_CASE(test_unsqueeze)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_step)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 2, 6}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_step_non_divisable)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_step_zero)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_step_at_end)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {3}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_mismatch_step_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis) TEST_CASE(test_unsqueeze_negative_axis)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
} }
...@@ -1567,21 +1599,28 @@ TEST_CASE(test_unsqueeze_scalar_tensor2) ...@@ -1567,21 +1599,28 @@ TEST_CASE(test_unsqueeze_scalar_tensor2)
TEST_CASE(test_unsqueeze_transpose) TEST_CASE(test_unsqueeze_transpose)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}}; migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 1, 4}}; migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 12, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_transpose_step)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 6}, {24, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 2, 3}, {24, 1, 12, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_multibroadcast) TEST_CASE(test_unsqueeze_multibroadcast)
{ {
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}}; migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 0, 0}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_slice) TEST_CASE(test_unsqueeze_slice)
{ {
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 36, 1}}; migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
...@@ -1613,6 +1652,27 @@ TEST_CASE(test_unsqueeze_multiple_axes_2) ...@@ -1613,6 +1652,27 @@ TEST_CASE(test_unsqueeze_multiple_axes_2)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1);
} }
TEST_CASE(test_unsqueeze_multiple_axes_3)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 5}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 1, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2, 4, 5}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_4)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 5}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 1, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {5, 4, 2}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_step)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 10}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 2, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2, 4, 5}}, {"steps", {2}}}), s1);
}
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
...@@ -38,7 +38,6 @@ TEST_CASE(test_shape_default) ...@@ -38,7 +38,6 @@ TEST_CASE(test_shape_default)
EXPECT(s.elements() == 0); EXPECT(s.elements() == 0);
EXPECT(s.bytes() == 0); EXPECT(s.bytes() == 0);
} }
TEST_CASE(test_shape_assign) TEST_CASE(test_shape_assign)
{ {
migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}}; migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}};
...@@ -65,6 +64,118 @@ TEST_CASE(test_shape_standard) ...@@ -65,6 +64,118 @@ TEST_CASE(test_shape_standard)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_min_max_opt)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
EXPECT(s.min_lens() == s.lens());
EXPECT(s.max_lens() == s.lens());
EXPECT(s.opt_lens() == s.lens());
}
TEST_CASE(test_shape_dynamic_fixed)
{
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.dynamic());
EXPECT(s.dyn_dims().size() == 3);
EXPECT(s.dyn_dims().at(0).is_fixed());
EXPECT(not s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.max_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.opt_lens() == std::vector<std::size_t>{0, 0, 0});
EXPECT(s.bytes() == 2 * 2 * 3 * sizeof(float));
}
TEST_CASE(test_shape_dynamic_not_fixed)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.dynamic());
EXPECT(s.dyn_dims().size() == 2);
EXPECT(not s.dyn_dims().at(0).is_fixed());
EXPECT(s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2});
EXPECT(s.max_lens() == std::vector<std::size_t>{5, 8});
EXPECT(s.opt_lens() == std::vector<std::size_t>{2, 0});
EXPECT(s.bytes() == 5 * 8 * sizeof(float));
}
TEST_CASE(test_shape_dynamic_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2};
auto b = a;
auto c = shape::dynamic_dimension{2, 5, 2};
auto d = shape::dynamic_dimension{3, 8, 4};
EXPECT(a == b);
EXPECT(a == c);
EXPECT(a != d);
migraphx::shape s0{shape::float_type, {a, d}};
migraphx::shape s1 = s0;
migraphx::shape s2{shape::float_type, {a, d}};
migraphx::shape s3{shape::int32_type, {a}};
EXPECT(s0 == s1);
EXPECT(s0 == s2);
EXPECT(s0 != s3);
std::stringstream ss0;
std::stringstream ss1;
std::stringstream ss3;
ss0 << s0;
ss1 << s1;
ss3 << s3;
EXPECT(ss0.str() == ss1.str());
EXPECT(ss0.str() != ss3.str());
}
TEST_CASE(test_shape_dynamic_errors)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
migraphx::shape s{shape::float_type, dims};
EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.index({0, 1}); }));
EXPECT(test::throws([&] { s.index(1); }));
EXPECT(test::throws([&] { s.index(std::vector<std::size_t>{0, 1}); }));
EXPECT(test::throws([&] { s.with_lens({3, 5}); }));
EXPECT(test::throws([&] { s.with_lens(shape::float_type, {3, 5}); }));
}
TEST_CASE(test_shape_dynamic_serialize)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims1 = {};
dims1.push_back(shape::dynamic_dimension{2, 5, 2});
dims1.push_back(shape::dynamic_dimension{2, 8, 0});
migraphx::shape s1{shape::float_type, dims1};
auto v1 = migraphx::to_value(s1);
std::vector<shape::dynamic_dimension> dims2 = {};
dims2.push_back(shape::dynamic_dimension{2, 5, 2});
migraphx::shape s2{shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2);
auto s3 = migraphx::from_value<shape>(v1);
EXPECT(s3 == s1);
auto s4 = migraphx::from_value<shape>(v2);
EXPECT(s4 == s2);
EXPECT(s3 != s4);
}
TEST_CASE(test_shape_packed) TEST_CASE(test_shape_packed)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}};
......
...@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast) ...@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_unsqueeze_concat)
{
migraphx::module m1;
{
auto l0 = m1.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l1 = m1.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
auto l2 = m1.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 3;
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
m1.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), unsqueezed_args);
}
// TODO: This could be simplified to a single transpose after concat
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice) TEST_CASE(transpose_slice)
{ {
migraphx::module m1; migraphx::module m1;
......
...@@ -37,6 +37,8 @@ ...@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,6 +63,13 @@ struct target ...@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -105,11 +114,18 @@ argument copy_from_target(T&, const argument& arg) ...@@ -105,11 +114,18 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
<% <%
interface('target', interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True), virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True), virtual('get_context', returns='context', const=True),
virtual('is_supported', returns='float', ins='instruction_ref', m='support_metric', const=True, default='target_is_supported'),
virtual('copy_to', virtual('copy_to',
returns = 'argument', returns = 'argument',
input = 'const argument&', input = 'const argument&',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment