Commit b092d017 authored by Paul's avatar Paul
Browse files

Formatting

parent eed607a3
...@@ -16,18 +16,23 @@ argument concat(hipStream_t stream, ...@@ -16,18 +16,23 @@ argument concat(hipStream_t stream,
std::vector<std::size_t> offsets_vec) std::vector<std::size_t> offsets_vec)
{ {
static constexpr const std::size_t limit = 6; static constexpr const std::size_t limit = 6;
if (offsets_vec.size() > limit) if(offsets_vec.size() > limit)
MIGRAPHX_THROW("Too many arguments to concat"); MIGRAPHX_THROW("Too many arguments to concat");
std::size_t nelements = std::max_element(args_vec.begin(), std::prev(args_vec.end()), by(std::less<>{}, [&](auto&& x) { return x.get_shape().elements(); }))->get_shape().elements(); std::size_t nelements =
std::max_element(args_vec.begin(),
std::prev(args_vec.end()),
by(std::less<>{}, [&](auto&& x) { return x.get_shape().elements(); }))
->get_shape()
.elements();
auto offsets = to_hip_vector<limit>(offsets_vec); auto offsets = to_hip_vector<limit>(offsets_vec);
hip_visit_all<limit+1>(args_vec)([&](auto args) { hip_visit_all<limit + 1>(args_vec)([&](auto args) {
auto output = args.back(); auto output = args.back();
auto ninputs = args.size() - 1; auto ninputs = args.size() - 1;
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
for(std::size_t j = 0;j < ninputs;j++) for(std::size_t j = 0; j < ninputs; j++)
{ {
auto&& arg = args[j]; auto&& arg = args[j];
if (i >= arg.size()) if(i >= arg.size())
continue; continue;
auto idx = output.get_shape().index(arg.get_shape().multi(i)); auto idx = output.get_shape().index(arg.get_shape().multi(i));
output.data()[idx + offsets[j]] = arg.data()[i]; output.data()[idx + offsets[j]] = arg.data()[i];
......
...@@ -124,8 +124,7 @@ hip_vector<T, N> to_hip_vector(const std::vector<T>& x) ...@@ -124,8 +124,7 @@ hip_vector<T, N> to_hip_vector(const std::vector<T>& x)
return result; return result;
} }
template <std::size_t N>
template<std::size_t N>
struct hip_shape struct hip_shape
{ {
using hip_index = hip_array<std::size_t, N>; using hip_index = hip_array<std::size_t, N>;
...@@ -136,9 +135,7 @@ struct hip_shape ...@@ -136,9 +135,7 @@ struct hip_shape
__device__ __host__ hip_shape() = default; __device__ __host__ hip_shape() = default;
hip_shape(const shape& s) hip_shape(const shape& s) : elements(s.elements()), standard(s.standard())
: elements(s.elements()),
standard(s.standard())
{ {
assert(s.lens().size() == N); assert(s.lens().size() == N);
assert(s.strides().size() == N); assert(s.strides().size() == N);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment