Unverified Commit 4a4f537e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into gpu-batch-gemm-bert

parents adc03be6 a27dd28c
......@@ -13,9 +13,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_concat::apply(module& p) const
void eliminate_concat::apply(module& m) const
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
// Look for the concat operator
if(ins->name() != concat_opt.name())
......@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std::sort(sorted_allocations.begin(),
sorted_allocations.end(),
[&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y);
return std::distance(m.begin(), x) < std::distance(m.begin(), y);
});
// Move "super" allocation to the front
auto first = sorted_allocations.front();
auto super = p.move_instruction(last, first);
auto super = m.move_instruction(last, first);
// Replace each allocation with a load
std::size_t offset = 0;
for(auto alloc : allocations)
{
op::load op{alloc->get_shape(), offset};
p.replace_instruction(alloc, op, {super});
m.replace_instruction(alloc, op, {super});
offset += alloc->get_shape().bytes();
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraphx::make_op("identity"), args);
m.replace_instruction(ins, migraphx::make_op("identity"), args);
}
}
}
......
......@@ -69,9 +69,9 @@ static bool try_compute_shape(instruction_ref ins,
return try_compute_shape(ins, inputs, mods);
}
void eliminate_contiguous::apply(module& p) const
void eliminate_contiguous::apply(module& m) const
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
// return instruction should have inputs with standard shape
if(ins->name() == "@return")
......@@ -96,8 +96,8 @@ void eliminate_contiguous::apply(module& p) const
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
auto l = m.add_literal(r.get_shape(), r.data());
m.replace_instruction(arg, l);
}
}
}
......
......@@ -8,21 +8,21 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(module& p) const
void eliminate_identity::apply(module& m) const
{
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == p.begin())
if(ins == m.begin())
continue;
const auto i = std::prev(ins);
if(i->name() == "identity")
{
p.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end());
m.replace_instruction(i, i->inputs().front());
m.move_instruction(i, m.end());
}
if(ins == last)
{
......@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1)
{
p.move_instruction(identity_input, i);
m.move_instruction(identity_input, i);
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
last = std::prev(last);
......@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break;
}
}
p.remove_instructions(std::next(last), p.end());
m.remove_instructions(std::next(last), m.end());
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -13,7 +13,7 @@ struct adjust_allocation
{
allocation_model model;
std::string name() const { return "adjust_allocation"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct stream_race
instruction_ref before;
};
std::vector<stream_race> analyze_streams(const module& p, const stream_model& m);
std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -13,7 +13,7 @@ struct module;
struct auto_contiguous
{
std::string name() const { return "auto_contiguous"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -33,7 +33,7 @@ struct check_context
};
std::string name() const { return "check_context"; }
void apply(module& p) const { p.insert_instruction(p.begin(), op{}); }
void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -19,7 +19,7 @@ struct eliminate_allocation
std::string allocation_op{};
std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct eliminate_common_subexpression
{
std::string name() const { return "eliminate_common_subexpression"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -18,7 +18,7 @@ struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -17,7 +17,7 @@ struct eliminate_contiguous
{
std::string op_name;
std::string name() const { return "eliminate_contiguous"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -18,7 +18,7 @@ struct module;
struct eliminate_identity
{
std::string name() const { return "eliminate_identity"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
#else
......
......@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v);
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher
template <class M>
struct basic_matcher
......@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result)
{
......@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result
{
std::unordered_map<std::string, instruction_ref> instructions;
struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result;
};
......@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result.result = ins;
result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
}
else
{
......@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
});
}
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
......
......@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{};
bool verify = false;
std::string name() const { return "memory coloring"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
......
......@@ -15,7 +15,7 @@ struct module;
struct propagate_constant
{
std::string name() const { return "propagate_constant"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct rewrite_batchnorm
{
std::string name() const { return "rewrite_batchnorm"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@ struct module;
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(module& prog) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment