Commit 58b00ddf authored by mei-ye's avatar mei-ye
Browse files

change unit test and pass in allocate

parent 39eedcd9
......@@ -9,6 +9,7 @@ struct program;
struct memory_coloring
{
std::string allocation_op{};
std::string name() const { return "memory coloring"; }
void apply(program& p) const;
};
......
......@@ -7,7 +7,7 @@ void memory_coloring::apply(program& p) const
{
if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&p);
memory_coloring_impl opt(&p, allocation_op);
opt.run();
}
}
......
......@@ -222,6 +222,7 @@ void memory_coloring_impl::rewrite()
if(is_allocate(ins))
{
if(!ins->arguments.empty())
p_program->replace_instruction(
ins, load{ins->arguments.at(0)->result, offset}, scratch_param);
}
......
......@@ -50,7 +50,7 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl
{
memory_coloring_impl(program* p) : p_program(p)
memory_coloring_impl(program* p, std::string alloc_op) : p_program(p), allocation_op(std::move(alloc_op))
{
instr2_live.clear();
live_ranges.clear();
......@@ -81,7 +81,7 @@ struct memory_coloring_impl
{
return is_param(ins) && any_cast<builtin::param>(ins->op).parameter == "output";
}
static bool is_allocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; }
bool is_allocate(const instruction_ref ins) { return ins->op.name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->op.name() == "@literal"; }
static bool is_check_context(const instruction_ref ins)
......@@ -95,7 +95,12 @@ struct memory_coloring_impl
std::string name = ins->op.name();
if(operand_alias.find(name) != operand_alias.end())
return operand_alias[name];
if(is_allocate(ins))
{
// This happens to custom allocators.
operand_alias[name] = -1;
return -1;
}
int cnt = -1;
int last_allocate = -1;
for(auto&& arg : ins->arguments)
......@@ -104,7 +109,7 @@ struct memory_coloring_impl
if(is_allocate(arg) || is_output_param(arg))
last_allocate = cnt;
}
assert((last_allocate != -1));
assert(last_allocate != -1);
operand_alias[name] = last_allocate;
return last_allocate;
}
......@@ -163,6 +168,7 @@ struct memory_coloring_impl
long long required_bytes;
// The earliest program point where an live interval ends.
int earliest_end_point;
std::string allocation_op{};
};
} // namespace migraph
#endif
......@@ -30,7 +30,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{},
dead_code_elimination{},
lowering{ctx},
memory_coloring{},
memory_coloring{"hip::allocate"},
lowering_memory_coloring{&ctx},
fuse_ops{},
dead_code_elimination{},
......
......@@ -8,17 +8,35 @@ struct memory_coloring_target
std::string name() const { return "memory_coloring"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::memory_coloring{}};
return {migraph::memory_coloring{"allocate"}};
}
migraph::context get_context() const { return {}; }
};
struct allocate
{
migraph::shape s{};
std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs}.has(0);
return s;
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>&) const
{
return {output_shape};
}
};
int main()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
p.add_instruction(migraph::transpose{{1, 0}}, l);
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {8}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {40}}});
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 16);
EXPECT(p.get_parameter_shape("scratch").bytes() == 192);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment