Commit 30c49503 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 870a396b 09aaa63e
......@@ -243,6 +243,9 @@ struct shape
/// Return true if the shape is dynamic
bool dynamic() const;
/// Return true if this shape or any of the sub_shapes are dynamic
bool any_of_dynamic() const;
shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_MEMORY_COLORING);
using instruction_set = std::unordered_set<instruction_ref>;
using instruction_set_map = std::unordered_map<instruction_ref, instruction_set>;
// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template <class F>
void liveness(const module& m, F f)
{
auto implicit_deps = m.calc_implicit_deps();
instruction_set live_set;
auto rp = reverse(m);
for(auto rins : iterator_for(rp)) // NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto ins = std::prev(rins.base());
// Add live variables
auto add_live_variables = [&](const auto& inputs) {
for(auto input : inputs)
{
auto i = instruction::get_output_alias(input);
// Skip if variable comes from parent
if(not m.has_instruction(i))
continue;
live_set.insert(i);
}
};
add_live_variables(ins->inputs());
add_live_variables(implicit_deps[ins]);
// Remove last usage
auto it = live_set.find(ins);
if(it != live_set.end())
{
live_set.erase(it);
f(ins, live_set);
}
}
}
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map build_conflict_table(const module& m, std::string allocation_op)
{
instruction_set_map conflict_table;
liveness(m, [&](auto ins, auto live_set) {
// Skip variables that aren't allocations
if(ins->name() != allocation_op)
return;
// Skip zero allocations
if(ins->get_shape().bytes() == 0)
return;
conflict_table[ins];
for(auto i : live_set)
{
if(i == ins)
continue;
// Skip variables that aren't allocations
if(i->name() != allocation_op)
continue;
// Skip zero allocations
if(i->get_shape().bytes() == 0)
continue;
conflict_table[i].insert(ins);
conflict_table[ins].insert(i);
}
});
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [](auto&& pp) {
return pp.second.count(pp.first) == 0;
}));
return conflict_table;
}
// Check if intervals overlap
bool is_overlap(std::pair<std::size_t, std::size_t> x, std::pair<std::size_t, std::size_t> y)
{
return std::max(x.first, y.first) < std::min(x.second, y.second);
}
struct allocation_segment
{
using segment = std::pair<std::size_t, std::size_t>;
std::unordered_map<instruction_ref, segment> ins2segment;
const segment* add_segment(instruction_ref ins, segment s) { return &(ins2segment[ins] = s); }
const segment* get_segment(instruction_ref ins) const
{
auto it = ins2segment.find(ins);
if(it == ins2segment.end())
return nullptr;
return &it->second;
}
// Remove segment for an instruction
void remove(instruction_ref ins)
{
auto it = ins2segment.find(ins);
if(it != ins2segment.end())
{
ins2segment.erase(it);
}
}
std::size_t max()
{
std::size_t n = 0;
for(auto&& pp : ins2segment)
{
auto seg = pp.second;
n = std::max(n, seg.second);
}
return n;
}
template <class Iterator>
static bool overlaps(Iterator first, Iterator last, const segment& s)
{
return std::any_of(first, last, [&](auto&& t) { return is_overlap(s, t); });
}
static bool overlaps(const std::set<segment>& segments, const segment& s)
{
return overlaps(segments.begin(), segments.end(), s);
}
static auto find_gap(const std::set<segment>& segments, std::size_t n)
{
std::size_t max_end = 0;
return std::adjacent_find(segments.begin(), segments.end(), [&](segment x, segment y) {
if(x.second < max_end)
return false;
max_end = x.second;
if(is_overlap(x, y))
return false;
assert(y.first >= x.second);
auto k = y.first - x.second;
return (k >= n);
});
}
static std::size_t max_type_size(const shape& s)
{
return std::accumulate(
s.sub_shapes().begin(),
s.sub_shapes().end(),
s.type_size(),
[](auto size, const auto& sub) { return std::max(size, max_type_size(sub)); });
}
static std::size_t compute_alignment(instruction_ref ins)
{
auto alignment = max_type_size(ins->get_shape());
// A rough estimate for the total number of elements
auto n = ins->get_shape().bytes() / alignment;
// Check for vectorized alignment
if(n > 4)
{
auto d = n % 4;
if(d == 0)
alignment *= 4;
if(d == 2)
alignment *= 2;
}
return alignment;
}
static segment
next_segment(std::set<segment>& segments, instruction_ref ins, std::size_t alignment)
{
assert(ins->get_shape().bytes() > 0);
// Compute alignment
auto n = 1 + (ins->get_shape().bytes() - 1) / alignment;
assert(n > 0);
auto start = 0;
// Insert at end if it cant fit at the begining
if(segments.empty() or segments.begin()->first <= n)
{
auto it = find_gap(segments, n);
if(it == segments.end())
it = std::max_element(segments.begin(), segments.end(), [&](segment x, segment y) {
return x.second < y.second;
});
if(it != segments.end())
start = it->second;
}
auto s = segment{start, start + n};
assert(not overlaps(segments, s));
segments.insert(s);
return s;
}
static std::unordered_map<instruction_ref, int>
create_allocation_index(const module& m, const instruction_set_map& conflict_table)
{
std::unordered_map<instruction_ref, int> result;
int i = 0;
for(auto ins : iterator_for(m))
{
if(not contains(conflict_table, ins))
continue;
result[ins] = i++;
}
return result;
}
// Build the allocation_color class from the conflict_table
static allocation_segment
build(const module& m, const instruction_set_map& conflict_table, std::size_t alignment)
{
allocation_segment as{};
std::vector<instruction_ref> conflict_queue;
// Add all allocations to the conflict_queue
std::transform(conflict_table.begin(),
conflict_table.end(),
std::back_inserter(conflict_queue),
[](auto&& pp) { return pp.first; });
auto alloc_index = create_allocation_index(m, conflict_table);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std::sort(conflict_queue.begin(), conflict_queue.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(
conflict_table.at(x).size(), x->get_shape().bytes(), alloc_index.at(x));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for(auto parent : conflict_queue)
{
// Sort children by size
std::vector<instruction_ref> children(conflict_table.at(parent).begin(),
conflict_table.at(parent).end());
std::sort(children.begin(), children.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(x->get_shape().bytes(), alloc_index.at(x));
}));
assert(not contains(children, parent));
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
assert(as.get_segment(parent) == nullptr);
as.add_segment(parent, next_segment(segments, parent, alignment));
}
// Reduce the number of segments
for(std::size_t n = 0; n < 3; n++)
{
for(auto parent : conflict_queue)
{
auto children = conflict_table.at(parent);
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
// Get the segment for the parent
const auto* parent_segment = as.get_segment(parent);
assert(parent_segment != nullptr);
auto s = next_segment(segments, parent, alignment);
if(s != *parent_segment and s.second <= as.max())
{
as.add_segment(parent, s);
}
}
}
return as;
}
};
static std::size_t find_max_alignment(const module& m, const std::string& allocation_op)
{
std::size_t alignment = 1;
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
alignment = std::max(allocation_segment::compute_alignment(ins), alignment);
}
return alignment;
}
void memory_coloring::apply(module& m) const
{
const std::size_t alignment = find_max_alignment(m, allocation_op);
auto conflict_table = build_conflict_table(m, allocation_op);
auto as = allocation_segment::build(m, conflict_table, alignment);
// All allocations should have a segment
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
return as.get_segment(pp.first);
}));
// Adjacent allocations should not have overlapping segments
assert(std::none_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
auto* x = as.get_segment(pp.first);
return std::any_of(pp.second.begin(), pp.second.end(), [&](auto ins) {
auto* y = as.get_segment(ins);
assert(x and y);
return is_overlap(*x, *y);
});
}));
// Print out segments
if(enabled(MIGRAPHX_DEBUG_MEMORY_COLORING{}))
{
for(auto&& pp : conflict_table)
{
std::cout << "------- conflict -------" << std::endl;
auto s1 = as.ins2segment.at(pp.first);
std::cout << s1.first << ", " << s1.second << ": ";
m.debug_print(pp.first);
for(auto ins : pp.second)
{
auto s2 = as.ins2segment.at(ins);
std::cout << s2.first << ", " << s2.second << ": ";
m.debug_print(ins);
}
}
}
// Total memory
std::size_t n = as.max() * alignment;
// Replace allocations
auto mem = m.add_parameter("scratch", shape{shape::int8_type, {n}});
for(auto&& [ins, seg] : as.ins2segment)
{
assert(ins->name() == allocation_op);
auto s = ins->get_shape();
std::size_t offset = seg.first * alignment;
assert(offset < n);
m.replace_instruction(ins, op::load{s, offset}, mem);
}
// Replace zero allocation
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
assert(ins->get_shape().bytes() == 0);
m.replace_instruction(ins, op::load{ins->get_shape(), 0}, mem);
}
// Remove scratch parameter if its not used
if(mem->outputs().empty())
{
m.remove_instruction(mem);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -166,6 +166,7 @@ void module::assign(const module& m)
auto s = ins->get_shape();
copy_ins = impl->insert(impl->instructions.end(),
{builtin::param{name, order}, std::move(s), {}});
impl->nparams++;
}
else if(ins->name() == "@outline")
{
......@@ -822,7 +823,8 @@ static void print_make_op(std::ostream& os, const operation& op)
static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx.shape(" << s.type_string() << ", lens=" << to_json_string(s.lens());
os << "migraphx.shape(type=" << to_json_string(s.type_string())
<< ", lens=" << to_json_string(s.lens());
if(not s.standard())
os << ", strides=" << to_json_string(s.strides());
os << ")";
......
......@@ -30,13 +30,16 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
/**
* Parameters:
* vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling
* normalize_attributes(op&, lens)
*
* See normalize_attribute.hpp for explaining the options.
*/
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
......@@ -151,6 +154,11 @@ auto tune_pad_attribute(const value& val)
return result;
}
/**
* Assumptions:
* Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input.
*/
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{
bool tuned = false;
......@@ -158,9 +166,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto val = op.to_value();
if(attrs.contains("normalize_padding"))
{
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size();
auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start))
......
......@@ -113,7 +113,8 @@ struct onnx_parser
void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size);
void parse_graph(module* mod, const onnx::GraphProto& graph);
std::vector<instruction_ref>
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
......@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if(model.has_graph())
{
this->parse_graph(mm, model.graph());
(void)this->parse_graph(mm, model.graph());
}
}
else
......@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if(model.has_graph())
{
this->parse_graph(mm, model.graph());
(void)this->parse_graph(mm, model.graph());
}
}
else
......@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version;
}
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
......@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; });
// add the return instuction
mod->add_return(output_ins);
if(not inlining)
{
// add the return instuction
mod->add_return(output_ins);
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
// remove instructions added in this mod
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
return output_ins;
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
......@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm>
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), a_arg, b_arg);
auto dot_ins = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3)
{
// TODO: support dynamic C input
if(std::any_of(args.cbegin(), args.cend(), [](auto in_arg) {
return in_arg->get_shape().dynamic();
}))
if(not float_equal(beta, 0.0f))
{
MIGRAPHX_THROW("PARSE_GEMM: C input not handled for dynamic input shapes");
}
if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0)
{
auto out_lens = a_arg->get_shape().lens();
out_lens.back() = b_arg->get_shape().lens().back();
auto c_arg = args[2];
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
auto c_arg = args[2];
if(dot_ins->get_shape().dynamic())
{
c_arg = info.add_instruction(make_op("multibroadcast"), args[2], dot_ins);
}
else
{
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
auto out_lens = a_arg->get_shape().lens();
out_lens.back() = b_arg->get_shape().lens().back();
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(
out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
}
auto beta_literal = info.add_literal(beta);
auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(beta_c->get_shape().type() != dot_type)
if(not float_equal(beta, 1.0f))
{
beta_c = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_c);
auto beta_literal = info.add_literal(beta);
c_arg = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(c_arg->get_shape().type() != dot_type)
{
c_arg = info.add_instruction(
make_op("convert", {{"target_type", dot_type}}), c_arg);
}
}
return info.add_instruction(make_op("add"), ret, beta_c);
return info.add_instruction(make_op("add"), dot_ins, c_arg);
}
}
return ret;
return dot_ins;
}
};
......
......@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!");
}
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if(args.front()->can_eval())
{
auto cond_arg = args.front()->eval();
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
return parser.parse_graph(mod, then_graph, true);
}
// else branch
else
{
return parser.parse_graph(mod, else_graph, true);
}
}
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
......@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph
parser.parse_graph(then_mdl, then_graph);
(void)parser.parse_graph(then_mdl, then_graph);
// parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph);
(void)parser.parse_graph(else_mdl, else_graph);
auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
......
......@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
(void)parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
......
......@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
// to decide whether MIGRAPHX can handle this slice.
if(args.size() == 5)
{
migraphx::argument step_arg = args.back()->eval();
......@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
// If axes arg is not given, the default is all of them.
if(op.axes.empty())
{
std::vector<int64_t> axes(args[0]->get_shape().lens().size());
std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes;
}
......@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size()))
{
if(steps[i] >= 0)
......@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
......
......@@ -21,20 +21,70 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/context.hpp>
#include <migraphx/ref/context.hpp>
#include <migraphx/functional.hpp>
#include <test.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
TEST_CASE(context)
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_trilu : op_parser<parse_trilu>
{
migraphx::context ctx = migraphx::ref::context{};
migraphx::value v = ctx.to_value();
EXPECT(v.empty());
std::vector<op_desc> operators() const { return {{"Trilu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
assert(input_shape.ndim() >= 2);
auto input_lens = input_shape.lens();
size_t num_rows = *(input_lens.rbegin() + 1);
size_t num_cols = input_lens.back();
int k = 0;
bool upper = true;
if(args.size() > 1)
{
auto arg_k = args[1]->eval();
check_arg_empty(arg_k, "PARSE_TRILU: dynamic k not supported");
k = arg_k.at<int>();
}
if(k < 0)
MIGRAPHX_THROW("PARSE_TRILU: negative k values not supported");
if(contains(info.attributes, "upper"))
{
upper = static_cast<bool>(info.attributes.at("upper").i());
}
shape::type_t output_type = args[0]->get_shape().type();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std::vector<bool> mask_mat(num_rows * num_cols, upper);
for(size_t i = 0; i < num_rows; i++)
{
for(size_t j = 0; j < std::min(k, static_cast<int>(num_cols)); j++)
{
mask_mat[i * num_cols + j] = not upper;
}
k++;
}
auto mask = info.add_literal(
migraphx::literal{migraphx::shape{output_type, {num_rows, num_cols}}, mask_mat});
migraphx::context cpu_ctx = migraphx::ref::context{};
cpu_ctx.from_value(v);
}
return info.add_broadcastable_binary_op("mul", mask, args[0]);
}
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
// TODO: broadcasting for dynamic shapes is only implemented
// for binary ops at time of writing, not ternary ops.
// When it becomes available, add multibroadcasting steps in the dynamic shape case.
// For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if(std::all_of(args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
if(args[1]->get_shape().lens() != lens)
else if(std::none_of(
args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
}
// If shapes are static and any are broadcasted, insert multibroadcast ops
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
{
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
if(args[1]->get_shape().lens() != lens)
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
else
MIGRAPHX_THROW("PARSE_WHERE: doesn't support mixed static and dynamic shape inputs");
}
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring_impl::run()
{
// calc implicit depdendencies
mod_implicit_deps = p_mod->calc_implicit_deps();
MIGRAPHX_DEBUG(dump("---Before memory coloring---"));
MIGRAPHX_DEBUG(dump_module());
build();
if(num_of_lives != 0)
{
MIGRAPHX_DEBUG(dump_intervals());
// Coloring
while(not alloc_queue.empty())
{
interval_ptr interval = alloc_queue.top();
allocate(interval);
alloc_queue.pop();
}
// rewrite happens after all modules are processed
rewrite();
if(enable_verify)
verify();
}
}
bool memory_coloring_impl::allocate(interval_ptr interval)
{
shape s = interval->result;
std::size_t size = s.bytes();
if(size == 0)
return false;
std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment;
int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
std::unordered_map<long long, live_range*> offset2_live;
offset2_live.clear();
if(conflict_table.find(vn) != conflict_table.end())
{
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
long long offset = range->offset;
if(offset != invalid_offset)
{
conflict_queue.push(range);
if(offset2_live.find(offset) == offset2_live.end())
{
offset2_live[offset] = range;
}
else
{
live_range* prev = offset2_live[offset];
assert(prev->offset == offset);
if(prev->size < range->size)
offset2_live[offset] = range;
}
}
}
}
std::size_t offset = 0;
while(not conflict_queue.empty())
{
live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset;
if(offset > iter_offset)
{
offset = std::max(offset, iter_offset + range->size);
}
else if(offset2_live[iter_offset] == range)
{
if((iter_offset > offset) && (iter_offset - offset) >= size)
{
break;
}
offset = iter_offset + range->size;
}
// alignment
if((offset % element_size) != 0)
offset += (element_size - (offset % element_size));
conflict_queue.pop();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset = (offset + 3) / 4 * 4;
segment.offset = offset;
MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size);
return true;
}
void memory_coloring_impl::build()
{
std::size_t num_of_instrs = p_mod->size();
if(num_of_instrs == 0)
return;
auto cur_points = num_of_instrs * 2;
instruction_ref iter = p_mod->end();
instruction_ref begin = p_mod->begin();
std::vector<instruction_ref> dead_instrs;
std::set<int> live_set;
// Build live intervals.
live_intervals.resize(num_of_instrs);
do
{
iter = std::prev(iter);
const instruction* p_iter = &(*iter);
interval_ptr def_interval = nullptr;
bool is_dead = false;
if(instr2_live.find(p_iter) != instr2_live.end())
{
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) or is_lit)
{
live_range& range = def_interval->segment;
def_interval->result = iter->get_shape();
def_interval->is_literal = is_lit;
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(not is_lit or unify_literals)
alloc_queue.push(def_interval);
live_set.erase(range.vn);
}
}
else if(not is_param(iter) && not is_outline(iter) && not is_check_context(iter))
{
is_dead = true;
}
auto inputs = iter->inputs();
if(contains(mod_implicit_deps, iter))
{
const auto& impl_deps = mod_implicit_deps.at(iter);
inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end());
}
for(auto&& arg : inputs)
{
if(not p_mod->has_instruction(arg))
continue;
if(is_param(arg) or is_outline(arg))
{
if(is_output_param(arg))
is_dead = false;
if(def_interval != nullptr)
{
def_interval->is_live_on_entry = true;
}
continue;
}
const instruction* p_arg = &(*instruction::get_output_alias(arg));
if(instr2_live.find(p_arg) == instr2_live.end())
{
// First time see a use, create a live interval.
int id = num_of_lives++;
interval_ptr interval = &(live_intervals[id]);
interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
earliest_end_point = cur_points;
if(latest_end_point == -1)
latest_end_point = cur_points;
}
else
{
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
assert(live_set.find(interval->id) != live_set.end());
}
}
if(is_dead)
dead_instrs.push_back(iter);
cur_points -= 2;
} while(iter != begin);
}
void memory_coloring_impl::rewrite()
{
std::vector<std::size_t> dims;
dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_mod->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_mod))
{
const instruction* p_iter = &(*ins);
if(instr2_live.find(p_iter) != instr2_live.end())
{
interval_ptr interval = instr2_live[p_iter];
if(interval->get_begin() == invalid_offset)
continue;
if(not unify_literals && interval->is_literal)
continue;
std::size_t offset = 0;
if(interval->get_offset() != invalid_offset)
{
offset = interval->get_offset();
}
else
{
assert(interval->result.bytes() == 0);
}
if(is_allocate(ins))
{
p_mod->replace_instruction(
ins,
make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}),
scratch_param);
}
}
}
MIGRAPHX_DEBUG(dump("---After rewrite---"));
MIGRAPHX_DEBUG(dump_module());
}
void memory_coloring_impl::verify()
{
if(num_of_lives > 0)
{
for(int i = 0; i < num_of_lives; ++i)
{
const live_interval& interval = live_intervals[i];
const live_range& segment = interval.segment;
if(segment.begin == invalid_offset)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
if(segment.offset == invalid_offset)
{
continue;
}
int vn = segment.vn;
if(conflict_table.find(vn) != conflict_table.end())
{
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
if(range->offset == invalid_offset)
continue;
if(not is_disjoin(*range, segment))
MIGRAPHX_THROW("range and segment is not disjoined");
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_module() { std::cout << *p_mod << std::endl; }
void memory_coloring_impl::dump_intervals()
{
if(num_of_lives > 0)
{
std::cout << "---live intervals ---" << std::endl;
for(int i = 0; i < num_of_lives; ++i)
{
live_interval& interval = live_intervals[i];
interval.dump();
}
std::cout << "---conflict table---" << std::endl;
for(int i = 0; i <= max_value_number; ++i)
{
std::cout << " segment:" << i;
std::cout << " =>";
const std::set<int>& table = conflict_table[i];
for(const auto& iter : table)
{
std::cout << (iter) << ",";
}
}
std::cout << std::endl;
}
}
// map liveness tracking point to instruction enum.
static int get_ins_enum(int x)
{
if(x > 0)
{
return (x / 2) - 1;
}
else
return invalid_offset;
}
void live_range::dump()
{
std::cout << " segment:" << vn;
std::cout << " [" << get_ins_enum(begin) << ", " << get_ins_enum(end) << "]";
if(offset != invalid_offset)
{
std::cout << " mem:";
std::cout << " [" << offset << "," << offset + size - 1 << "]";
}
std::cout << std::endl;
}
void live_interval::dump()
{
std::cout << "id:" << id;
segment.dump();
std::cout << " uses:";
for(const auto& iter : use_points)
{
std::cout << " " << get_ins_enum(iter) << ",";
}
std::cout << " def:";
std::cout << " " << get_ins_enum(def_point);
if(is_literal)
std::cout << " literal";
std::cout << " " << result;
std::cout << std::endl;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static const std::size_t invalid_offset = std::numeric_limits<std::size_t>::max();
struct live_range
{
std::size_t begin; // begin point in the instruction stream.
std::size_t end; // end point in the instruction stream.
std::size_t offset; // offset to base pointer of allocated memory trunk.
std::size_t vn; // value number that identifies this live_range.
std::size_t size; // size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
};
struct live_interval
{
live_interval() : segment({invalid_offset, invalid_offset, invalid_offset, invalid_offset, 0})
{
}
void add_use(std::size_t use) { use_points.push_front(use); }
std::size_t get_begin() const { return segment.begin; }
std::size_t get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; }
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
live_range segment;
std::size_t id = invalid_offset;
std::list<std::size_t> use_points{};
std::size_t def_point = invalid_offset;
shape result{};
bool is_literal = false;
bool is_live_on_entry = false;
};
using interval_ptr = live_interval*;
struct memory_coloring_impl
{
memory_coloring_impl(module* p, std::string alloc_op, bool p_verify)
: p_mod(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{
}
bool allocate(interval_ptr);
void add_conflicts(const std::set<int>& live_set, int val)
{
for(const auto& iter : live_set)
{
conflict_table[iter].insert(val);
conflict_table[val].insert(iter);
}
}
void build();
void run();
void rewrite();
private:
static bool is_param(const instruction_ref ins) { return ins->name() == "@param"; }
static bool is_output_param(const instruction_ref ins)
{
if(not is_param(ins))
return false;
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
return contains(param_name, "#output_");
}
bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; }
static bool is_check_context(const instruction_ref ins)
{
return ins->name() == "check_context";
}
static bool is_disjoin(const live_range& range1, const live_range& range2)
{
if((range1.size == 0) or (range2.size == 0))
return false;
auto end1 = range1.offset + range1.size - 1;
auto end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) or (end2 < range1.offset));
}
void verify();
#ifdef MIGRAPHX_DEBUG_OPT
void dump(const std::string&);
void dump_module();
void dump_intervals();
#endif
struct ordering
{
bool operator()(const interval_ptr& i1, const interval_ptr& i2) const
{
auto len1 = i1->get_end() - i1->get_begin();
auto len2 = i2->get_end() - i2->get_begin();
if(len1 != len2)
{
return (len1 < len2);
}
else if(i1->result.bytes() != i2->result.bytes())
{
return (i1->result.bytes() < i2->result.bytes());
}
else
{
return i1->id > i2->id;
}
}
bool operator()(const live_range* i1, const live_range* i2) const
{
return (i1->offset > i2->offset);
}
};
module* p_mod;
std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals.
std::vector<live_interval> live_intervals = {};
// Map live range value number to live range.
std::unordered_map<int, live_range*> live_ranges = {};
// Map live range value number to a set of conflicting live ranges' value numbers.
std::unordered_map<int, std::set<int>> conflict_table = {};
// Priority queue for coloring.
std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue{};
int num_of_lives = 0;
int max_value_number = -1;
std::size_t required_bytes = 0;
// The earliest program point where an live interval ends.
int earliest_end_point = -1;
// The latest program point where an live interval ends.
int latest_end_point = -1;
// Whether to unify literals into coloring.
bool unify_literals = false;
std::string allocation_op{};
bool enable_verify;
ins_dep_map mod_implicit_deps;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,18 +21,27 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include "memory_coloring_impl.hpp"
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& m) const
void optimize_module::apply(module_pass_manager& mpm) const
{
if(not enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
for(int i = 0; i < 2; i++)
{
memory_coloring_impl opt(&m, allocation_op, verify);
opt.run();
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{});
mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{});
mpm.run_pass(dead_code_elimination{});
}
}
......
......@@ -39,6 +39,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_PASSES);
void validate_pass(module& mod, const pass& p, tracer trace)
{
......@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
virtual void run_pass(const pass& p) override
{
assert(mod);
timer ts{};
using seconds = std::chrono::duration<double>;
trace("Module: ", mod->name(), ", Pass: ", p.name());
const double t1 = ts.record<seconds>();
assert(mod->validate() == mod->end());
p.apply(*this);
if(enabled(MIGRAPHX_TIME_PASSES{}))
{
using milliseconds = std::chrono::duration<double, std::milli>;
auto ms = time<milliseconds>([&] { p.apply(*this); });
std::cout << p.name() << ": " << ms << "ms\n";
}
else
{
p.apply(*this);
}
trace(*mod);
validate_pass(*mod, p, *t);
const double t2 = ts.record<seconds>();
trace("Pass: ", p.name(), " completed in (s): ", (t2 - t1));
}
};
......
......@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
assert(not this->is_compiled());
this->impl->target_name = t.name();
this->impl->ctx = t.get_context();
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout};
options.trace(*this);
options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace);
auto mods = this->get_modules();
// Validate and finalize
for(const auto& mod : reverse(mods))
{
......@@ -336,7 +334,8 @@ std::vector<argument> generic_eval(const module* mod,
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
{
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name);
"} for parameter: " + param_name +
" should be: " + to_string(ins->get_shape()));
}
return param;
}));
......@@ -380,7 +379,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert(results.find(ins) != results.end());
if(not ins->get_shape().dynamic())
if(not ins->get_shape().any_of_dynamic())
{
assert(results.at(ins).get_shape() == ins->get_shape());
}
......
......@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins)
return false;
}
bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const
{
......@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const
// Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m))
{
if(is_const(i) and i != last)
const bool is_const = is_const_ins(i);
if(is_const and i != last)
continue;
std::copy_if(
i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; });
if(i == last and is_const)
{
const_instrs.insert(i);
}
else
{
std::copy_if(i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) {
return is_const_ins(ins) and ins->name() != "@literal";
});
}
}
// Compute literals in parallel
......
......@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("is_compiled", &migraphx::program::is_compiled)
.def(
"compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) {
[](migraphx::program& p,
const migraphx::target& t,
bool offload_copy,
bool fast_math,
bool exhaustive_tune) {
migraphx::compile_options options;
options.offload_copy = offload_copy;
options.fast_math = fast_math;
options.offload_copy = offload_copy;
options.fast_math = fast_math;
options.exhaustive_tune = exhaustive_tune;
p.compile(t, options);
},
py::arg("t"),
py::arg("offload_copy") = true,
py::arg("fast_math") = true)
py::arg("offload_copy") = true,
py::arg("fast_math") = true,
py::arg("exhaustive_tune") = false)
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def(
"create_module",
......
......@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map()
static std::unordered_map<std::string, operation> m; // NOLINT
return m;
}
void register_op_init() { (void)op_map(); }
void register_op(const operation& op) { op_map()[op.name()] = op; }
void unregister_op(const std::string& op_name)
{
assert(op_map().count(op_name));
op_map().erase(op_name);
}
operation load_op(const std::string& name)
{
return at(op_map(), name, "Operator not found: " + name);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment