Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -14,7 +14,7 @@ bool happens_before(const std::vector<std::size_t>& e1, const std::vector<std::s ...@@ -14,7 +14,7 @@ bool happens_before(const std::vector<std::size_t>& e1, const std::vector<std::s
not std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::greater_equal<>{}); not std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::greater_equal<>{});
} }
std::vector<stream_race> analyze_streams(const program& p, const stream_model& m) std::vector<stream_race> analyze_streams(const module& p, const stream_model& m)
{ {
using vector_clock = std::vector<std::size_t>; using vector_clock = std::vector<std::size_t>;
std::vector<stream_race> races; std::vector<stream_race> races;
......
...@@ -118,8 +118,8 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) ...@@ -118,8 +118,8 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct quantize_int8_options struct quantize_int8_options
{ {
std::vector<program::parameter_map> calibration = {}; std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {}; std::vector<std::string> op_names = {};
}; };
void add_op_name(quantize_int8_options& options, const char* name) void add_op_name(quantize_int8_options& options, const char* name)
...@@ -127,7 +127,7 @@ void add_op_name(quantize_int8_options& options, const char* name) ...@@ -127,7 +127,7 @@ void add_op_name(quantize_int8_options& options, const char* name)
options.op_names.push_back(name); options.op_names.push_back(name);
} }
void add_calibration_data(quantize_int8_options& options, program::parameter_map& data) void add_calibration_data(quantize_int8_options& options, parameter_map& data)
{ {
options.calibration.push_back(data); options.calibration.push_back(data);
} }
...@@ -160,10 +160,7 @@ bool equal(const T& x, const T& y) ...@@ -160,10 +160,7 @@ bool equal(const T& x, const T& y)
return x == y; return x == y;
} }
std::vector<argument> run(program& p, const program::parameter_map& params) std::vector<argument> run(program& p, const parameter_map& params) { return p.eval(params); }
{
return p.eval(params);
}
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); } std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(program& p) const void auto_contiguous::apply(module& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -29,7 +29,7 @@ std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last) ...@@ -29,7 +29,7 @@ std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
return -n; return -n;
} }
void dead_code_elimination::apply(program& p) const void dead_code_elimination::apply(module& p) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
......
...@@ -18,7 +18,7 @@ struct find_dot_add ...@@ -18,7 +18,7 @@ struct find_dot_add
{ {
auto matcher() const { return match::name("dot")(match::nargs(3)); } auto matcher() const { return match::name("dot")(match::nargs(3)); }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator()); auto dot = any_cast<op::dot>(ins->get_operator());
...@@ -42,7 +42,7 @@ struct find_dot_add ...@@ -42,7 +42,7 @@ struct find_dot_add
} // namespace } // namespace
void decompose::apply(program& p) const { match::find_matches(p, find_dot_add{}); } void decompose::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -10,169 +10,170 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,169 +10,170 @@ inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto m0 = auto m0 =
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}}); mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto mx0 = p.add_literal( auto mx0 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0));
auto mx1 = p.add_literal( auto mx1 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1));
auto mx2 = p.add_literal( auto mx2 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2));
auto mx3 = p.add_literal( auto mx3 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
auto mx4 = p.add_literal( auto mx4 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
auto mx5 = p.add_literal( auto mx5 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
auto mx6 = p.add_literal( auto mx6 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6));
auto mx7 = p.add_literal(migraphx::generate_literal( auto mx7 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7)); migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7));
auto mx8 = p.add_literal( auto mx8 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8));
auto mx9 = p.add_literal(migraphx::generate_literal( auto mx9 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9)); migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
auto mx10 = p.add_literal( auto mx10 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10));
auto mx11 = p.add_literal(migraphx::generate_literal( auto mx11 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11)); migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
auto mx12 = p.add_literal( auto mx12 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12));
auto mx13 = p.add_literal(migraphx::generate_literal( auto mx13 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13)); migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
auto mx14 = p.add_literal( auto mx14 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14));
auto mx15 = p.add_literal(migraphx::generate_literal( auto mx15 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15)); migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15));
migraphx::op::convolution convolution16; migraphx::op::convolution convolution16;
convolution16.padding = {2, 2}; convolution16.padding = {2, 2};
convolution16.stride = {4, 4}; convolution16.stride = {4, 4};
convolution16.dilation = {1, 1}; convolution16.dilation = {1, 1};
convolution16.group = 1; convolution16.group = 1;
auto mx16 = p.add_instruction(convolution16, m0, mx15); auto mx16 = mm->add_instruction(convolution16, m0, mx15);
migraphx::op::broadcast broadcast17; migraphx::op::broadcast broadcast17;
broadcast17.axis = 1; broadcast17.axis = 1;
broadcast17.broadcast_lens = {batch, 64, 55, 55}; broadcast17.broadcast_lens = {batch, 64, 55, 55};
auto mx17 = p.add_instruction(broadcast17, mx14); auto mx17 = mm->add_instruction(broadcast17, mx14);
migraphx::op::add add18; migraphx::op::add add18;
auto mx18 = p.add_instruction(add18, mx16, mx17); auto mx18 = mm->add_instruction(add18, mx16, mx17);
migraphx::op::relu relu19; migraphx::op::relu relu19;
auto mx19 = p.add_instruction(relu19, mx18); auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20; migraphx::op::pooling pooling20;
pooling20.mode = "max"; pooling20.mode = "max";
pooling20.padding = {0, 0}; pooling20.padding = {0, 0};
pooling20.stride = {2, 2}; pooling20.stride = {2, 2};
pooling20.lengths = {3, 3}; pooling20.lengths = {3, 3};
auto mx20 = p.add_instruction(pooling20, mx19); auto mx20 = mm->add_instruction(pooling20, mx19);
migraphx::op::convolution convolution21; migraphx::op::convolution convolution21;
convolution21.padding = {2, 2}; convolution21.padding = {2, 2};
convolution21.stride = {1, 1}; convolution21.stride = {1, 1};
convolution21.dilation = {1, 1}; convolution21.dilation = {1, 1};
convolution21.group = 1; convolution21.group = 1;
auto mx21 = p.add_instruction(convolution21, mx20, mx13); auto mx21 = mm->add_instruction(convolution21, mx20, mx13);
migraphx::op::broadcast broadcast22; migraphx::op::broadcast broadcast22;
broadcast22.axis = 1; broadcast22.axis = 1;
broadcast22.broadcast_lens = {batch, 192, 27, 27}; broadcast22.broadcast_lens = {batch, 192, 27, 27};
auto mx22 = p.add_instruction(broadcast22, mx12); auto mx22 = mm->add_instruction(broadcast22, mx12);
migraphx::op::add add23; migraphx::op::add add23;
auto mx23 = p.add_instruction(add23, mx21, mx22); auto mx23 = mm->add_instruction(add23, mx21, mx22);
migraphx::op::relu relu24; migraphx::op::relu relu24;
auto mx24 = p.add_instruction(relu24, mx23); auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25; migraphx::op::pooling pooling25;
pooling25.mode = "max"; pooling25.mode = "max";
pooling25.padding = {0, 0}; pooling25.padding = {0, 0};
pooling25.stride = {2, 2}; pooling25.stride = {2, 2};
pooling25.lengths = {3, 3}; pooling25.lengths = {3, 3};
auto mx25 = p.add_instruction(pooling25, mx24); auto mx25 = mm->add_instruction(pooling25, mx24);
migraphx::op::convolution convolution26; migraphx::op::convolution convolution26;
convolution26.padding = {1, 1}; convolution26.padding = {1, 1};
convolution26.stride = {1, 1}; convolution26.stride = {1, 1};
convolution26.dilation = {1, 1}; convolution26.dilation = {1, 1};
convolution26.group = 1; convolution26.group = 1;
auto mx26 = p.add_instruction(convolution26, mx25, mx11); auto mx26 = mm->add_instruction(convolution26, mx25, mx11);
migraphx::op::broadcast broadcast27; migraphx::op::broadcast broadcast27;
broadcast27.axis = 1; broadcast27.axis = 1;
broadcast27.broadcast_lens = {batch, 384, 13, 13}; broadcast27.broadcast_lens = {batch, 384, 13, 13};
auto mx27 = p.add_instruction(broadcast27, mx10); auto mx27 = mm->add_instruction(broadcast27, mx10);
migraphx::op::add add28; migraphx::op::add add28;
auto mx28 = p.add_instruction(add28, mx26, mx27); auto mx28 = mm->add_instruction(add28, mx26, mx27);
migraphx::op::relu relu29; migraphx::op::relu relu29;
auto mx29 = p.add_instruction(relu29, mx28); auto mx29 = mm->add_instruction(relu29, mx28);
migraphx::op::convolution convolution30; migraphx::op::convolution convolution30;
convolution30.padding = {1, 1}; convolution30.padding = {1, 1};
convolution30.stride = {1, 1}; convolution30.stride = {1, 1};
convolution30.dilation = {1, 1}; convolution30.dilation = {1, 1};
convolution30.group = 1; convolution30.group = 1;
auto mx30 = p.add_instruction(convolution30, mx29, mx9); auto mx30 = mm->add_instruction(convolution30, mx29, mx9);
migraphx::op::broadcast broadcast31; migraphx::op::broadcast broadcast31;
broadcast31.axis = 1; broadcast31.axis = 1;
broadcast31.broadcast_lens = {batch, 256, 13, 13}; broadcast31.broadcast_lens = {batch, 256, 13, 13};
auto mx31 = p.add_instruction(broadcast31, mx8); auto mx31 = mm->add_instruction(broadcast31, mx8);
migraphx::op::add add32; migraphx::op::add add32;
auto mx32 = p.add_instruction(add32, mx30, mx31); auto mx32 = mm->add_instruction(add32, mx30, mx31);
migraphx::op::relu relu33; migraphx::op::relu relu33;
auto mx33 = p.add_instruction(relu33, mx32); auto mx33 = mm->add_instruction(relu33, mx32);
migraphx::op::convolution convolution34; migraphx::op::convolution convolution34;
convolution34.padding = {1, 1}; convolution34.padding = {1, 1};
convolution34.stride = {1, 1}; convolution34.stride = {1, 1};
convolution34.dilation = {1, 1}; convolution34.dilation = {1, 1};
convolution34.group = 1; convolution34.group = 1;
auto mx34 = p.add_instruction(convolution34, mx33, mx7); auto mx34 = mm->add_instruction(convolution34, mx33, mx7);
migraphx::op::broadcast broadcast35; migraphx::op::broadcast broadcast35;
broadcast35.axis = 1; broadcast35.axis = 1;
broadcast35.broadcast_lens = {batch, 256, 13, 13}; broadcast35.broadcast_lens = {batch, 256, 13, 13};
auto mx35 = p.add_instruction(broadcast35, mx6); auto mx35 = mm->add_instruction(broadcast35, mx6);
migraphx::op::add add36; migraphx::op::add add36;
auto mx36 = p.add_instruction(add36, mx34, mx35); auto mx36 = mm->add_instruction(add36, mx34, mx35);
migraphx::op::relu relu37; migraphx::op::relu relu37;
auto mx37 = p.add_instruction(relu37, mx36); auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38; migraphx::op::pooling pooling38;
pooling38.mode = "max"; pooling38.mode = "max";
pooling38.padding = {0, 0}; pooling38.padding = {0, 0};
pooling38.stride = {2, 2}; pooling38.stride = {2, 2};
pooling38.lengths = {3, 3}; pooling38.lengths = {3, 3};
auto mx38 = p.add_instruction(pooling38, mx37); auto mx38 = mm->add_instruction(pooling38, mx37);
migraphx::op::flatten flatten39; migraphx::op::flatten flatten39;
flatten39.axis = 1; flatten39.axis = 1;
auto mx39 = p.add_instruction(flatten39, mx38); auto mx39 = mm->add_instruction(flatten39, mx38);
migraphx::op::identity identity40; migraphx::op::identity identity40;
auto mx40 = p.add_instruction(identity40, mx39); auto mx40 = mm->add_instruction(identity40, mx39);
migraphx::op::transpose transpose41; migraphx::op::transpose transpose41;
transpose41.dims = {1, 0}; transpose41.dims = {1, 0};
auto mx41 = p.add_instruction(transpose41, mx5); auto mx41 = mm->add_instruction(transpose41, mx5);
migraphx::op::multibroadcast multibroadcast42; migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096}; multibroadcast42.output_lens = {batch, 4096};
auto mx42 = p.add_instruction(multibroadcast42, mx4); auto mx42 = mm->add_instruction(multibroadcast42, mx4);
migraphx::op::dot dot43; migraphx::op::dot dot43;
dot43.alpha = 1; dot43.alpha = 1;
dot43.beta = 1; dot43.beta = 1;
auto mx43 = p.add_instruction(dot43, mx40, mx41, mx42); auto mx43 = mm->add_instruction(dot43, mx40, mx41, mx42);
migraphx::op::relu relu44; migraphx::op::relu relu44;
auto mx44 = p.add_instruction(relu44, mx43); auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45; migraphx::op::identity identity45;
auto mx45 = p.add_instruction(identity45, mx44); auto mx45 = mm->add_instruction(identity45, mx44);
migraphx::op::transpose transpose46; migraphx::op::transpose transpose46;
transpose46.dims = {1, 0}; transpose46.dims = {1, 0};
auto mx46 = p.add_instruction(transpose46, mx3); auto mx46 = mm->add_instruction(transpose46, mx3);
migraphx::op::multibroadcast multibroadcast47; migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096}; multibroadcast47.output_lens = {batch, 4096};
auto mx47 = p.add_instruction(multibroadcast47, mx2); auto mx47 = mm->add_instruction(multibroadcast47, mx2);
migraphx::op::dot dot48; migraphx::op::dot dot48;
dot48.alpha = 1; dot48.alpha = 1;
dot48.beta = 1; dot48.beta = 1;
auto mx48 = p.add_instruction(dot48, mx45, mx46, mx47); auto mx48 = mm->add_instruction(dot48, mx45, mx46, mx47);
migraphx::op::relu relu49; migraphx::op::relu relu49;
auto mx49 = p.add_instruction(relu49, mx48); auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50; migraphx::op::transpose transpose50;
transpose50.dims = {1, 0}; transpose50.dims = {1, 0};
auto mx50 = p.add_instruction(transpose50, mx1); auto mx50 = mm->add_instruction(transpose50, mx1);
migraphx::op::multibroadcast multibroadcast51; migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000}; multibroadcast51.output_lens = {batch, 1000};
auto mx51 = p.add_instruction(multibroadcast51, mx0); auto mx51 = mm->add_instruction(multibroadcast51, mx0);
migraphx::op::dot dot52; migraphx::op::dot dot52;
dot52.alpha = 1; dot52.alpha = 1;
dot52.beta = 1; dot52.beta = 1;
p.add_instruction(dot52, mx49, mx50, mx51); mm->add_instruction(dot52, mx49, mx50, mx51);
return p; return p;
} }
......
This diff is collapsed.
...@@ -135,10 +135,12 @@ struct loader ...@@ -135,10 +135,12 @@ struct loader
if(trim > 0) if(trim > 0)
{ {
auto last = std::prev(p.end(), trim); auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end()); auto* mm = p.get_main_module();
mm->remove_instructions(last, p.end());
} }
if(optimize) if(optimize)
migraphx::run_passes(p, {
migraphx::run_passes(*p.get_main_module(),
{ {
migraphx::rewrite_batchnorm{}, migraphx::rewrite_batchnorm{},
migraphx::eliminate_identity{}, migraphx::eliminate_identity{},
...@@ -152,6 +154,7 @@ struct loader ...@@ -152,6 +154,7 @@ struct loader
migraphx::eliminate_pad{}, migraphx::eliminate_pad{},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
}); });
}
return p; return p;
} }
...@@ -204,7 +207,7 @@ struct program_params ...@@ -204,7 +207,7 @@ struct program_params
auto generate(const program& p, const target& t, bool offload) auto generate(const program& p, const target& t, bool offload)
{ {
program::parameter_map m; parameter_map m;
for(auto&& s : fill0) for(auto&& s : fill0)
m[s] = fill_argument(p.get_parameter_shape(s), 0); m[s] = fill_argument(p.get_parameter_shape(s), 0);
for(auto&& s : fill1) for(auto&& s : fill1)
......
...@@ -16,8 +16,7 @@ auto get_hash(const T& x) ...@@ -16,8 +16,7 @@ auto get_hash(const T& x)
return std::hash<T>{}(x); return std::hash<T>{}(x);
} }
program::parameter_map parameter_map fill_param_map(parameter_map& m, const program& p, const target& t, bool offload)
fill_param_map(program::parameter_map& m, const program& p, const target& t, bool offload)
{ {
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
...@@ -30,7 +29,7 @@ fill_param_map(program::parameter_map& m, const program& p, const target& t, boo ...@@ -30,7 +29,7 @@ fill_param_map(program::parameter_map& m, const program& p, const target& t, boo
return m; return m;
} }
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu) parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu)
{ {
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
...@@ -47,9 +46,9 @@ program::parameter_map fill_param_map(program::parameter_map& m, const program& ...@@ -47,9 +46,9 @@ program::parameter_map fill_param_map(program::parameter_map& m, const program&
return m; return m;
} }
program::parameter_map create_param_map(const program& p, const target& t, bool offload) parameter_map create_param_map(const program& p, const target& t, bool offload)
{ {
program::parameter_map m; parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
auto arg = generate_argument(x.second, get_hash(x.first)); auto arg = generate_argument(x.second, get_hash(x.first));
...@@ -61,9 +60,9 @@ program::parameter_map create_param_map(const program& p, const target& t, bool ...@@ -61,9 +60,9 @@ program::parameter_map create_param_map(const program& p, const target& t, bool
return m; return m;
} }
program::parameter_map create_param_map(const program& p, bool gpu) parameter_map create_param_map(const program& p, bool gpu)
{ {
program::parameter_map m; parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
#ifdef HAVE_GPU #ifdef HAVE_GPU
......
...@@ -7,12 +7,12 @@ namespace migraphx { ...@@ -7,12 +7,12 @@ namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map parameter_map
fill_param_map(program::parameter_map& m, const program& p, const target& t, bool offload = false); fill_param_map(parameter_map& m, const program& p, const target& t, bool offload = false);
program::parameter_map create_param_map(const program& p, const target& t, bool offload = false); parameter_map create_param_map(const program& p, const target& t, bool offload = false);
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu); parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu);
program::parameter_map create_param_map(const program& p, bool gpu = true); parameter_map create_param_map(const program& p, bool gpu = true);
target get_target(bool gpu); target get_target(bool gpu);
void compile_program(program& p, bool gpu = true); void compile_program(program& p, bool gpu = true);
......
This diff is collapsed.
...@@ -11,7 +11,7 @@ namespace migraphx { ...@@ -11,7 +11,7 @@ namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::vector<argument> run_ref(program p, const program::parameter_map& inputs) std::vector<argument> run_ref(program p, const parameter_map& inputs)
{ {
p.compile(ref::target{}); p.compile(ref::target{});
auto out = p.eval(inputs); auto out = p.eval(inputs);
...@@ -19,14 +19,12 @@ std::vector<argument> run_ref(program p, const program::parameter_map& inputs) ...@@ -19,14 +19,12 @@ std::vector<argument> run_ref(program p, const program::parameter_map& inputs)
return out; return out;
} }
std::vector<argument> run_target(program p, std::vector<argument>
const target& t, run_target(program p, const target& t, const compile_options& options, const parameter_map& inputs)
const compile_options& options,
const program::parameter_map& inputs)
{ {
p.compile(t, options); p.compile(t, options);
program::parameter_map m; parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
auto arg = inputs.count(x.first) == 0 ? generate_argument(x.second) : inputs.at(x.first); auto arg = inputs.count(x.first) == 0 ? generate_argument(x.second) : inputs.at(x.first);
...@@ -45,7 +43,7 @@ void verify_program(const std::string& name, ...@@ -45,7 +43,7 @@ void verify_program(const std::string& name,
const program& p, const program& p,
const target& t, const target& t,
compile_options options, compile_options options,
const program::parameter_map& inputs, const parameter_map& inputs,
double tolerance) double tolerance)
{ {
auto x = run_ref(p, inputs); auto x = run_ref(p, inputs);
...@@ -65,7 +63,8 @@ void verify_instructions(const program& prog, ...@@ -65,7 +63,8 @@ void verify_instructions(const program& prog,
compile_options options, compile_options options,
double tolerance) double tolerance)
{ {
for(auto&& ins : prog) const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog))
{ {
if(ins.name().front() == '@') if(ins.name().front() == '@')
continue; continue;
...@@ -78,15 +77,17 @@ void verify_instructions(const program& prog, ...@@ -78,15 +77,17 @@ void verify_instructions(const program& prog,
if(ins.name() == "undefined") if(ins.name() == "undefined")
continue; continue;
program p; program p;
auto* mm_p = p.get_main_module();
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
for(auto&& arg : ins.inputs()) for(auto&& arg : ins.inputs())
{ {
if(arg->name() == "@literal") if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->get_literal())); inputs.push_back(mm_p->add_literal(arg->get_literal()));
else else
inputs.push_back(p.add_parameter(std::to_string(inputs.size()), arg->get_shape())); inputs.push_back(
mm_p->add_parameter(std::to_string(inputs.size()), arg->get_shape()));
} }
p.add_instruction(ins.get_operator(), inputs); mm_p->add_instruction(ins.get_operator(), inputs);
try try
{ {
std::cout << "Verify: " << ins.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
...@@ -105,11 +106,12 @@ void verify_reduced(program p, ...@@ -105,11 +106,12 @@ void verify_reduced(program p,
int n, int n,
const target& t, const target& t,
compile_options options, compile_options options,
const program::parameter_map& inputs, const parameter_map& inputs,
double tolerance) double tolerance)
{ {
auto* mm = p.get_main_module();
auto last = std::prev(p.end(), n + 1); auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end()); mm->remove_instructions(last, p.end());
std::cout << "Verify: " << std::endl; std::cout << "Verify: " << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, inputs, tolerance); verify_program(std::to_string(n), p, t, options, inputs, tolerance);
...@@ -118,7 +120,7 @@ void verify_reduced(program p, ...@@ -118,7 +120,7 @@ void verify_reduced(program p,
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options, compile_options options,
const program::parameter_map& inputs, const parameter_map& inputs,
double tolerance) double tolerance)
{ {
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
......
...@@ -10,18 +10,18 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,18 +10,18 @@ inline namespace MIGRAPHX_INLINE_NS {
void verify_program(const std::string& name, void verify_program(const std::string& name,
const program& p, const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
const program::parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 100); double tolerance = 100);
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
double tolerance = 80); double tolerance = 80);
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
const program::parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 80); double tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_allocation::apply(program& p) const void eliminate_allocation::apply(module& p) const
{ {
assert(alignment > 0); assert(alignment > 0);
......
...@@ -11,7 +11,7 @@ namespace migraphx { ...@@ -11,7 +11,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range> template <class Range>
void cse_range(program& p, Range&& r) void cse_range(module& p, Range&& r)
{ {
std::unordered_multimap<std::string, instruction_ref> instructions; std::unordered_multimap<std::string, instruction_ref> instructions;
std::unordered_set<instruction_ref> processed_ins; std::unordered_set<instruction_ref> processed_ins;
...@@ -42,7 +42,7 @@ void cse_range(program& p, Range&& r) ...@@ -42,7 +42,7 @@ void cse_range(program& p, Range&& r)
} }
} }
void eliminate_common_subexpression::apply(program& p) const { cse_range(p, iterator_for(p)); } void eliminate_common_subexpression::apply(module& p) const { cse_range(p, iterator_for(p)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_concat::apply(program& p) const void eliminate_concat::apply(module& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -65,7 +65,7 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<instruction ...@@ -65,7 +65,7 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<instruction
return try_compute_shape(ins, inputs); return try_compute_shape(ins, inputs);
} }
void eliminate_contiguous::apply(program& p) const void eliminate_contiguous::apply(module& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(program& p) const void eliminate_identity::apply(module& p) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_pad::apply(program& p) const void eliminate_pad::apply(module& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
...@@ -34,7 +34,7 @@ template <class T> ...@@ -34,7 +34,7 @@ template <class T>
void eliminate_pad::update_op(T, void eliminate_pad::update_op(T,
const instruction_ref& input, const instruction_ref& input,
const instruction_ref& ins, const instruction_ref& ins,
program& p) const module& p) const
{ {
auto pad_op = any_cast<op::pad>(input->get_operator()); auto pad_op = any_cast<op::pad>(input->get_operator());
if(!pad_op.symmetric()) if(!pad_op.symmetric())
...@@ -56,7 +56,7 @@ void eliminate_pad::update_op(T, ...@@ -56,7 +56,7 @@ void eliminate_pad::update_op(T,
void eliminate_pad::update_pooling(const instruction_ref& input, void eliminate_pad::update_pooling(const instruction_ref& input,
const instruction_ref& ins, const instruction_ref& ins,
program& p) const module& p) const
{ {
auto pad_op = any_cast<op::pad>(input->get_operator()); auto pad_op = any_cast<op::pad>(input->get_operator());
if(!pad_op.symmetric()) if(!pad_op.symmetric())
......
...@@ -9,11 +9,12 @@ namespace migraphx { ...@@ -9,11 +9,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
struct auto_contiguous struct auto_contiguous
{ {
std::string name() const { return "auto_contiguous"; } std::string name() const { return "auto_contiguous"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment