Commit 4c031df7 authored by wsttiger's avatar wsttiger
Browse files

Fixed conflicts

parents d32653a5 ed5f9897
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
inline std::string inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace) replace_string(std::string subject, const std::string& search, const std::string& replace)
...@@ -85,6 +87,7 @@ inline std::string to_string(const T& x) ...@@ -85,6 +87,7 @@ inline std::string to_string(const T& x)
return ss.str(); return ss.str();
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -10,8 +10,10 @@ ...@@ -10,8 +10,10 @@
#include <vector> #include <vector>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/pass.hpp> #include <migraph/pass.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -242,6 +244,7 @@ inline const ValueType& any_cast(const target& x) ...@@ -242,6 +244,7 @@ inline const ValueType& any_cast(const target& x)
#endif #endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/float_equal.hpp> #include <migraph/float_equal.hpp>
#include <migraph/requires.hpp> #include <migraph/requires.hpp>
#include <migraph/config.hpp>
#include <iostream> #include <iostream>
#include <utility> #include <utility>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <class T> template <class T>
struct tensor_view struct tensor_view
...@@ -167,6 +169,7 @@ tensor_view<T> make_view(shape s, T* data) ...@@ -167,6 +169,7 @@ tensor_view<T> make_view(shape s, T* data)
return {s, data}; return {s, data};
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP #define MIGRAPH_GUARD_RTGLIB_TIME_HPP
#include <chrono> #include <chrono>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <class Duration, class F> template <class Duration, class F>
auto time(F f) auto time(F f)
...@@ -14,6 +16,7 @@ auto time(F f) ...@@ -14,6 +16,7 @@ auto time(F f)
return std::chrono::duration_cast<Duration>(finish - start).count(); return std::chrono::duration_cast<Duration>(finish - start).count();
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#include <ostream> #include <ostream>
#include <migraph/functional.hpp> #include <migraph/functional.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct tracer struct tracer
{ {
...@@ -28,6 +30,7 @@ struct tracer ...@@ -28,6 +30,7 @@ struct tracer
std::ostream* os = nullptr; std::ostream* os = nullptr;
}; };
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP #define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string> #include <string>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <class PrivateMigraphTypeNameProbe> template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name() const std::string& get_type_name()
...@@ -39,6 +41,7 @@ const std::string& get_type_name(const T&) ...@@ -39,6 +41,7 @@ const std::string& get_type_name(const T&)
return migraph::get_type_name<T>(); return migraph::get_type_name<T>();
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -10,8 +10,10 @@ ...@@ -10,8 +10,10 @@
#include <type_traits> #include <type_traits>
#include <migraph/half.hpp> #include <migraph/half.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ #define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \ template <class X> \
...@@ -28,6 +30,7 @@ MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) ...@@ -28,6 +30,7 @@ MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
#include <numeric> #include <numeric>
#include <migraph/float_equal.hpp> #include <migraph/float_equal.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -170,5 +172,7 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n ...@@ -170,5 +172,7 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#include <migraph/verify.hpp> #include <migraph/verify.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
inline bool verify_args(const std::string& name, inline bool verify_args(const std::string& name,
const argument& cpu_arg, const argument& cpu_arg,
...@@ -82,6 +84,7 @@ inline bool verify_args(const std::string& name, ...@@ -82,6 +84,7 @@ inline bool verify_args(const std::string& name,
return passed; return passed;
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraph/erase.hpp> #include <migraph/erase.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args) instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
...@@ -182,4 +183,5 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg ...@@ -182,4 +183,5 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
return op.compute_shape(compute_shapes(args)); return op.compute_shape(compute_shapes(args));
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/ranges.hpp> #include <migraph/ranges.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct unknown struct unknown
{ {
std::string op; std::string op;
...@@ -50,6 +51,9 @@ struct onnx_parser ...@@ -50,6 +51,9 @@ struct onnx_parser
{ {
add_generic_op("MatMul", op::dot{}); add_generic_op("MatMul", op::dot{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{});
add_broadcastable_binary_op("Add", op::add{}); add_broadcastable_binary_op("Add", op::add{});
add_broadcastable_binary_op("Div", op::div{}); add_broadcastable_binary_op("Div", op::div{});
...@@ -74,6 +78,7 @@ struct onnx_parser ...@@ -74,6 +78,7 @@ struct onnx_parser
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat); add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
} }
template <class F> template <class F>
...@@ -426,6 +431,18 @@ struct onnx_parser ...@@ -426,6 +431,18 @@ struct onnx_parser
return prog.add_instruction(migraph::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraph::op::add{}, img_scaled, bias_bcast);
} }
instruction_ref
parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::vector<int64_t> perm{};
if(contains(attributes, "perm"))
{
auto&& perm_vals = attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return prog.add_instruction(migraph::op::transpose{perm}, args.front());
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
...@@ -586,7 +603,7 @@ struct onnx_parser ...@@ -586,7 +603,7 @@ struct onnx_parser
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()}; case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()};
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error(""); case onnx::TensorProto::FLOAT16: return literal{{shape::half_type, dims}, s.data()};
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()}; case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()};
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
...@@ -614,7 +631,8 @@ struct onnx_parser ...@@ -614,7 +631,8 @@ struct onnx_parser
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error(""); case onnx::TensorProto::FLOAT16:
return literal{{shape::half_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return literal{
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()}; {shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
...@@ -645,8 +663,7 @@ struct onnx_parser ...@@ -645,8 +663,7 @@ struct onnx_parser
break; // throw std::runtime_error("Unsupported type STRING"); break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL"); break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
break; // throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break; case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break; case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break; case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
...@@ -693,4 +710,5 @@ program parse_onnx(const std::string& name) ...@@ -693,4 +710,5 @@ program parse_onnx(const std::string& name)
return std::move(parser.prog); return std::move(parser.prog);
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/pass_config.hpp> #include <migraph/pass_config.hpp>
#include <migraph/config.hpp>
#include <set> #include <set>
#include <list> #include <list>
...@@ -13,6 +14,7 @@ ...@@ -13,6 +14,7 @@
#include <queue> #include <queue>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
//#define MIGRAPH_DEBUG_OPT //#define MIGRAPH_DEBUG_OPT
...@@ -21,6 +23,8 @@ namespace migraph { ...@@ -21,6 +23,8 @@ namespace migraph {
#else #else
#define MIGRAPH_DEBUG(s) #define MIGRAPH_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT #endif // MIGRAPH_DEBUG_OPT
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
...@@ -2,13 +2,16 @@ ...@@ -2,13 +2,16 @@
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
void memory_coloring::apply(program& p) const void memory_coloring::apply(program& p) const
{ {
if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p, allocation_op); memory_coloring_impl opt(&p, allocation_op, verify);
opt.run(); opt.run();
} }
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
void memory_coloring_impl::run() void memory_coloring_impl::run()
{ {
MIGRAPH_DEBUG(dump("---Before memory coloring---")); MIGRAPH_DEBUG(dump("---Before memory coloring---"));
MIGRAPH_DEBUG(dump_program()); MIGRAPH_DEBUG(dump_program());
register_operand_alias();
build(); build();
if(num_of_lives != 0) if(num_of_lives != 0)
{ {
...@@ -19,7 +19,8 @@ void memory_coloring_impl::run() ...@@ -19,7 +19,8 @@ void memory_coloring_impl::run()
alloc_queue.pop(); alloc_queue.pop();
} }
rewrite(); rewrite();
MIGRAPH_DEBUG(verify()); if(enable_verify)
verify();
} }
} }
...@@ -129,11 +130,8 @@ void memory_coloring_impl::build() ...@@ -129,11 +130,8 @@ void memory_coloring_impl::build()
{ {
is_dead = true; is_dead = true;
} }
int tie_ndx = get_input_tie_ndx(iter);
int cnt = -1;
for(auto&& arg : iter->inputs()) for(auto&& arg : iter->inputs())
{ {
cnt++;
if(is_param(arg) || is_outline(arg)) if(is_param(arg) || is_outline(arg))
{ {
if(is_output_param(arg)) if(is_output_param(arg))
...@@ -144,15 +142,8 @@ void memory_coloring_impl::build() ...@@ -144,15 +142,8 @@ void memory_coloring_impl::build()
} }
continue; continue;
} }
const instruction* p_arg = &(*arg); const instruction* p_arg = &(*instruction::get_output_alias(arg));
if(cnt == tie_ndx && (def_interval != nullptr)) if(instr2_live.find(p_arg) == instr2_live.end())
{
// input memory is used as this instruction's output.
// def is considered as use. Coalesce the live intervals.
def_interval->add_use(cur_points);
instr2_live[p_arg] = def_interval;
}
else if(instr2_live.find(p_arg) == instr2_live.end())
{ {
// First time see a use, create a live interval. // First time see a use, create a live interval.
int id = num_of_lives++; int id = num_of_lives++;
...@@ -182,23 +173,6 @@ void memory_coloring_impl::build() ...@@ -182,23 +173,6 @@ void memory_coloring_impl::build()
} while(iter != begin); } while(iter != begin);
} }
void memory_coloring_impl::register_operand_alias()
{
operand_alias["hip::allocate"] = -1;
operand_alias["hip::load_literal"] = -1;
operand_alias["@outline"] = -1;
operand_alias["check_context"] = -1;
operand_alias["@literal"] = -1;
operand_alias["@param"] = -1;
operand_alias["transpose"] = 0;
operand_alias["flatten"] = 0;
operand_alias["broadcast"] = 0;
operand_alias["identity"] = 0;
operand_alias["reshape"] = 0;
operand_alias["pass"] = 0;
operand_alias["scalar"] = 0;
}
void memory_coloring_impl::rewrite() void memory_coloring_impl::rewrite()
{ {
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
...@@ -248,37 +222,6 @@ void memory_coloring_impl::rewrite() ...@@ -248,37 +222,6 @@ void memory_coloring_impl::rewrite()
MIGRAPH_DEBUG(dump_program()); MIGRAPH_DEBUG(dump_program());
} }
#ifdef MIGRAPH_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_program() { std::cout << *p_program << std::endl; }
void memory_coloring_impl::dump_intervals()
{
if(num_of_lives > 0)
{
std::cout << "---live intervals ---" << std::endl;
for(int i = 0; i < num_of_lives; ++i)
{
live_interval& interval = live_intervals[i];
interval.dump();
}
std::cout << "---conflict table---" << std::endl;
for(int i = 0; i <= max_value_number; ++i)
{
std::cout << " segment:" << i;
std::cout << " =>";
std::set<int>& table = conflict_table[i];
for(auto& iter : table)
{
std::cout << (iter) << ",";
}
}
std::cout << std::endl;
}
}
void memory_coloring_impl::verify() void memory_coloring_impl::verify()
{ {
if(num_of_lives > 0) if(num_of_lives > 0)
...@@ -290,7 +233,9 @@ void memory_coloring_impl::verify() ...@@ -290,7 +233,9 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset) if(segment.begin == invalid_offset)
{ {
assert(interval.is_live_on_entry); // TODO: This check breaks on the tests
// if(!interval.is_live_on_entry)
// MIGRAPH_THROW("interval is not live on entry");
continue; continue;
} }
...@@ -308,10 +253,41 @@ void memory_coloring_impl::verify() ...@@ -308,10 +253,41 @@ void memory_coloring_impl::verify()
if(range->offset == invalid_offset) if(range->offset == invalid_offset)
continue; continue;
if(!is_disjoin(*range, segment)) if(!is_disjoin(*range, segment))
assert(false); MIGRAPH_THROW("range and segment is not disjoined");
}
}
}
}
}
#ifdef MIGRAPH_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_program() { std::cout << *p_program << std::endl; }
void memory_coloring_impl::dump_intervals()
{
if(num_of_lives > 0)
{
std::cout << "---live intervals ---" << std::endl;
for(int i = 0; i < num_of_lives; ++i)
{
live_interval& interval = live_intervals[i];
interval.dump();
} }
std::cout << "---conflict table---" << std::endl;
for(int i = 0; i <= max_value_number; ++i)
{
std::cout << " segment:" << i;
std::cout << " =>";
std::set<int>& table = conflict_table[i];
for(auto& iter : table)
{
std::cout << (iter) << ",";
} }
} }
std::cout << std::endl;
} }
} }
...@@ -357,4 +333,6 @@ void live_interval::dump() ...@@ -357,4 +333,6 @@ void live_interval::dump()
} }
#endif #endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP #ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP #define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
static const int invalid_offset = -1; static const int invalid_offset = -1;
...@@ -50,8 +52,8 @@ using interval_ptr = live_interval*; ...@@ -50,8 +52,8 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl struct memory_coloring_impl
{ {
memory_coloring_impl(program* p, std::string alloc_op) memory_coloring_impl(program* p, std::string alloc_op, bool p_verify)
: p_program(p), allocation_op(std::move(alloc_op)) : p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{ {
instr2_live.clear(); instr2_live.clear();
live_ranges.clear(); live_ranges.clear();
...@@ -59,7 +61,6 @@ struct memory_coloring_impl ...@@ -59,7 +61,6 @@ struct memory_coloring_impl
num_of_lives = 0; num_of_lives = 0;
max_value_number = -1; max_value_number = -1;
required_bytes = 0; required_bytes = 0;
operand_alias.clear();
earliest_end_point = -1; earliest_end_point = -1;
latest_end_point = -1; latest_end_point = -1;
unify_literals = false; unify_literals = false;
...@@ -75,7 +76,6 @@ struct memory_coloring_impl ...@@ -75,7 +76,6 @@ struct memory_coloring_impl
} }
void build(); void build();
void run(); void run();
void register_operand_alias();
void rewrite(); void rewrite();
private: private:
...@@ -92,31 +92,6 @@ struct memory_coloring_impl ...@@ -92,31 +92,6 @@ struct memory_coloring_impl
return ins->name() == "check_context"; return ins->name() == "check_context";
} }
// get operand alias info. This is a temporary workaround.
int get_input_tie_ndx(const instruction_ref ins)
{
std::string name = ins->name();
if(operand_alias.find(name) != operand_alias.end())
return operand_alias[name];
if(is_allocate(ins))
{
// This happens to custom allocators.
operand_alias[name] = -1;
return -1;
}
int cnt = -1;
int last_allocate = -1;
for(auto&& arg : ins->inputs())
{
cnt++;
if(is_allocate(arg) || is_output_param(arg))
last_allocate = cnt;
}
assert(last_allocate != -1);
operand_alias[name] = last_allocate;
return last_allocate;
}
#ifdef MIGRAPH_DEBUG_OPT
static bool is_disjoin(live_range& range1, live_range& range2) static bool is_disjoin(live_range& range1, live_range& range2)
{ {
if((range1.size == 0) || (range2.size == 0)) if((range1.size == 0) || (range2.size == 0))
...@@ -125,10 +100,11 @@ struct memory_coloring_impl ...@@ -125,10 +100,11 @@ struct memory_coloring_impl
long long end2 = range2.offset + range2.size - 1; long long end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset)); return ((end1 < range2.offset) || (end2 < range1.offset));
} }
void verify();
#ifdef MIGRAPH_DEBUG_OPT
void dump(const std::string&); void dump(const std::string&);
void dump_program(); void dump_program();
void dump_intervals(); void dump_intervals();
void verify();
#endif #endif
struct ordering struct ordering
{ {
...@@ -164,7 +140,6 @@ struct memory_coloring_impl ...@@ -164,7 +140,6 @@ struct memory_coloring_impl
std::unordered_map<int, std::set<int>> conflict_table; std::unordered_map<int, std::set<int>> conflict_table;
// Priority queue for coloring. // Priority queue for coloring.
std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue; std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue;
std::unordered_map<std::string, int> operand_alias;
int num_of_lives; int num_of_lives;
int max_value_number; int max_value_number;
...@@ -176,6 +151,9 @@ struct memory_coloring_impl ...@@ -176,6 +151,9 @@ struct memory_coloring_impl
// Whether to unify literals into coloring. // Whether to unify literals into coloring.
bool unify_literals; bool unify_literals;
std::string allocation_op{}; std::string allocation_op{};
bool enable_verify;
}; };
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <utility> #include <utility>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE) MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE)
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL) MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL)
...@@ -281,7 +282,7 @@ void program::compile(const target& t, tracer trace) ...@@ -281,7 +282,7 @@ void program::compile(const target& t, tracer trace)
{ {
assert(this->validate() == impl->instructions.end()); assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
if(not trace.enabled() or enabled(MIGRAPH_TRACE_COMPILE{})) if(enabled(MIGRAPH_TRACE_COMPILE{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
trace(*this); trace(*this);
trace(); trace();
...@@ -498,4 +499,6 @@ std::ostream& operator<<(std::ostream& os, const program& p) ...@@ -498,4 +499,6 @@ std::ostream& operator<<(std::ostream& os, const program& p)
print_program(os, p, [](auto&&...) {}); print_program(os, p, [](auto&&...) {});
return os; return os;
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <iostream> #include <iostream>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct shape_impl struct shape_impl
{ {
...@@ -190,4 +191,5 @@ std::ostream& operator<<(std::ostream& os, const shape& x) ...@@ -190,4 +191,5 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os; return os;
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
...@@ -60,4 +61,5 @@ struct find_add_lit_broadcast ...@@ -60,4 +61,5 @@ struct find_add_lit_broadcast
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); } void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <unordered_set> #include <unordered_set>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
bool is_reshaper(const std::string& name) bool is_reshaper(const std::string& name)
{ {
...@@ -59,4 +60,5 @@ void simplify_reshapes::apply(program& p) const ...@@ -59,4 +60,5 @@ void simplify_reshapes::apply(program& p) const
} }
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <blaze/math/CustomMatrix.h> #include <blaze/math/CustomMatrix.h>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace cpu { namespace cpu {
template <class T> template <class T>
...@@ -93,5 +94,5 @@ void migemm( ...@@ -93,5 +94,5 @@ void migemm(
} }
} // namespace cpu } // namespace cpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment