Commit 3e6a6b3c authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

test case and move pass

parent 2eb12723
......@@ -75,12 +75,8 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
refs,
layout_ins->module_inputs());
// auto layout = m.insert_instruction(
// std::next(output),
// make_op("gpu::precompile_op", {{"op", to_value(layout_ins->get_operator())}}),
// output);
// m.debug_print();
result.insert(layout);
// m.debug_print(layout);
}
return result;
}
......@@ -95,12 +91,20 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
auto precompile_op = ins->get_operator();
auto val = precompile_op.to_value();
if(val["op"].at("name") != "layout")
if(val["op"].at("name").to<std::string>() != "layout")
{
// std::cout << val["op"].at("name").to<std::string>() << std::endl;
continue;
}
m.debug_print(ins);
if(ins->get_shape() != ins->inputs().front()->get_shape())
{
std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() << std::endl;
continue;
}
if(contains(output_layouts, ins))
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
......
......@@ -135,9 +135,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
fuse_mlir{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
compile_miopen{&gctx},
......
......@@ -50,3 +50,27 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu>
return p;
}
};
struct test_conv_add_relu : verify_program<test_conv_add_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 4, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 1}});
auto bias_literal = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(bias_literal);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
bias);
auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_bias);
auto relu = mm->add_instruction(migraphx::make_op("relu"), bias_add);
mm->add_instruction(migraphx::make_op("convolution"), relu, weights);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment