Commit 3de56715 authored by wsttiger's avatar wsttiger
Browse files

Formatting

parent 447eec93
......@@ -8,14 +8,14 @@
namespace migraph {
namespace gpu {
shape hip_concat::compute_shape(std::vector<shape> inputs) const
shape hip_concat::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
std::vector<std::size_t> hip_concat::compute_offsets(const shape& output_shape,
const std::vector<argument> args) const
const std::vector<argument> args) const
{
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
......@@ -29,8 +29,8 @@ std::vector<std::size_t> hip_concat::compute_offsets(const shape& output_shape,
}
argument hip_concat::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
const shape& output_shape,
const std::vector<argument>& args) const
{
std::vector<std::size_t> offsets = compute_offsets(output_shape, args);
return device::concat(output_shape, args, offsets);
......
......@@ -8,30 +8,30 @@ namespace migraph {
namespace gpu {
namespace device {
argument concat(const migraph::shape& output_shape, std::vector<migraph::argument> args, std::vector<std::size_t> offsets)
argument concat(const migraph::shape& output_shape,
std::vector<migraph::argument> args,
std::vector<std::size_t> offsets)
{
//migraph::argument& result = args.back();
for(std::size_t l = 0; l < args.size()-1; l++)
// migraph::argument& result = args.back();
for(std::size_t l = 0; l < args.size() - 1; l++)
{
auto argl = args[l];
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(args.back(), argl)([&](auto output, auto input) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto* outptr = output.data() + offsets[l];
const auto* inptr = input.data();
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(nelements)([=](auto i) {
outptr[desc_output.linear(desc_input.multi(i))] = inptr[i];
});
gs_launch(nelements)(
[=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
});
});
}
//return result;
// return result;
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace migraph
......@@ -5,7 +5,8 @@ namespace migraph {
namespace gpu {
namespace device {
argument concat(const shape& output_shape, std::vector<argument> args, std::vector<std::size_t> offsets);
argument
concat(const shape& output_shape, std::vector<argument> args, std::vector<std::size_t> offsets);
} // namespace device
} // namespace gpu
......
......@@ -165,8 +165,8 @@ struct miopen_apply
instruction_ref apply_concat(instruction_ref ins)
{
auto&& op = any_cast<op::concat>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
auto&& op = any_cast<op::concat>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return prog->replace_instruction(ins, hip_concat{op}, refs);
......
......@@ -523,7 +523,7 @@ struct test_concat
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 1;
std::size_t axis = 1;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}};
......@@ -540,7 +540,7 @@ struct test_concat2
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 0;
std::size_t axis = 0;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment