Commit 22fee23b authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_check_shapes' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_nms

parents 0746d6a7 d6afa9e9
...@@ -88,6 +88,7 @@ add_library(migraphx ...@@ -88,6 +88,7 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#define MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#include <migraphx/support_metric.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct assignment_options
{
support_metric metric = support_metric::latency;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
...@@ -38,22 +38,34 @@ struct check_shapes ...@@ -38,22 +38,34 @@ struct check_shapes
const shape* begin; const shape* begin;
const shape* end; const shape* end;
const std::string name; const std::string name;
const bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n) check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d)
{ {
} }
template <class Op> template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op) : begin(b), end(e), name(op.name()) check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false)
: begin(b), end(e), name(op.name()), dynamic_allowed(d)
{ {
} }
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.data()), end(s.data() + s.size()), name(op.name()) : begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d)
{ {
} }
~check_shapes()
{
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
{
std::cerr << prefix() << "Dynamic shapes not supported" << std::endl;
std::abort();
}
}
std::string prefix() const std::string prefix() const
{ {
if(name.empty()) if(name.empty())
...@@ -92,6 +104,11 @@ struct check_shapes ...@@ -92,6 +104,11 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check that the first shape has exactly n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
...@@ -104,6 +121,11 @@ struct check_shapes ...@@ -104,6 +121,11 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check that the first shape has a maximum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& max_ndims(std::size_t n) const const check_shapes& max_ndims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
...@@ -117,6 +139,11 @@ struct check_shapes ...@@ -117,6 +139,11 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check that the first shape has a minimum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& min_ndims(std::size_t n) const const check_shapes& min_ndims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
...@@ -130,6 +157,9 @@ struct check_shapes ...@@ -130,6 +157,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same shape.
*/
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(!this->same([](const shape& s) { return s; }))
...@@ -137,6 +167,9 @@ struct check_shapes ...@@ -137,6 +167,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same type.
*/
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(!this->same([](const shape& s) { return s.type(); }))
...@@ -144,6 +177,9 @@ struct check_shapes ...@@ -144,6 +177,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same lens.
*/
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.max_lens(); })) if(!this->same([](const shape& s) { return s.max_lens(); }))
...@@ -151,6 +187,9 @@ struct check_shapes ...@@ -151,6 +187,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same number of dimensions.
*/
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.max_lens().size(); })) if(!this->same([](const shape& s) { return s.max_lens().size(); }))
...@@ -158,6 +197,9 @@ struct check_shapes ...@@ -158,6 +197,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are standard.
*/
const check_shapes& standard() const const check_shapes& standard() const
{ {
if(!this->all_of([](const shape& s) { return s.standard(); })) if(!this->all_of([](const shape& s) { return s.standard(); }))
...@@ -165,6 +207,9 @@ struct check_shapes ...@@ -165,6 +207,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are standard or scalar.
*/
const check_shapes& standard_or_scalar() const const check_shapes& standard_or_scalar() const
{ {
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); })) if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
...@@ -172,6 +217,9 @@ struct check_shapes ...@@ -172,6 +217,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed.
*/
const check_shapes& packed() const const check_shapes& packed() const
{ {
if(!this->all_of([](const shape& s) { return s.packed(); })) if(!this->all_of([](const shape& s) { return s.packed(); }))
...@@ -179,6 +227,9 @@ struct check_shapes ...@@ -179,6 +227,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed or broadcasted.
*/
const check_shapes& packed_or_broadcasted() const const check_shapes& packed_or_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); })) if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
...@@ -186,6 +237,9 @@ struct check_shapes ...@@ -186,6 +237,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are tuples.
*/
const check_shapes& tuple_type() const const check_shapes& tuple_type() const
{ {
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; })) if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
...@@ -193,6 +247,9 @@ struct check_shapes ...@@ -193,6 +247,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are not transposed.
*/
const check_shapes& not_transposed() const const check_shapes& not_transposed() const
{ {
if(!this->all_of([](const shape& s) { return not s.transposed(); })) if(!this->all_of([](const shape& s) { return not s.transposed(); }))
...@@ -200,6 +257,9 @@ struct check_shapes ...@@ -200,6 +257,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are not broadcasted.
*/
const check_shapes& not_broadcasted() const const check_shapes& not_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return not s.broadcasted(); })) if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
...@@ -207,6 +267,10 @@ struct check_shapes ...@@ -207,6 +267,10 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same n elements.
* \param n number of elements
*/
const check_shapes& elements(std::size_t n) const const check_shapes& elements(std::size_t n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
...@@ -214,6 +278,9 @@ struct check_shapes ...@@ -214,6 +278,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check the batches of all the shapes do not have transposed strides.
*/
const check_shapes& batch_not_transposed() const const check_shapes& batch_not_transposed() const
{ {
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); })) if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
...@@ -242,6 +309,16 @@ struct check_shapes ...@@ -242,6 +309,16 @@ struct check_shapes
return std::all_of(begin, end, p); return std::all_of(begin, end, p);
} }
template <class Predicate>
bool any_of(Predicate p) const
{
if(begin == end)
return false;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p);
}
const shape* get(long i) const const shape* get(long i) const
{ {
if(i >= size()) if(i >= size())
......
...@@ -42,11 +42,12 @@ namespace op { ...@@ -42,11 +42,12 @@ namespace op {
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"), f(self.steps, "steps"));
} }
value attributes() const value attributes() const
...@@ -73,6 +74,9 @@ struct unsqueeze ...@@ -73,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
} }
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
...@@ -80,16 +84,27 @@ struct unsqueeze ...@@ -80,16 +84,27 @@ struct unsqueeze
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
new_lens[i] = 1; std::int64_t step = 1;
if(p == 0) // unsqueeze on the first axes if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
new_strides[i] = old_lens[0] * old_strides[0]; if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
} }
else // unsqueeze on middle or last axes else
{ {
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1; if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
} }
} }
else else
......
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
...@@ -84,6 +86,9 @@ struct program ...@@ -84,6 +86,9 @@ struct program
instruction_ref validate() const; instruction_ref validate() const;
target_assignments get_target_assignments(const std::vector<target>& targets,
assignment_options options = assignment_options{});
void compile(const target& t, compile_options options = compile_options{}); void compile(const target& t, compile_options options = compile_options{});
bool is_compiled() const; bool is_compiled() const;
......
...@@ -318,8 +318,9 @@ struct shape ...@@ -318,8 +318,9 @@ struct shape
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
/*! /*!
* Returns size of the data buffer. * Returns the number of elements in the data buffer.
* Assuming a packed shape, returns maximum size of the data buffer for dynamic shape. * For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/ */
std::size_t element_space() const; std::size_t element_space() const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class support_metric
{
latency,
throughput
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
...@@ -37,6 +37,8 @@ ...@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,6 +63,13 @@ struct target ...@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg) ...@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
...@@ -117,6 +132,8 @@ struct target ...@@ -117,6 +132,8 @@ struct target
// //
context get_context() const; context get_context() const;
// (optional) // (optional)
float is_supported(instruction_ref ins, support_metric m) const;
// (optional)
argument copy_to(const argument& input) const; argument copy_to(const argument& input) const;
// (optional) // (optional)
argument copy_from(const argument& input) const; argument copy_from(const argument& input) const;
...@@ -207,6 +224,12 @@ struct target ...@@ -207,6 +224,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
float is_supported(instruction_ref ins, support_metric m) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m);
}
argument copy_to(const argument& input) const argument copy_to(const argument& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -242,11 +265,31 @@ struct target ...@@ -242,11 +265,31 @@ struct target
virtual std::vector<pass> get_passes(context& ctx, virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0; const compile_options& options) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0; virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0; virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
template <class T>
static auto private_detail_te_default_is_supported(char,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m))
{
return private_detail_te_self.is_supported(ins, m);
}
template <class T>
static float private_detail_te_default_is_supported(float,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
{
return target_is_supported(private_detail_te_self, ins, m);
}
template <class T> template <class T>
static auto static auto
private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input) private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input)
...@@ -329,6 +372,12 @@ struct target ...@@ -329,6 +372,12 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override
{
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m);
}
argument copy_to(const argument& input) const override argument copy_to(const argument& input) const override
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments
{
void add_assignment(instruction_ref ins, const std::string& target);
auto begin() const { return assignments.cbegin(); }
auto end() const { return assignments.cend(); }
private:
std::unordered_map<instruction_ref, std::string> assignments;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
...@@ -159,6 +159,25 @@ instruction_ref program::validate() const ...@@ -159,6 +159,25 @@ instruction_ref program::validate() const
return mm->validate(); return mm->validate();
} }
target_assignments program::get_target_assignments(const std::vector<target>& targets,
assignment_options options)
{
const auto m = options.metric;
target_assignments p;
const auto* mod = get_main_module();
for(auto it : iterator_for(*mod))
{
auto t = std::max_element(
targets.begin(), targets.end(), [it, m](const target& lhs, const target& rhs) {
return lhs.is_supported(it, m) < rhs.is_supported(it, m);
});
p.add_assignment(it, t->name());
}
return p;
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->target_name.empty(); }
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
...@@ -721,11 +740,13 @@ void program::perf_report(std::ostream& os, ...@@ -721,11 +740,13 @@ void program::perf_report(std::ostream& os,
double overhead_percent = overhead_time * 100.0 / total_time; double overhead_percent = overhead_time * 100.0 / total_time;
double total_instruction_time = 0.0; double total_instruction_time = 0.0;
std::unordered_map<std::string, double> op_times; std::unordered_map<std::string, double> op_times;
std::unordered_map<std::string, std::size_t> op_n;
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[perf_group(p.first->get_operator())] += avg; op_times[perf_group(p.first->get_operator())] += avg;
total_instruction_time += avg; total_instruction_time += avg;
op_n[perf_group(p.first->get_operator())]++;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
...@@ -746,18 +767,19 @@ void program::perf_report(std::ostream& os, ...@@ -746,18 +767,19 @@ void program::perf_report(std::ostream& os,
os << std::endl; os << std::endl;
os << "Summary:" << std::endl; os << "Summary:" << std::endl;
std::vector<std::pair<double, std::string>> op_times_sorted; std::vector<std::tuple<double, std::size_t, std::string>> op_times_sorted;
std::transform(op_times.begin(), std::transform(
op_times.end(), op_times.begin(), op_times.end(), std::back_inserter(op_times_sorted), [&](auto p) {
std::back_inserter(op_times_sorted), auto&& name = p.first;
[](auto p) { return std::make_pair(p.second, p.first); }); return std::make_tuple(p.second, op_n.at(name), name);
});
std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{}); std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{});
for(auto&& p : op_times_sorted) for(auto&& [avg, nn, name] : op_times_sorted)
{ {
auto&& name = p.second;
double avg = p.first;
double percent = std::ceil(100.0 * avg / total_instruction_time); double percent = std::ceil(100.0 * avg / total_instruction_time);
os << name << ": " << avg << "ms, " << percent << "%" << std::endl; double per_ins = avg / nn;
os << name << ": " << avg << "ms / " << nn << " = " << per_ins << "ms, " << percent << "%"
<< std::endl;
} }
os << std::endl; os << std::endl;
......
...@@ -272,7 +272,7 @@ struct find_concat_transpose ...@@ -272,7 +272,7 @@ struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::name("transpose")));
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target_assignments.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void target_assignments::add_assignment(instruction_ref ins, const std::string& target)
{
assignments.emplace(ins, target);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -43,6 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -43,6 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG_SYM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
...@@ -227,6 +228,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -227,6 +228,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(params.find("-std=") == std::string::npos) if(params.find("-std=") == std::string::npos)
params += " --std=c++17"; params += " --std=c++17";
params += " -fno-gpu-rdc"; params += " -fno-gpu-rdc";
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g";
params += " -c"; params += " -c";
if(is_hcc_compiler()) if(is_hcc_compiler())
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "test.hpp"
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/target_assignments.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto diff = mm->add_instruction(migraphx::make_op("div"), x, y);
mm->add_instruction(migraphx::make_op("div"), diff, z);
return p;
}
TEST_CASE(is_supported)
{
auto p = create_program();
auto targets = migraphx::get_targets();
EXPECT(!targets.empty());
auto first_target = targets[0];
auto t = migraphx::make_target(first_target);
const auto assignments = p.get_target_assignments({t});
for(const auto& [ins, target] : assignments)
{
(void)ins;
EXPECT(target == first_target);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1534,15 +1534,46 @@ TEST_CASE(test_squeeze_wrong_axis) ...@@ -1534,15 +1534,46 @@ TEST_CASE(test_squeeze_wrong_axis)
TEST_CASE(test_unsqueeze) TEST_CASE(test_unsqueeze)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_step)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 2, 6}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_step_non_divisable)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_step_zero)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_step_at_end)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {3}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_mismatch_step_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis) TEST_CASE(test_unsqueeze_negative_axis)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
} }
...@@ -1568,21 +1599,28 @@ TEST_CASE(test_unsqueeze_scalar_tensor2) ...@@ -1568,21 +1599,28 @@ TEST_CASE(test_unsqueeze_scalar_tensor2)
TEST_CASE(test_unsqueeze_transpose) TEST_CASE(test_unsqueeze_transpose)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}}; migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 1, 4}}; migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 12, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_transpose_step)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 6}, {24, 1, 4}};
migraphx::shape s2{migraphx::shape::float_type, {4, 4, 2, 3}, {24, 1, 12, 4}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_multibroadcast) TEST_CASE(test_unsqueeze_multibroadcast)
{ {
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}}; migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 0, 0}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_slice) TEST_CASE(test_unsqueeze_slice)
{ {
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 36, 1}}; migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 4, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
...@@ -1614,6 +1652,27 @@ TEST_CASE(test_unsqueeze_multiple_axes_2) ...@@ -1614,6 +1652,27 @@ TEST_CASE(test_unsqueeze_multiple_axes_2)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1);
} }
TEST_CASE(test_unsqueeze_multiple_axes_3)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 5}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 1, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2, 4, 5}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_4)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 5}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 1, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {5, 4, 2}}}), s1);
}
TEST_CASE(test_unsqueeze_multiple_axes_step)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 10}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 2, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2, 4, 5}}, {"steps", {2}}}), s1);
}
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
...@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast) ...@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_unsqueeze_concat)
{
migraphx::module m1;
{
auto l0 = m1.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l1 = m1.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
auto l2 = m1.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 3;
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
m1.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), unsqueezed_args);
}
// TODO: This could be simplified to a single transpose after concat
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice) TEST_CASE(transpose_slice)
{ {
migraphx::module m1; migraphx::module m1;
......
...@@ -37,6 +37,8 @@ ...@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,6 +63,13 @@ struct target ...@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -105,11 +114,18 @@ argument copy_from_target(T&, const argument& arg) ...@@ -105,11 +114,18 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
<% <%
interface('target', interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True), virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True), virtual('get_context', returns='context', const=True),
virtual('is_supported', returns='float', ins='instruction_ref', m='support_metric', const=True, default='target_is_supported'),
virtual('copy_to', virtual('copy_to',
returns = 'argument', returns = 'argument',
input = 'const argument&', input = 'const argument&',
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
##################################################################################### #####################################################################################
import string, sys, re import string, sys, re
trivial = ['std::size_t', 'instruction_ref'] trivial = ['std::size_t', 'instruction_ref', 'support_metric']
headers = ''' headers = '''
#include <algorithm> #include <algorithm>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment