Commit 2fc6b715 authored by Paul's avatar Paul
Browse files

Merge

parents 5967d68d 118e05c7
...@@ -33,15 +33,36 @@ ...@@ -33,15 +33,36 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void unregister_op(const std::string& op_name);
namespace detail {
struct op_handler
{
operation op;
std::string name;
op_handler(const operation& op_r) : op(op_r), name(op.name()){};
~op_handler() { unregister_op(name); }
};
} // namespace detail
void register_op_init();
void register_op(const operation& op); void register_op(const operation& op);
operation load_op(const std::string& name); operation load_op(const std::string& name);
bool has_op(const std::string& name); bool has_op(const std::string& name);
std::vector<std::string> get_operators(); std::vector<std::string> get_operators();
template <class T> template <class T>
void register_op() void register_op()
{ {
register_op(T{}); register_op_init(); // instantiate static op_map;
static auto op_h = detail::op_handler(T{});
register_op(op_h.op);
} }
struct register_op_action struct register_op_action
......
...@@ -33,14 +33,28 @@ ...@@ -33,14 +33,28 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void register_target_init();
void register_target(const target& t); void register_target(const target& t);
void unregister_target(const std::string& name);
target make_target(const std::string& name); target make_target(const std::string& name);
std::vector<std::string> get_targets(); std::vector<std::string> get_targets();
namespace detail {
struct target_handler
{
target t;
std::string target_name;
target_handler(const target& t_r) : t(t_r), target_name(t.name()) {}
~target_handler() { unregister_target(target_name); }
};
} // namespace detail
template <class T> template <class T>
void register_target() void register_target()
{ {
register_target(T{}); register_target_init();
static auto t_h = detail::target_handler(T{});
register_target(t_h.t);
} }
struct register_target_action struct register_target_action
......
...@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x) ...@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x)
} }
template <class T> template <class T>
auto from_value_impl(rank<4>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void()) auto from_value_impl(rank<4>, const value& v, T& x)
-> decltype(x.insert(*x.begin()), std::declval<typename T::mapped_type>(), void())
{ {
x.clear(); x.clear();
for(auto&& e : v) for(auto&& e : v)
......
...@@ -29,10 +29,12 @@ ...@@ -29,10 +29,12 @@
#include <ostream> #include <ostream>
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <set>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -87,12 +89,12 @@ struct shape ...@@ -87,12 +89,12 @@ struct shape
{ {
std::size_t min = 0; std::size_t min = 0;
std::size_t max = 0; std::size_t max = 0;
std::size_t opt = 0; std::set<std::size_t> optimals{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt")); return pack(f(self.min, "min"), f(self.max, "max"), f(self.optimals, "optimals"));
} }
bool is_fixed() const; bool is_fixed() const;
...@@ -132,11 +134,12 @@ struct shape ...@@ -132,11 +134,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank) // Construct a dynamic shape from vectors of mins, maxes, and optimals.
// optimals_list is a vector of optimals that corresponds to each min and max.
shape(type_t t, shape(type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts); std::vector<std::set<std::size_t>> optimals_list);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
...@@ -186,21 +189,21 @@ struct shape ...@@ -186,21 +189,21 @@ struct shape
/*! /*!
* Minimum lengths for dynamic shape. * Minimum lengths for dynamic shape.
* lens() for fixed shape. * lens() for static shape.
*/ */
std::vector<std::size_t> min_lens() const; std::vector<std::size_t> min_lens() const;
/*! /*!
* Maximum lengths for dynamic shape. * Maximum lengths for dynamic shape.
* lens() for fixed shape. * lens() for static shape.
*/ */
std::vector<std::size_t> max_lens() const; std::vector<std::size_t> max_lens() const;
/*! /*!
* Optimum lengths for dynamic shape. * Optimum lengths for dynamic shape.
* lens() for fixed shape. * Empty for static shape.
*/ */
std::vector<std::size_t> opt_lens() const; std::vector<std::set<std::size_t>> opt_lens() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
...@@ -253,9 +256,12 @@ struct shape ...@@ -253,9 +256,12 @@ struct shape
shape with_type(type_t t) const; shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape // convert the shape to an equivalent dynamic shape with empty optimals
shape to_dynamic() const; shape to_dynamic() const;
// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,20 +21,28 @@ ...@@ -21,20 +21,28 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/serialize.hpp> #ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <migraphx/context.hpp> #define MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <migraphx/ref/context.hpp>
#include <migraphx/functional.hpp>
#include <test.hpp>
TEST_CASE(context) #include <string>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Split dynamic dimension over submodules if exactly one dimension in the parameter list is
* dynamic.
*/
struct split_single_dyn_dim
{ {
migraphx::context ctx = migraphx::ref::context{}; std::string name() const { return "split_single_dyn_dim"; }
migraphx::value v = ctx.to_value(); void apply(module_pass_manager&) const;
EXPECT(v.empty()); };
migraphx::context cpu_ctx = migraphx::ref::context{}; } // namespace MIGRAPHX_INLINE_NS
cpu_ctx.from_value(v); } // namespace migraphx
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } #endif
...@@ -166,6 +166,7 @@ void module::assign(const module& m) ...@@ -166,6 +166,7 @@ void module::assign(const module& m)
auto s = ins->get_shape(); auto s = ins->get_shape();
copy_ins = impl->insert(impl->instructions.end(), copy_ins = impl->insert(impl->instructions.end(),
{builtin::param{name, order}, std::move(s), {}}); {builtin::param{name, order}, std::move(s), {}});
impl->nparams++;
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
...@@ -594,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const ...@@ -594,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const
} }
} }
std::vector<instruction_ref> module::get_returns() const
{
auto last = std::prev(this->end());
if(last->name() == "@return")
return last->inputs();
return {last};
}
instruction_ref module::validate() const instruction_ref module::validate() const
{ {
return std::find_if( return std::find_if(
......
...@@ -172,6 +172,22 @@ struct vector_stream ...@@ -172,6 +172,22 @@ struct vector_stream
} }
}; };
struct writer_stream
{
std::function<void(const char*, std::size_t)> writer;
writer_stream& write(const char* b, std::size_t n)
{
writer(b, n);
return *this;
}
};
void to_msgpack(const value& v, std::function<void(const char*, std::size_t)> writer)
{
writer_stream ws{std::move(writer)};
msgpack::pack(ws, v);
}
std::vector<char> to_msgpack(const value& v) std::vector<char> to_msgpack(const value& v)
{ {
vector_stream vs; vector_stream vs;
......
...@@ -94,7 +94,7 @@ struct onnx_parser ...@@ -94,7 +94,7 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
......
...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
auto dim_val = options.default_dim_value; auto dim_val = options.default_dim_value;
if(dim_val != 0) if(dim_val != 0)
{ {
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0}) if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1})
{ {
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value" MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value"); "set to non-default value");
} }
else else
{ {
parser.default_dyn_dim_value = {dim_val, dim_val, 0}; parser.default_dyn_dim_value = {dim_val, dim_val};
} }
} }
else else
......
...@@ -491,7 +491,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -491,7 +491,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return default_dyn_dim_value; return default_dyn_dim_value;
} }
std::size_t tmp = d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0}; return {tmp, tmp};
} }
else else
{ {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size(); auto n_dim = input_lens.size();
instruction_ref y_scale; instruction_ref y_scale = args[1];
if(args[1]->get_shape().elements() != 1) if(args[1]->get_shape().elements() != 1)
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else
{ auto common_args = add_common_args(*info.mod, {args[0], y_scale});
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); common_args.push_back(y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale); return info.add_instruction(make_op("quantizelinear"), common_args);
} }
}; };
......
...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
return info.add_instruction(make_op("reshape", {{"dims", dims}}), auto cont = info.add_instruction(make_op("contiguous"), args[0]);
info.make_contiguous(args[0])); return info.add_instruction(make_op("reshape", {{"dims", dims}}), cont);
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring_impl::run()
{
// calc implicit depdendencies
mod_implicit_deps = p_mod->calc_implicit_deps();
MIGRAPHX_DEBUG(dump("---Before memory coloring---"));
MIGRAPHX_DEBUG(dump_module());
build();
if(num_of_lives != 0)
{
MIGRAPHX_DEBUG(dump_intervals());
// Coloring
while(not alloc_queue.empty())
{
interval_ptr interval = alloc_queue.top();
allocate(interval);
alloc_queue.pop();
}
// rewrite happens after all modules are processed
rewrite();
if(enable_verify)
verify();
}
}
bool memory_coloring_impl::allocate(interval_ptr interval)
{
shape s = interval->result;
std::size_t size = s.bytes();
if(size == 0)
return false;
std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment;
int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
std::unordered_map<long long, live_range*> offset2_live;
offset2_live.clear();
if(conflict_table.find(vn) != conflict_table.end())
{
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
long long offset = range->offset;
if(offset != invalid_offset)
{
conflict_queue.push(range);
if(offset2_live.find(offset) == offset2_live.end())
{
offset2_live[offset] = range;
}
else
{
live_range* prev = offset2_live[offset];
assert(prev->offset == offset);
if(prev->size < range->size)
offset2_live[offset] = range;
}
}
}
}
std::size_t offset = 0;
while(not conflict_queue.empty())
{
live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset;
if(offset > iter_offset)
{
offset = std::max(offset, iter_offset + range->size);
}
else if(offset2_live[iter_offset] == range)
{
if((iter_offset > offset) && (iter_offset - offset) >= size)
{
break;
}
offset = iter_offset + range->size;
}
// alignment
if((offset % element_size) != 0)
offset += (element_size - (offset % element_size));
conflict_queue.pop();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset = (offset + 3) / 4 * 4;
segment.offset = offset;
MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size);
return true;
}
void memory_coloring_impl::build()
{
std::size_t num_of_instrs = p_mod->size();
if(num_of_instrs == 0)
return;
auto cur_points = num_of_instrs * 2;
instruction_ref iter = p_mod->end();
instruction_ref begin = p_mod->begin();
std::vector<instruction_ref> dead_instrs;
std::set<int> live_set;
// Build live intervals.
live_intervals.resize(num_of_instrs);
do
{
iter = std::prev(iter);
const instruction* p_iter = &(*iter);
interval_ptr def_interval = nullptr;
bool is_dead = false;
if(instr2_live.find(p_iter) != instr2_live.end())
{
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) or is_lit)
{
live_range& range = def_interval->segment;
def_interval->result = iter->get_shape();
def_interval->is_literal = is_lit;
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(not is_lit or unify_literals)
alloc_queue.push(def_interval);
live_set.erase(range.vn);
}
}
else if(not is_param(iter) && not is_outline(iter) && not is_check_context(iter))
{
is_dead = true;
}
auto inputs = iter->inputs();
if(contains(mod_implicit_deps, iter))
{
const auto& impl_deps = mod_implicit_deps.at(iter);
inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end());
}
for(auto&& arg : inputs)
{
if(not p_mod->has_instruction(arg))
continue;
if(is_param(arg) or is_outline(arg))
{
if(is_output_param(arg))
is_dead = false;
if(def_interval != nullptr)
{
def_interval->is_live_on_entry = true;
}
continue;
}
const instruction* p_arg = &(*instruction::get_output_alias(arg));
if(instr2_live.find(p_arg) == instr2_live.end())
{
// First time see a use, create a live interval.
int id = num_of_lives++;
interval_ptr interval = &(live_intervals[id]);
interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
earliest_end_point = cur_points;
if(latest_end_point == -1)
latest_end_point = cur_points;
}
else
{
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
assert(live_set.find(interval->id) != live_set.end());
}
}
if(is_dead)
dead_instrs.push_back(iter);
cur_points -= 2;
} while(iter != begin);
}
void memory_coloring_impl::rewrite()
{
std::vector<std::size_t> dims;
dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_mod->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_mod))
{
const instruction* p_iter = &(*ins);
if(instr2_live.find(p_iter) != instr2_live.end())
{
interval_ptr interval = instr2_live[p_iter];
if(interval->get_begin() == invalid_offset)
continue;
if(not unify_literals && interval->is_literal)
continue;
std::size_t offset = 0;
if(interval->get_offset() != invalid_offset)
{
offset = interval->get_offset();
}
else
{
assert(interval->result.bytes() == 0);
}
if(is_allocate(ins))
{
p_mod->replace_instruction(
ins,
make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}),
scratch_param);
}
}
}
MIGRAPHX_DEBUG(dump("---After rewrite---"));
MIGRAPHX_DEBUG(dump_module());
}
void memory_coloring_impl::verify()
{
if(num_of_lives > 0)
{
for(int i = 0; i < num_of_lives; ++i)
{
const live_interval& interval = live_intervals[i];
const live_range& segment = interval.segment;
if(segment.begin == invalid_offset)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
if(segment.offset == invalid_offset)
{
continue;
}
int vn = segment.vn;
if(conflict_table.find(vn) != conflict_table.end())
{
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
if(range->offset == invalid_offset)
continue;
if(not is_disjoin(*range, segment))
MIGRAPHX_THROW("range and segment is not disjoined");
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_module() { std::cout << *p_mod << std::endl; }
void memory_coloring_impl::dump_intervals()
{
if(num_of_lives > 0)
{
std::cout << "---live intervals ---" << std::endl;
for(int i = 0; i < num_of_lives; ++i)
{
live_interval& interval = live_intervals[i];
interval.dump();
}
std::cout << "---conflict table---" << std::endl;
for(int i = 0; i <= max_value_number; ++i)
{
std::cout << " segment:" << i;
std::cout << " =>";
const std::set<int>& table = conflict_table[i];
for(const auto& iter : table)
{
std::cout << (iter) << ",";
}
}
std::cout << std::endl;
}
}
// map liveness tracking point to instruction enum.
static int get_ins_enum(int x)
{
if(x > 0)
{
return (x / 2) - 1;
}
else
return invalid_offset;
}
void live_range::dump()
{
std::cout << " segment:" << vn;
std::cout << " [" << get_ins_enum(begin) << ", " << get_ins_enum(end) << "]";
if(offset != invalid_offset)
{
std::cout << " mem:";
std::cout << " [" << offset << "," << offset + size - 1 << "]";
}
std::cout << std::endl;
}
void live_interval::dump()
{
std::cout << "id:" << id;
segment.dump();
std::cout << " uses:";
for(const auto& iter : use_points)
{
std::cout << " " << get_ins_enum(iter) << ",";
}
std::cout << " def:";
std::cout << " " << get_ins_enum(def_point);
if(is_literal)
std::cout << " literal";
std::cout << " " << result;
std::cout << std::endl;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static const std::size_t invalid_offset = std::numeric_limits<std::size_t>::max();
struct live_range
{
std::size_t begin; // begin point in the instruction stream.
std::size_t end; // end point in the instruction stream.
std::size_t offset; // offset to base pointer of allocated memory trunk.
std::size_t vn; // value number that identifies this live_range.
std::size_t size; // size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
};
struct live_interval
{
live_interval() : segment({invalid_offset, invalid_offset, invalid_offset, invalid_offset, 0})
{
}
void add_use(std::size_t use) { use_points.push_front(use); }
std::size_t get_begin() const { return segment.begin; }
std::size_t get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; }
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
live_range segment;
std::size_t id = invalid_offset;
std::list<std::size_t> use_points{};
std::size_t def_point = invalid_offset;
shape result{};
bool is_literal = false;
bool is_live_on_entry = false;
};
using interval_ptr = live_interval*;
struct memory_coloring_impl
{
memory_coloring_impl(module* p, std::string alloc_op, bool p_verify)
: p_mod(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{
}
bool allocate(interval_ptr);
void add_conflicts(const std::set<int>& live_set, int val)
{
for(const auto& iter : live_set)
{
conflict_table[iter].insert(val);
conflict_table[val].insert(iter);
}
}
void build();
void run();
void rewrite();
private:
static bool is_param(const instruction_ref ins) { return ins->name() == "@param"; }
static bool is_output_param(const instruction_ref ins)
{
if(not is_param(ins))
return false;
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
return contains(param_name, "#output_");
}
bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; }
static bool is_check_context(const instruction_ref ins)
{
return ins->name() == "check_context";
}
static bool is_disjoin(const live_range& range1, const live_range& range2)
{
if((range1.size == 0) or (range2.size == 0))
return false;
auto end1 = range1.offset + range1.size - 1;
auto end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) or (end2 < range1.offset));
}
void verify();
#ifdef MIGRAPHX_DEBUG_OPT
void dump(const std::string&);
void dump_module();
void dump_intervals();
#endif
struct ordering
{
bool operator()(const interval_ptr& i1, const interval_ptr& i2) const
{
auto len1 = i1->get_end() - i1->get_begin();
auto len2 = i2->get_end() - i2->get_begin();
if(len1 != len2)
{
return (len1 < len2);
}
else if(i1->result.bytes() != i2->result.bytes())
{
return (i1->result.bytes() < i2->result.bytes());
}
else
{
return i1->id > i2->id;
}
}
bool operator()(const live_range* i1, const live_range* i2) const
{
return (i1->offset > i2->offset);
}
};
module* p_mod;
std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals.
std::vector<live_interval> live_intervals = {};
// Map live range value number to live range.
std::unordered_map<int, live_range*> live_ranges = {};
// Map live range value number to a set of conflicting live ranges' value numbers.
std::unordered_map<int, std::set<int>> conflict_table = {};
// Priority queue for coloring.
std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue{};
int num_of_lives = 0;
int max_value_number = -1;
std::size_t required_bytes = 0;
// The earliest program point where an live interval ends.
int earliest_end_point = -1;
// The latest program point where an live interval ends.
int latest_end_point = -1;
// Whether to unify literals into coloring.
bool unify_literals = false;
std::string allocation_op{};
bool enable_verify;
ins_dep_map mod_implicit_deps;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -86,12 +86,21 @@ struct module_pm : module_pass_manager ...@@ -86,12 +86,21 @@ struct module_pm : module_pass_manager
assert(mod); assert(mod);
return *mod; return *mod;
} }
virtual module* create_module(const std::string& name) override virtual module* create_module(const std::string& name) override
{ {
assert(prog); assert(prog);
return prog->create_module(name); return prog->create_module(name);
} }
virtual module* get_common_parent() override { return common_parent; } virtual module* get_common_parent() override { return common_parent; }
virtual module* get_root_module() override
{
assert(prog);
return prog->get_main_module();
}
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
assert(mod); assert(mod);
......
...@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os) ...@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os)
return [&](const char* x) { os << x; }; return [&](const char* x) { os << x; };
} }
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out) template <class F>
int exec(const std::string& cmd, const char* type, F f)
{ {
int ec = 0; int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl; std::cout << cmd << std::endl;
auto closer = [&](FILE* stream) { auto closer = [&](FILE* stream) {
auto status = pclose(stream); auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT ec = WIFEXITED(status) ? WEXITSTATUS(status) : 0; // NOLINT
}; };
{ {
// TODO: Use execve instead of popen // TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), type), closer); // NOLINT
if(not pipe) if(not pipe)
MIGRAPHX_THROW("popen() failed: " + cmd); MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer; f(pipe.get());
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data());
} }
return ec; return ec;
} }
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
{
return exec(cmd, "r", [&](FILE* f) {
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), f) != nullptr)
std_out(buffer.data());
});
}
int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
{
return exec(cmd, "w", [&](FILE* f) {
std_in([&](const char* buffer, std::size_t n) { std::fwrite(buffer, 1, n, f); });
});
}
struct process_impl struct process_impl
{ {
std::string command{}; std::string command{};
...@@ -72,6 +87,15 @@ struct process_impl ...@@ -72,6 +87,15 @@ struct process_impl
result += command; result += command;
return result; return result;
} }
template <class... Ts>
void check_exec(Ts&&... xs) const
{
int ec = migraphx::exec(std::forward<Ts>(xs)...);
if(ec != 0)
MIGRAPHX_THROW("Command " + get_command() + " exited with status " +
std::to_string(ec));
}
}; };
process::process(const std::string& cmd) : impl(std::make_unique<process_impl>()) process::process(const std::string& cmd) : impl(std::make_unique<process_impl>())
...@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p) ...@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p)
return *this; return *this;
} }
void process::exec() void process::exec() { impl->check_exec(impl->get_command(), redirect_to(std::cout)); }
void process::write(std::function<void(process::writer)> pipe_in)
{ {
auto ec = migraphx::exec(impl->get_command(), redirect_to(std::cout)); impl->check_exec(impl->get_command(), std::move(pipe_in));
if(ec != 0)
MIGRAPHX_THROW("Command " + impl->get_command() + " exited with status " +
std::to_string(ec));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -331,7 +331,8 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -331,7 +331,8 @@ std::vector<argument> generic_eval(const module* mod,
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params[param_name]; auto param = params[param_name];
// TODO: may want to check correct number of dimensions and/or was within bounds // TODO: may want to check correct number of dimensions and/or was within bounds
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape()) if(not ins->get_shape().any_of_dynamic() and
param.get_shape() != ins->get_shape())
{ {
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name + "} for parameter: " + param_name +
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void promote_literals::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module();
if(m.name() == "main")
return;
for(auto ins : iterator_for(m))
{
if(ins->name() == "@literal")
{
auto new_lit = root_module->add_literal(ins->get_literal());
for(auto out_ins : ins->outputs())
{
out_ins->replace_argument(out_ins, ins, new_lit);
}
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const void propagate_constant::apply(module& m) const
{ {
...@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const ...@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const
// Find instructions that can be evaluated to a literal // Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m)) for(auto i : iterator_for(m))
{ {
if(is_const(i) and i != last) const bool is_const = is_const_ins(i);
if(is_const and i != last)
continue; continue;
std::copy_if( if(i == last and is_const)
i->inputs().begin(), {
i->inputs().end(), const_instrs.insert(i);
std::inserter(const_instrs, const_instrs.begin()), }
[&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; }); else
{
std::copy_if(i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) {
return is_const_ins(ins) and ins->name() != "@literal";
});
}
} }
// Compute literals in parallel // Compute literals in parallel
......
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment