"src/targets/vscode:/vscode.git/clone" did not exist on "d0e2ace63ee1054ed55256757f592a0d3a54b295"
Commit 88f549e2 authored by Paul's avatar Paul
Browse files

Fix output arg

parent b7aa8f2a
......@@ -29,7 +29,7 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
std::transform(
args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); });
k.launch(ctx.get_stream().get(), global, local, std::move(kargs));
return args.back();
return args[get_output_arg(args.size())];
}
void code_object_op::finalize(context&, const shape&, const std::vector<shape>&)
{
......
......@@ -21,6 +21,7 @@ struct code_object_op
std::size_t local;
std::vector<shape> expected_inputs;
shape output;
std::int64_t output_arg = -1;
kernel k{};
template <class Self, class F>
......@@ -39,9 +40,13 @@ struct code_object_op
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void finalize(context&, const shape&, const std::vector<shape>&);
std::int64_t get_output_arg(std::size_t n) const
{
return output_arg < 0 ? n + output_arg : output_arg;
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
return get_output_arg(shapes.size());
}
friend std::ostream& operator<<(std::ostream& os, const code_object_op& op)
......
......@@ -534,10 +534,12 @@ instruction_ref insert_mlir(module& m,
return lit;
};
std::size_t last = 0;
for(auto input : inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
last = refs.size();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
......@@ -558,6 +560,7 @@ instruction_ref insert_mlir(module& m,
}
co.expected_inputs = to_shapes(refs);
co.output = mmlir.get_output_shapes().front();
co.output_arg = last;
return m.insert_instruction(ins, co, refs);
}
......
......@@ -44,6 +44,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
});
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::insert_mlir(*mm, mm->end(), mmlir, inputs);
std::cout << p << std::endl;
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment