Commit 88f549e2 authored by Paul's avatar Paul
Browse files

Fix output arg

parent b7aa8f2a
...@@ -29,7 +29,7 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>& ...@@ -29,7 +29,7 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
std::transform( std::transform(
args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); }); args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); });
k.launch(ctx.get_stream().get(), global, local, std::move(kargs)); k.launch(ctx.get_stream().get(), global, local, std::move(kargs));
return args.back(); return args[get_output_arg(args.size())];
} }
void code_object_op::finalize(context&, const shape&, const std::vector<shape>&) void code_object_op::finalize(context&, const shape&, const std::vector<shape>&)
{ {
......
...@@ -21,6 +21,7 @@ struct code_object_op ...@@ -21,6 +21,7 @@ struct code_object_op
std::size_t local; std::size_t local;
std::vector<shape> expected_inputs; std::vector<shape> expected_inputs;
shape output; shape output;
std::int64_t output_arg = -1;
kernel k{}; kernel k{};
template <class Self, class F> template <class Self, class F>
...@@ -39,9 +40,13 @@ struct code_object_op ...@@ -39,9 +40,13 @@ struct code_object_op
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void finalize(context&, const shape&, const std::vector<shape>&); void finalize(context&, const shape&, const std::vector<shape>&);
std::int64_t get_output_arg(std::size_t n) const
{
return output_arg < 0 ? n + output_arg : output_arg;
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return get_output_arg(shapes.size());
} }
friend std::ostream& operator<<(std::ostream& os, const code_object_op& op) friend std::ostream& operator<<(std::ostream& os, const code_object_op& op)
......
...@@ -534,10 +534,12 @@ instruction_ref insert_mlir(module& m, ...@@ -534,10 +534,12 @@ instruction_ref insert_mlir(module& m,
return lit; return lit;
}; };
std::size_t last = 0;
for(auto input : inputs) for(auto input : inputs)
{ {
const size_t offset = 0; const size_t offset = 0;
auto s = input->get_shape(); auto s = input->get_shape();
last = refs.size();
refs.push_back(input); refs.push_back(input);
refs.push_back(input); refs.push_back(input);
refs.push_back(get_literal(offset)); // offset refs.push_back(get_literal(offset)); // offset
...@@ -558,6 +560,7 @@ instruction_ref insert_mlir(module& m, ...@@ -558,6 +560,7 @@ instruction_ref insert_mlir(module& m,
} }
co.expected_inputs = to_shapes(refs); co.expected_inputs = to_shapes(refs);
co.output = mmlir.get_output_shapes().front(); co.output = mmlir.get_output_shapes().front();
co.output_arg = last;
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
......
...@@ -44,6 +44,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir) ...@@ -44,6 +44,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
}); });
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front())); inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::insert_mlir(*mm, mm->end(), mmlir, inputs); migraphx::gpu::insert_mlir(*mm, mm->end(), mmlir, inputs);
std::cout << p << std::endl;
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment