Commit 58b00ddf authored by mei-ye's avatar mei-ye
Browse files

change unit test and pass in allocate

parent 39eedcd9
...@@ -9,6 +9,7 @@ struct program; ...@@ -9,6 +9,7 @@ struct program;
struct memory_coloring struct memory_coloring
{ {
std::string allocation_op{};
std::string name() const { return "memory coloring"; } std::string name() const { return "memory coloring"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -7,7 +7,7 @@ void memory_coloring::apply(program& p) const ...@@ -7,7 +7,7 @@ void memory_coloring::apply(program& p) const
{ {
if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p); memory_coloring_impl opt(&p, allocation_op);
opt.run(); opt.run();
} }
} }
......
...@@ -222,8 +222,9 @@ void memory_coloring_impl::rewrite() ...@@ -222,8 +222,9 @@ void memory_coloring_impl::rewrite()
if(is_allocate(ins)) if(is_allocate(ins))
{ {
p_program->replace_instruction( if(!ins->arguments.empty())
ins, load{ins->arguments.at(0)->result, offset}, scratch_param); p_program->replace_instruction(
ins, load{ins->arguments.at(0)->result, offset}, scratch_param);
} }
else if(is_literal(ins)) else if(is_literal(ins))
{ {
......
...@@ -50,7 +50,7 @@ using interval_ptr = live_interval*; ...@@ -50,7 +50,7 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl struct memory_coloring_impl
{ {
memory_coloring_impl(program* p) : p_program(p) memory_coloring_impl(program* p, std::string alloc_op) : p_program(p), allocation_op(std::move(alloc_op))
{ {
instr2_live.clear(); instr2_live.clear();
live_ranges.clear(); live_ranges.clear();
...@@ -81,7 +81,7 @@ struct memory_coloring_impl ...@@ -81,7 +81,7 @@ struct memory_coloring_impl
{ {
return is_param(ins) && any_cast<builtin::param>(ins->op).parameter == "output"; return is_param(ins) && any_cast<builtin::param>(ins->op).parameter == "output";
} }
static bool is_allocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; } bool is_allocate(const instruction_ref ins) { return ins->op.name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->op.name() == "@outline"; } static bool is_outline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->op.name() == "@literal"; } static bool is_literal(const instruction_ref ins) { return ins->op.name() == "@literal"; }
static bool is_check_context(const instruction_ref ins) static bool is_check_context(const instruction_ref ins)
...@@ -95,7 +95,12 @@ struct memory_coloring_impl ...@@ -95,7 +95,12 @@ struct memory_coloring_impl
std::string name = ins->op.name(); std::string name = ins->op.name();
if(operand_alias.find(name) != operand_alias.end()) if(operand_alias.find(name) != operand_alias.end())
return operand_alias[name]; return operand_alias[name];
if(is_allocate(ins))
{
// This happens to custom allocators.
operand_alias[name] = -1;
return -1;
}
int cnt = -1; int cnt = -1;
int last_allocate = -1; int last_allocate = -1;
for(auto&& arg : ins->arguments) for(auto&& arg : ins->arguments)
...@@ -104,7 +109,7 @@ struct memory_coloring_impl ...@@ -104,7 +109,7 @@ struct memory_coloring_impl
if(is_allocate(arg) || is_output_param(arg)) if(is_allocate(arg) || is_output_param(arg))
last_allocate = cnt; last_allocate = cnt;
} }
assert((last_allocate != -1)); assert(last_allocate != -1);
operand_alias[name] = last_allocate; operand_alias[name] = last_allocate;
return last_allocate; return last_allocate;
} }
...@@ -163,6 +168,7 @@ struct memory_coloring_impl ...@@ -163,6 +168,7 @@ struct memory_coloring_impl
long long required_bytes; long long required_bytes;
// The earliest program point where an live interval ends. // The earliest program point where an live interval ends.
int earliest_end_point; int earliest_end_point;
std::string allocation_op{};
}; };
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -30,7 +30,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -30,7 +30,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
memory_coloring{}, memory_coloring{"hip::allocate"},
lowering_memory_coloring{&ctx}, lowering_memory_coloring{&ctx},
fuse_ops{}, fuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -8,17 +8,35 @@ struct memory_coloring_target ...@@ -8,17 +8,35 @@ struct memory_coloring_target
std::string name() const { return "memory_coloring"; } std::string name() const { return "memory_coloring"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraph::pass> get_passes(migraph::context&) const
{ {
return {migraph::memory_coloring{}}; return {migraph::memory_coloring{"allocate"}};
} }
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
struct allocate
{
migraph::shape s{};
std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs}.has(0);
return s;
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>&) const
{
return {output_shape};
}
};
int main() int main()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {8}}});
p.add_instruction(migraph::transpose{{1, 0}}, l); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {40}}});
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 16); EXPECT(p.get_parameter_shape("scratch").bytes() == 192);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment