"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e4759983488dc82cc8254a7668030375a5053bf9"
Commit 399eacef authored by Paul's avatar Paul
Browse files

Improve making the args

parent 8423d683
...@@ -106,10 +106,10 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module ...@@ -106,10 +106,10 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this; return *this;
} }
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& name) cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{ {
params.push_back({name, "T" + name}); params.push_back({pname, "T" + pname});
tparams.push_back("class T" + name); tparams.push_back("class T" + pname);
return *this; return *this;
} }
......
...@@ -135,7 +135,7 @@ insert_module_in_submodule(module_ref sm, ...@@ -135,7 +135,7 @@ insert_module_in_submodule(module_ref sm,
} }
static std::vector<instruction_ref> static std::vector<instruction_ref>
find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction_ref>& map_ins) find_inputs(module_ref sm, const module& parent, const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names; std::map<std::string, instruction_ref> names;
...@@ -145,6 +145,8 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction ...@@ -145,6 +145,8 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
continue; continue;
if(param->name() != "@param") if(param->name() != "@param")
continue; continue;
if(not parent.has_instruction(input))
continue;
auto v = param->get_operator().to_value(); auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>(); auto name = v.at("parameter").to<std::string>();
names[name] = input; names[name] = input;
...@@ -152,6 +154,7 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction ...@@ -152,6 +154,7 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) { std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second; return p.second;
}); });
assert(result.size() == sm->get_parameter_shapes().size());
return result; return result;
} }
...@@ -211,7 +214,7 @@ struct find_pointwise_reduce ...@@ -211,7 +214,7 @@ struct find_pointwise_reduce
// Insert fused_reduce // Insert fused_reduce
insert_module_in_submodule(rm, reduce, map_ins); insert_module_in_submodule(rm, reduce, map_ins);
auto new_inputs = find_inputs(rm, map_ins); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm}); mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
} }
}; };
...@@ -266,7 +269,7 @@ struct find_reduce_pointwise ...@@ -266,7 +269,7 @@ struct find_reduce_pointwise
auto out = insert_ins_in_submodule(rm, pw, map_ins); auto out = insert_ins_in_submodule(rm, pw, map_ins);
rm->replace_return(out); rm->replace_return(out);
auto new_inputs = find_inputs(rm, map_ins); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm}); mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
} }
}; };
...@@ -300,7 +303,7 @@ struct find_reduce_reduce ...@@ -300,7 +303,7 @@ struct find_reduce_reduce
auto out = insert_module_in_submodule(rm, reduce1, map_ins); auto out = insert_module_in_submodule(rm, reduce1, map_ins);
rm->replace_return(out); rm->replace_return(out);
auto new_inputs = find_inputs(rm, map_ins); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm}); mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
} }
}; };
......
...@@ -77,7 +77,7 @@ struct cpp_generator ...@@ -77,7 +77,7 @@ struct cpp_generator
function& set_types(const module& m); function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse); function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m); function& set_generic_types(const module& m);
function& add_generic_param(const std::string& name); function& add_generic_param(const std::string& pname);
}; };
cpp_generator(); cpp_generator();
......
...@@ -180,7 +180,7 @@ struct index ...@@ -180,7 +180,7 @@ struct index
} }
else else
{ {
static_assert(max_stride_iterations(n, stride) < 64); // static_assert(max_stride_iterations(n, stride) < 64);
sequence(max_stride_iterations(n, stride), [&](auto... ks) { sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) { fold([&](auto d, auto k) {
auto i = start + stride * k; auto i = start + stride * k;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment