Commit fb706c81 authored by Paul's avatar Paul
Browse files

Fix compile errors

parent 5ee35bf0
...@@ -95,13 +95,13 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module ...@@ -95,13 +95,13 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
std::map<std::string, shape> input_map(pmap.begin(), pmap.end()); std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform( std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) { input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first}; return param{p.first, "T" + to_c_id(p.first)};
}); });
std::transform(input_map.begin(), std::transform(input_map.begin(),
input_map.end(), input_map.end(),
std::back_inserter(this->tparams), std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; }); [&](auto&& p) { return "class T" + to_c_id(p.first); });
this->return_type = "auto"; this->return_type = "auto";
return *this; return *this;
} }
...@@ -200,12 +200,7 @@ cpp_generator::function cpp_generator::generate_module(const module& m, ...@@ -200,12 +200,7 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
const generate_module_callback& g) const generate_module_callback& g)
{ {
function f; function f;
auto name = transform_string(m.name(), [](char c) { f.set_name(to_c_id(m.name())).set_types(m).set_body(
if(with_char(::isalnum)(c) or c == '_')
return c;
return '_';
});
f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string { m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
...@@ -241,7 +236,7 @@ cpp_generator::to_args(const std::vector<instruction_ref>& inputs, ...@@ -241,7 +236,7 @@ cpp_generator::to_args(const std::vector<instruction_ref>& inputs,
{ {
std::vector<std::string> args; std::vector<std::string> args;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(args), [&](auto i) { std::transform(inputs.begin(), inputs.end(), std::back_inserter(args), [&](auto i) {
return names.at(i); return to_c_id(names.at(i));
}); });
return args; return args;
} }
...@@ -265,7 +260,7 @@ std::string cpp_generator::create_function(const cpp_generator::function& f) ...@@ -265,7 +260,7 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
impl->fs << delim; impl->fs << delim;
for(auto&& p : f.params) for(auto&& p : f.params)
{ {
impl->fs << delim << p.type << " " << p.name; impl->fs << delim << p.type << " " << to_c_id(p.name);
delim = ','; delim = ',';
} }
impl->fs << ") {\n" << f.body << "\n}\n"; impl->fs << ") {\n" << f.body << "\n}\n";
......
...@@ -37,6 +37,11 @@ struct identity ...@@ -37,6 +37,11 @@ struct identity
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); } shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(shape, std::vector<argument> args) const { return args[0]; } argument compute(shape, std::vector<argument> args) const { return args[0]; }
value attributes() const
{
return {{"pointwise", true}, {"point_op", "${0}"}};
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m)
{ {
if(starts_with(ins.name(), "@")) if(starts_with(ins.name(), "@"))
continue; continue;
if(contains({"multibroadcast", "contiguous"}, ins.name())) if(contains({"multibroadcast", "contiguous", "identity"}, ins.name()))
continue; continue;
if(ins.name() == "pointwise") if(ins.name() == "pointwise")
{ {
......
...@@ -126,6 +126,11 @@ struct compile_plan ...@@ -126,6 +126,11 @@ struct compile_plan
{ {
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins}; results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
} }
catch(const std::exception& e)
{
std::cerr << "Exception in " + preop.name() + ": " + e.what() << std::endl;
results[i] = nullopt;
}
catch(...) catch(...)
{ {
results[i] = nullopt; results[i] = nullopt;
......
...@@ -145,6 +145,7 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -145,6 +145,7 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
std::cout << "**************** fused_concat_compiler" << std::endl;
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
...@@ -206,25 +207,33 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -206,25 +207,33 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
}); });
std::vector<std::string> mod_names; std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(), std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(), ins->module_inputs().end() - 1,
std::back_inserter(mod_names), std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); }); [&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names; v["ops"] = mod_names;
module_ref last_mod = ins->module_inputs().back();
v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
std::unordered_map<std::string, std::size_t> mod_args; std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(), std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(), ins->module_inputs().end()-1,
std::inserter(mod_args, mod_args.end()), std::inserter(mod_args, mod_args.end()),
[&](module_ref mod) { [&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name()); const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size()); return std::make_pair(name, mod->get_parameter_names().size());
}); });
v["args"] = mod_args; v["args"] = mod_args;
v["kernel"] = transform_accumulate( auto prefix_name = transform_accumulate(
ins->module_inputs().begin(), ins->module_inputs().begin(),
ins->module_inputs().end() - 1, ins->module_inputs().end() - 1,
std::string{}, std::string{},
std::plus<>{}, std::plus<>{},
[&](module_ref mod) { return generate_name_from_ops(*mod) + "_"; }) + [&](module_ref mod) -> std::string {
auto name = generate_name_from_ops(*mod);
if(name.empty())
return "";
return name + "_";
});
v["kernel"] = prefix_name +
"concat_" + generate_name_from_ops(*(ins->module_inputs().back())) + "concat_" + generate_name_from_ops(*(ins->module_inputs().back())) +
"_kernel"; "_kernel";
return compile_op(ctx, to_shapes(ins->inputs()), v); return compile_op(ctx, to_shapes(ins->inputs()), v);
......
...@@ -49,7 +49,7 @@ constexpr auto concat_slice(Output out, Input, Start) ...@@ -49,7 +49,7 @@ constexpr auto concat_slice(Output out, Input, Start)
template <index_int Axis, class Input, class Start, class... Ts> template <index_int Axis, class Input, class Start, class... Ts>
constexpr auto concat_slices(Input input, Start start, Ts... xs) constexpr auto concat_slices(Input input, Start start, Ts... xs)
{ {
return [=](auto f) { f(concat_slice<Axis>(xs, input, start)...); }; return [=](auto f) { return f(concat_slice<Axis>(xs, input, start)...); };
} }
template <index_int Axis, class Input> template <index_int Axis, class Input>
...@@ -81,7 +81,7 @@ __device__ auto concat2(InputPacks... input_packs) ...@@ -81,7 +81,7 @@ __device__ auto concat2(InputPacks... input_packs)
auto idx = make_index(); auto idx = make_index();
fold([&](auto start, auto input_pack) { fold([&](auto start, auto input_pack) {
return input_pack([&](auto g, auto x, auto... xs) { return input_pack([&](auto g, auto x, auto... xs) {
concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) { return concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) {
idx.global_stride(x.get_shape().elements(), idx.global_stride(x.get_shape().elements(),
[&](auto i) { z[i] = f(g(x[i], xs[i]...), ys[i]...); }); [&](auto i) { z[i] = f(g(x[i], xs[i]...), ys[i]...); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment