Commit fb706c81 authored by Paul's avatar Paul
Browse files

Fix compile errors

parent 5ee35bf0
......@@ -95,13 +95,13 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
return param{p.first, "T" + to_c_id(p.first)};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
[&](auto&& p) { return "class T" + to_c_id(p.first); });
this->return_type = "auto";
return *this;
}
......@@ -200,12 +200,7 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
const generate_module_callback& g)
{
function f;
auto name = transform_string(m.name(), [](char c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return '_';
});
f.set_name(name).set_types(m).set_body(
f.set_name(to_c_id(m.name())).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal")
{
......@@ -241,7 +236,7 @@ cpp_generator::to_args(const std::vector<instruction_ref>& inputs,
{
std::vector<std::string> args;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(args), [&](auto i) {
return names.at(i);
return to_c_id(names.at(i));
});
return args;
}
......@@ -265,7 +260,7 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
impl->fs << delim;
for(auto&& p : f.params)
{
impl->fs << delim << p.type << " " << p.name;
impl->fs << delim << p.type << " " << to_c_id(p.name);
delim = ',';
}
impl->fs << ") {\n" << f.body << "\n}\n";
......
......@@ -37,6 +37,11 @@ struct identity
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(shape, std::vector<argument> args) const { return args[0]; }
value attributes() const
{
return {{"pointwise", true}, {"point_op", "${0}"}};
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m)
{
if(starts_with(ins.name(), "@"))
continue;
if(contains({"multibroadcast", "contiguous"}, ins.name()))
if(contains({"multibroadcast", "contiguous", "identity"}, ins.name()))
continue;
if(ins.name() == "pointwise")
{
......
......@@ -126,6 +126,11 @@ struct compile_plan
{
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
}
catch(const std::exception& e)
{
std::cerr << "Exception in " + preop.name() + ": " + e.what() << std::endl;
results[i] = nullopt;
}
catch(...)
{
results[i] = nullopt;
......
......@@ -145,6 +145,7 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
std::cout << "**************** fused_concat_compiler" << std::endl;
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
......@@ -206,25 +207,33 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
});
std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(),
ins->module_inputs().end() - 1,
std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names;
module_ref last_mod = ins->module_inputs().back();
v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(),
ins->module_inputs().end()-1,
std::inserter(mod_args, mod_args.end()),
[&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size());
});
v["args"] = mod_args;
v["kernel"] = transform_accumulate(
auto prefix_name = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::string{},
std::plus<>{},
[&](module_ref mod) { return generate_name_from_ops(*mod) + "_"; }) +
[&](module_ref mod) -> std::string {
auto name = generate_name_from_ops(*mod);
if(name.empty())
return "";
return name + "_";
});
v["kernel"] = prefix_name +
"concat_" + generate_name_from_ops(*(ins->module_inputs().back())) +
"_kernel";
return compile_op(ctx, to_shapes(ins->inputs()), v);
......
......@@ -49,7 +49,7 @@ constexpr auto concat_slice(Output out, Input, Start)
template <index_int Axis, class Input, class Start, class... Ts>
constexpr auto concat_slices(Input input, Start start, Ts... xs)
{
return [=](auto f) { f(concat_slice<Axis>(xs, input, start)...); };
return [=](auto f) { return f(concat_slice<Axis>(xs, input, start)...); };
}
template <index_int Axis, class Input>
......@@ -81,7 +81,7 @@ __device__ auto concat2(InputPacks... input_packs)
auto idx = make_index();
fold([&](auto start, auto input_pack) {
return input_pack([&](auto g, auto x, auto... xs) {
concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) {
return concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) {
idx.global_stride(x.get_shape().elements(),
[&](auto i) { z[i] = f(g(x[i], xs[i]...), ys[i]...); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment