Commit 3de56715 authored by wsttiger's avatar wsttiger
Browse files

Formatting

parent 447eec93
...@@ -8,14 +8,14 @@ ...@@ -8,14 +8,14 @@
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
shape hip_concat::compute_shape(std::vector<shape> inputs) const shape hip_concat::compute_shape(std::vector<shape> inputs) const
{ {
inputs.pop_back(); inputs.pop_back();
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
std::vector<std::size_t> hip_concat::compute_offsets(const shape& output_shape, std::vector<std::size_t> hip_concat::compute_offsets(const shape& output_shape,
const std::vector<argument> args) const const std::vector<argument> args) const
{ {
std::vector<std::size_t> offsets; std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0); std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
...@@ -29,8 +29,8 @@ std::vector<std::size_t> hip_concat::compute_offsets(const shape& output_shape, ...@@ -29,8 +29,8 @@ std::vector<std::size_t> hip_concat::compute_offsets(const shape& output_shape,
} }
argument hip_concat::compute(context& ctx, argument hip_concat::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
std::vector<std::size_t> offsets = compute_offsets(output_shape, args); std::vector<std::size_t> offsets = compute_offsets(output_shape, args);
return device::concat(output_shape, args, offsets); return device::concat(output_shape, args, offsets);
......
...@@ -8,30 +8,30 @@ namespace migraph { ...@@ -8,30 +8,30 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument concat(const migraph::shape& output_shape, std::vector<migraph::argument> args, std::vector<std::size_t> offsets) argument concat(const migraph::shape& output_shape,
std::vector<migraph::argument> args,
std::vector<std::size_t> offsets)
{ {
//migraph::argument& result = args.back(); // migraph::argument& result = args.back();
for(std::size_t l = 0; l < args.size()-1; l++) for(std::size_t l = 0; l < args.size() - 1; l++)
{ {
auto argl = args[l]; auto argl = args[l];
std::size_t nelements = argl.get_shape().elements(); std::size_t nelements = argl.get_shape().elements();
visit_all(args.back(), argl)([&](auto output, auto input) { visit_all(args.back(), argl)([&](auto output, auto input) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto* outptr = output.data() + offsets[l]; auto* outptr = output.data() + offsets[l];
const auto* inptr = input.data(); const auto* inptr = input.data();
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(nelements)([=](auto i) { gs_launch(nelements)(
outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; [=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
});
}); });
}); });
} }
//return result; // return result;
return args.back(); return args.back();
} }
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
...@@ -5,7 +5,8 @@ namespace migraph { ...@@ -5,7 +5,8 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument concat(const shape& output_shape, std::vector<argument> args, std::vector<std::size_t> offsets); argument
concat(const shape& output_shape, std::vector<argument> args, std::vector<std::size_t> offsets);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -165,8 +165,8 @@ struct miopen_apply ...@@ -165,8 +165,8 @@ struct miopen_apply
instruction_ref apply_concat(instruction_ref ins) instruction_ref apply_concat(instruction_ref ins)
{ {
auto&& op = any_cast<op::concat>(ins->get_operator()); auto&& op = any_cast<op::concat>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output); refs.push_back(output);
return prog->replace_instruction(ins, hip_concat{op}, refs); return prog->replace_instruction(ins, hip_concat{op}, refs);
......
...@@ -523,7 +523,7 @@ struct test_concat ...@@ -523,7 +523,7 @@ struct test_concat
migraph::program create_program() const migraph::program create_program() const
{ {
migraph::program p; migraph::program p;
std::size_t axis = 1; std::size_t axis = 1;
migraph::shape s0{migraph::shape::int32_type, {2, 2}}; migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}}; migraph::shape s1{migraph::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}}; migraph::shape s2{migraph::shape::int32_type, {2, 1}};
...@@ -540,7 +540,7 @@ struct test_concat2 ...@@ -540,7 +540,7 @@ struct test_concat2
migraph::program create_program() const migraph::program create_program() const
{ {
migraph::program p; migraph::program p;
std::size_t axis = 0; std::size_t axis = 0;
migraph::shape s0{migraph::shape::int32_type, {2, 2}}; migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}}; migraph::shape s1{migraph::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}}; migraph::shape s2{migraph::shape::int32_type, {1, 2}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment