Unverified Commit 79e15ca9 authored by varunsh's avatar varunsh Committed by GitHub
Browse files

Update is_supported (#1334)

* Update is_supported
* Return object from is_supported
* Return by reference in interator
parent b691abdd
...@@ -90,7 +90,6 @@ add_library(migraphx ...@@ -90,7 +90,6 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
...@@ -21,16 +21,24 @@ ...@@ -21,16 +21,24 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#include <migraphx/target_assignments.hpp> #include <unordered_set>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void target_assignments::add_assignment(instruction_ref ins, const std::string& target) struct supported_segment
{ {
assignments.emplace(ins, target); std::unordered_set<instruction_ref> instructions;
} float metric;
};
using supported_segments = std::vector<supported_segment>;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
...@@ -37,8 +37,10 @@ ...@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp> #include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -64,12 +66,12 @@ struct target ...@@ -64,12 +66,12 @@ struct target
*/ */
context get_context() const; context get_context() const;
/** /**
* @brief Check how well an instruction is supported on a target with the given metric * @brief Get the ranges of instructions that are supported on a target
* @param ins Instruction to check if it's supported * @param module Module to check for supported instructions
* @param metric Used to define how the return value should be interpreted * @param metric Used to define how the quality of the support should be measured
* @return The value based on the chosen metric. Negative numbers mean unsupported * @return the supported segments of the graph
*/ */
float is_supported(T&, instruction_ref ins, support_metric m) const; supported_segments target_is_supported(T&, const_module_ref mod, support_metric metric) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg) ...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
} }
template <class T> template <class T>
float target_is_supported(T&, instruction_ref, support_metric) supported_segments target_find_supported(T&, const_module_ref, support_metric)
{ {
return 0; return {};
} }
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
...@@ -132,7 +134,7 @@ struct target ...@@ -132,7 +134,7 @@ struct target
// //
context get_context() const; context get_context() const;
// (optional) // (optional)
float is_supported(instruction_ref ins, support_metric m) const; supported_segments find_supported(const_module_ref mod, support_metric m) const;
// (optional) // (optional)
argument copy_to(const argument& input) const; argument copy_to(const argument& input) const;
// (optional) // (optional)
...@@ -224,10 +226,10 @@ struct target ...@@ -224,10 +226,10 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
float is_supported(instruction_ref ins, support_metric m) const supported_segments find_supported(const_module_ref mod, support_metric m) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m); return (*this).private_detail_te_get_handle().find_supported(mod, m);
} }
argument copy_to(const argument& input) const argument copy_to(const argument& input) const
...@@ -261,33 +263,33 @@ struct target ...@@ -261,33 +263,33 @@ struct target
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx, virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0; const compile_options& options) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0; virtual supported_segments find_supported(const_module_ref mod, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0; virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0; virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
template <class T> template <class T>
static auto private_detail_te_default_is_supported(char, static auto private_detail_te_default_find_supported(char,
T&& private_detail_te_self, T&& private_detail_te_self,
instruction_ref ins, const_module_ref mod,
support_metric m) support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m)) -> decltype(private_detail_te_self.find_supported(mod, m))
{ {
return private_detail_te_self.is_supported(ins, m); return private_detail_te_self.find_supported(mod, m);
} }
template <class T> template <class T>
static float private_detail_te_default_is_supported(float, static supported_segments private_detail_te_default_find_supported(float,
T&& private_detail_te_self, T&& private_detail_te_self,
instruction_ref ins, const_module_ref mod,
support_metric m) support_metric m)
{ {
return target_is_supported(private_detail_te_self, ins, m); return target_find_supported(private_detail_te_self, mod, m);
} }
template <class T> template <class T>
...@@ -372,10 +374,11 @@ struct target ...@@ -372,10 +374,11 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override supported_segments find_supported(const_module_ref mod, support_metric m) const override
{ {
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m); return private_detail_te_default_find_supported(
char(0), private_detail_te_value, mod, m);
} }
argument copy_to(const argument& input) const override argument copy_to(const argument& input) const override
......
...@@ -33,10 +33,20 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,10 +33,20 @@ inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments struct target_assignments
{ {
void add_assignment(instruction_ref ins, const std::string& target); using iterator = std::unordered_map<instruction_ref, std::string>::const_iterator;
using value_type = std::pair<instruction_ref, std::string>;
auto begin() const { return assignments.cbegin(); } auto size() const { return assignments.size(); }
auto end() const { return assignments.cend(); } auto& at(instruction_ref ins) const { return assignments.at(ins); }
auto insert(iterator it, const std::pair<instruction_ref, std::string>& assignment)
{
return assignments.insert(it, assignment);
}
auto find(instruction_ref ins) const { return assignments.find(ins); }
auto begin() const { return assignments.begin(); }
auto end() const { return assignments.end(); }
private: private:
std::unordered_map<instruction_ref, std::string> assignments; std::unordered_map<instruction_ref, std::string> assignments;
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <migraphx/output_iterator.hpp> #include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp> #include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
...@@ -167,13 +168,37 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta ...@@ -167,13 +168,37 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
target_assignments p; target_assignments p;
const auto* mod = get_main_module(); const auto* mod = get_main_module();
for(auto it : iterator_for(*mod)) std::vector<std::pair<target, supported_segments>> target_subgraphs;
target_subgraphs.reserve(targets.size());
std::transform(targets.begin(),
targets.end(),
std::back_inserter(target_subgraphs),
[&](const auto& t) { return std::make_pair(t, t.find_supported(mod, m)); });
for(const auto ins : iterator_for(*mod))
{ {
auto t = std::max_element( if(contains(p, ins))
targets.begin(), targets.end(), [it, m](const target& lhs, const target& rhs) { {
return lhs.is_supported(it, m) < rhs.is_supported(it, m); continue;
}); }
p.add_assignment(it, t->name());
for(const auto& [target, subgraph] : target_subgraphs)
{
// can't pass a structured binding into lambda in C++17 so create a variable for it
const auto& t = target;
for(const auto& segment : subgraph)
{
const auto& instructions = segment.instructions;
if(not contains(instructions, ins))
{
continue;
}
std::transform(instructions.begin(),
instructions.end(),
std::inserter(p, p.end()),
[&](auto instr) { return std::make_pair(instr, t.name()); });
}
}
} }
return p; return p;
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/fpga/context.hpp> #include <migraphx/fpga/context.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,7 +42,7 @@ struct target ...@@ -41,7 +42,7 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
float is_supported(instruction_ref ins, support_metric m); supported_segments find_supported(const_module_ref mod, support_metric m) const;
argument copy_to(const argument& arg) const { return arg; } argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; } argument copy_from(const argument& arg) const { return arg; }
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -62,12 +63,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -62,12 +63,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
argument target::allocate(const shape& s) const { return fill_argument(s, 0); } argument target::allocate(const shape& s) const { return fill_argument(s, 0); }
float is_supported(instruction_ref ins, support_metric m) supported_segments target::find_supported(const_module_ref mod, support_metric m) const
{ {
// for now, not using the ins and metric to return a value
(void)ins;
(void)m; (void)m;
return 1.0;
supported_segment instrs;
for(const auto ins : iterator_for(*mod))
{
instrs.instructions.insert(ins);
}
instrs.metric = 1; // arbitrary value
return {instrs};
} }
MIGRAPHX_REGISTER_TARGET(target); MIGRAPHX_REGISTER_TARGET(target);
......
...@@ -26,8 +26,9 @@ ...@@ -26,8 +26,9 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/fpga/target.hpp>
#include <migraphx/target_assignments.hpp> #include <migraphx/target_assignments.hpp>
#include <migraphx/iterator_for.hpp>
migraphx::program create_program() migraphx::program create_program()
{ {
...@@ -37,8 +38,8 @@ migraphx::program create_program() ...@@ -37,8 +38,8 @@ migraphx::program create_program()
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s); auto z = mm->add_parameter("z", s);
auto diff = mm->add_instruction(migraphx::make_op("div"), x, y); auto diff = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::make_op("div"), diff, z); mm->add_instruction(migraphx::make_op("add"), diff, z);
return p; return p;
} }
...@@ -47,14 +48,16 @@ TEST_CASE(is_supported) ...@@ -47,14 +48,16 @@ TEST_CASE(is_supported)
auto p = create_program(); auto p = create_program();
auto targets = migraphx::get_targets(); auto targets = migraphx::get_targets();
EXPECT(!targets.empty()); EXPECT(!targets.empty());
auto first_target = targets[0]; auto t = migraphx::make_target("fpga");
auto t = migraphx::make_target(first_target);
const auto assignments = p.get_target_assignments({t}); const auto assignments = p.get_target_assignments({t});
for(const auto& [ins, target] : assignments) const auto* mod = p.get_main_module();
EXPECT(mod->size() == assignments.size());
for(const auto ins : iterator_for(*mod))
{ {
(void)ins; const auto& target = assignments.at(ins);
EXPECT(target == first_target); EXPECT(target == "fpga");
} }
} }
......
...@@ -37,8 +37,10 @@ ...@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp> #include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -64,12 +66,12 @@ struct target ...@@ -64,12 +66,12 @@ struct target
*/ */
context get_context() const; context get_context() const;
/** /**
* @brief Check how well an instruction is supported on a target with the given metric * @brief Get the ranges of instructions that are supported on a target
* @param ins Instruction to check if it's supported * @param module Module to check for supported instructions
* @param metric Used to define how the return value should be interpreted * @param metric Used to define how the quality of the support should be measured
* @return The value based on the chosen metric. Negative numbers mean unsupported * @return the supported segments of the graph
*/ */
float is_supported(T&, instruction_ref ins, support_metric m) const; supported_segments target_is_supported(T&, const_module_ref mod, support_metric metric) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg) ...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
} }
template <class T> template <class T>
float target_is_supported(T&, instruction_ref, support_metric) supported_segments target_find_supported(T&, const_module_ref, support_metric)
{ {
return 0; return {};
} }
<% <%
...@@ -125,7 +127,7 @@ interface('target', ...@@ -125,7 +127,7 @@ interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True), virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True), virtual('get_context', returns='context', const=True),
virtual('is_supported', returns='float', ins='instruction_ref', m='support_metric', const=True, default='target_is_supported'), virtual('find_supported', returns='supported_segments', mod='const_module_ref', m='support_metric', const=True, default='target_find_supported'),
virtual('copy_to', virtual('copy_to',
returns = 'argument', returns = 'argument',
input = 'const argument&', input = 'const argument&',
......
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
##################################################################################### #####################################################################################
import string, sys, re import string, sys, re
trivial = ['std::size_t', 'instruction_ref', 'support_metric'] trivial = [
'std::size_t', 'instruction_ref', 'support_metric', 'const_module_ref'
]
headers = ''' headers = '''
#include <algorithm> #include <algorithm>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment