Commit c1d244d9 authored by Paul's avatar Paul
Browse files

Formatting

parent cf8cc835
...@@ -287,9 +287,7 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a ...@@ -287,9 +287,7 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a
{ {
std::size_t nelements = result.get_shape().elements(); std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, args...)([&](auto output, auto... inputs) { hip_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) { output.data()[i] = f(inputs.data()[i]...); });
output.data()[i] = f(inputs.data()[i]...);
});
}); });
} }
......
...@@ -85,7 +85,7 @@ struct hip_array ...@@ -85,7 +85,7 @@ struct hip_array
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y) friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y)
{ {
hip_array result; hip_array result;
for(std::size_t i = 0;i < N;i++) for(std::size_t i = 0; i < N; i++)
result[i] = x[i] * y[i]; result[i] = x[i] * y[i];
return result; return result;
} }
...@@ -167,15 +167,9 @@ struct hip_shape ...@@ -167,15 +167,9 @@ struct hip_shape
std::copy(s.strides().begin(), s.strides().end(), strides.begin()); std::copy(s.strides().begin(), s.strides().end(), strides.begin());
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const { return lens.product(); }
{
return lens.product();
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const { return x.dot(strides); }
{
return x.dot(strides);
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const
{ {
...@@ -234,8 +228,11 @@ struct hip_tensor_view ...@@ -234,8 +228,11 @@ struct hip_tensor_view
MIGRAPHX_DEVICE_CONSTEXPR value_type* data() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR value_type* data() const { return d; }
template<class U> template <class U>
MIGRAPHX_DEVICE_CONSTEXPR value_type& operator[](U i) const { return d[s.index(i)]; } MIGRAPHX_DEVICE_CONSTEXPR value_type& operator[](U i) const
{
return d[s.index(i)];
}
MIGRAPHX_DEVICE_CONSTEXPR value_type* begin() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR value_type* begin() const { return d; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment