Unverified Commit 4a4f537e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into gpu-batch-gemm-bert

parents adc03be6 a27dd28c
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_concat::apply(module& p) const void eliminate_concat::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Look for the concat operator // Look for the concat operator
if(ins->name() != concat_opt.name()) if(ins->name() != concat_opt.name())
...@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const ...@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std::sort(sorted_allocations.begin(), std::sort(sorted_allocations.begin(),
sorted_allocations.end(), sorted_allocations.end(),
[&](instruction_ref x, instruction_ref y) { [&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y); return std::distance(m.begin(), x) < std::distance(m.begin(), y);
}); });
// Move "super" allocation to the front // Move "super" allocation to the front
auto first = sorted_allocations.front(); auto first = sorted_allocations.front();
auto super = p.move_instruction(last, first); auto super = m.move_instruction(last, first);
// Replace each allocation with a load // Replace each allocation with a load
std::size_t offset = 0; std::size_t offset = 0;
for(auto alloc : allocations) for(auto alloc : allocations)
{ {
op::load op{alloc->get_shape(), offset}; op::load op{alloc->get_shape(), offset};
p.replace_instruction(alloc, op, {super}); m.replace_instruction(alloc, op, {super});
offset += alloc->get_shape().bytes(); offset += alloc->get_shape().bytes();
} }
std::vector<instruction_ref> args = {super}; std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args)); std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraphx::make_op("identity"), args); m.replace_instruction(ins, migraphx::make_op("identity"), args);
} }
} }
} }
......
...@@ -69,9 +69,9 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -69,9 +69,9 @@ static bool try_compute_shape(instruction_ref ins,
return try_compute_shape(ins, inputs, mods); return try_compute_shape(ins, inputs, mods);
} }
void eliminate_contiguous::apply(module& p) const void eliminate_contiguous::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// return instruction should have inputs with standard shape // return instruction should have inputs with standard shape
if(ins->name() == "@return") if(ins->name() == "@return")
...@@ -96,8 +96,8 @@ void eliminate_contiguous::apply(module& p) const ...@@ -96,8 +96,8 @@ void eliminate_contiguous::apply(module& p) const
auto c = op::contiguous{}; auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data()); auto l = m.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l); m.replace_instruction(arg, l);
} }
} }
} }
......
...@@ -8,21 +8,21 @@ ...@@ -8,21 +8,21 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(module& p) const void eliminate_identity::apply(module& m) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Skip the first instruction, since we always process the previous // Skip the first instruction, since we always process the previous
// instruction // instruction
if(ins == p.begin()) if(ins == m.begin())
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
if(i->name() == "identity") if(i->name() == "identity")
{ {
p.replace_instruction(i, i->inputs().front()); m.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end()); m.move_instruction(i, m.end());
} }
if(ins == last) if(ins == last)
{ {
...@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const ...@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const instruction_ref& identity_input = ins->inputs().front(); const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1) if(identity_input->outputs().size() == 1)
{ {
p.move_instruction(identity_input, i); m.move_instruction(identity_input, i);
// since this is the last instruction, removing it only // since this is the last instruction, removing it only
// requires changing "last" and calling remove below // requires changing "last" and calling remove below
last = std::prev(last); last = std::prev(last);
...@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const ...@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break; break;
} }
} }
p.remove_instructions(std::next(last), p.end()); m.remove_instructions(std::next(last), m.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -13,7 +13,7 @@ struct adjust_allocation ...@@ -13,7 +13,7 @@ struct adjust_allocation
{ {
allocation_model model; allocation_model model;
std::string name() const { return "adjust_allocation"; } std::string name() const { return "adjust_allocation"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct stream_race ...@@ -16,7 +16,7 @@ struct stream_race
instruction_ref before; instruction_ref before;
}; };
std::vector<stream_race> analyze_streams(const module& p, const stream_model& m); std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -13,7 +13,7 @@ struct module; ...@@ -13,7 +13,7 @@ struct module;
struct auto_contiguous struct auto_contiguous
{ {
std::string name() const { return "auto_contiguous"; } std::string name() const { return "auto_contiguous"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -33,7 +33,7 @@ struct check_context ...@@ -33,7 +33,7 @@ struct check_context
}; };
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
void apply(module& p) const { p.insert_instruction(p.begin(), op{}); } void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -19,7 +19,7 @@ struct eliminate_allocation ...@@ -19,7 +19,7 @@ struct eliminate_allocation
std::string allocation_op{}; std::string allocation_op{};
std::size_t alignment = 32; std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; } std::string name() const { return "eliminate_allocation"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct eliminate_common_subexpression struct eliminate_common_subexpression
{ {
std::string name() const { return "eliminate_common_subexpression"; } std::string name() const { return "eliminate_common_subexpression"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -18,7 +18,7 @@ struct eliminate_concat ...@@ -18,7 +18,7 @@ struct eliminate_concat
{ {
concat_optimization concat_opt; concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; } std::string name() const { return "eliminate_concat"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -17,7 +17,7 @@ struct eliminate_contiguous ...@@ -17,7 +17,7 @@ struct eliminate_contiguous
{ {
std::string op_name; std::string op_name;
std::string name() const { return "eliminate_contiguous"; } std::string name() const { return "eliminate_contiguous"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -18,7 +18,7 @@ struct module; ...@@ -18,7 +18,7 @@ struct module;
struct eliminate_identity struct eliminate_identity
{ {
std::string name() const { return "eliminate_identity"; } std::string name() const { return "eliminate_identity"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK) #if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L #if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1 #define MIGRAPHX_HAS_FILESYSTEM 1
#else #else
......
...@@ -9,7 +9,19 @@ namespace migraphx { ...@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name); operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v); operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -156,6 +156,19 @@ struct id_matcher ...@@ -156,6 +156,19 @@ struct id_matcher
} }
}; };
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
template <class M> template <class M>
struct basic_matcher struct basic_matcher
...@@ -167,8 +180,8 @@ struct basic_matcher ...@@ -167,8 +180,8 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result) if(result)
{ {
...@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base ...@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result struct matcher_result
{ {
std::unordered_map<std::string, instruction_ref> instructions; struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result; instruction_ref result;
}; };
...@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m) ...@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{ {
result.result = ins; result.result = ins;
result.instructions = ctx.instructions; result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
} }
else else
{ {
...@@ -533,6 +579,18 @@ auto skip_output(Ms... ms) ...@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
}); });
} }
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s) inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
......
...@@ -17,7 +17,7 @@ struct memory_coloring ...@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{}; std::string allocation_op{};
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory coloring"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK) #if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L #if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1 #define MIGRAPHX_HAS_OPTIONAL 1
#else #else
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct propagate_constant struct propagate_constant
{ {
std::string name() const { return "propagate_constant"; } std::string name() const { return "propagate_constant"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct rewrite_batchnorm struct rewrite_batchnorm
{ {
std::string name() const { return "rewrite_batchnorm"; } std::string name() const { return "rewrite_batchnorm"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct rewrite_pooling struct rewrite_pooling
{ {
std::string name() const { return "rewrite_pooling"; } std::string name() const { return "rewrite_pooling"; }
void apply(module& prog) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment