Commit 399eacef authored by Paul's avatar Paul
Browse files

Improve making the args

parent 8423d683
......@@ -106,10 +106,10 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& name)
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{
params.push_back({name, "T" + name});
tparams.push_back("class T" + name);
params.push_back({pname, "T" + pname});
tparams.push_back("class T" + pname);
return *this;
}
......
......@@ -135,7 +135,7 @@ insert_module_in_submodule(module_ref sm,
}
static std::vector<instruction_ref>
find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
find_inputs(module_ref sm, const module& parent, const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
......@@ -145,6 +145,8 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
continue;
if(param->name() != "@param")
continue;
if(not parent.has_instruction(input))
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
......@@ -152,6 +154,7 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second;
});
assert(result.size() == sm->get_parameter_shapes().size());
return result;
}
......@@ -211,7 +214,7 @@ struct find_pointwise_reduce
// Insert fused_reduce
insert_module_in_submodule(rm, reduce, map_ins);
auto new_inputs = find_inputs(rm, map_ins);
auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
}
};
......@@ -266,7 +269,7 @@ struct find_reduce_pointwise
auto out = insert_ins_in_submodule(rm, pw, map_ins);
rm->replace_return(out);
auto new_inputs = find_inputs(rm, map_ins);
auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
}
};
......@@ -300,7 +303,7 @@ struct find_reduce_reduce
auto out = insert_module_in_submodule(rm, reduce1, map_ins);
rm->replace_return(out);
auto new_inputs = find_inputs(rm, map_ins);
auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
}
};
......
......@@ -77,7 +77,7 @@ struct cpp_generator
function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
function& add_generic_param(const std::string& name);
function& add_generic_param(const std::string& pname);
};
cpp_generator();
......
......@@ -180,7 +180,7 @@ struct index
}
else
{
static_assert(max_stride_iterations(n, stride) < 64);
// static_assert(max_stride_iterations(n, stride) < 64);
sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) {
auto i = start + stride * k;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment