Commit 46653e67 authored by Paul's avatar Paul
Browse files

Formatting

parent b171e9ad
...@@ -13,7 +13,7 @@ template <class T, std::size_t N> ...@@ -13,7 +13,7 @@ template <class T, std::size_t N>
struct hip_tensor_view struct hip_tensor_view
{ {
using value_type = T; using value_type = T;
using hip_index = typename hip_shape<N>::hip_index; using hip_index = typename hip_shape<N>::hip_index;
__device__ __host__ hip_tensor_view() = default; __device__ __host__ hip_tensor_view() = default;
__host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {} __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
__host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {} __host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {}
......
...@@ -16,14 +16,15 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector ...@@ -16,14 +16,15 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
{ {
std::size_t nelements = arg1.get_shape().elements(); std::size_t nelements = arg1.get_shape().elements();
hip_visit_all(result, arg1)([&](auto output, auto input) { hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
using hip_index = typename decltype(output)::hip_index; using hip_index = typename decltype(output)::hip_index;
type device_val = value; type device_val = value;
if(float_equal(value, std::numeric_limits<float>::lowest())) if(float_equal(value, std::numeric_limits<float>::lowest()))
{ {
device_val = device_cast(std::numeric_limits<type>::lowest()); device_val = device_cast(std::numeric_limits<type>::lowest());
} }
gs_launch(stream, result.get_shape().elements())([=](auto i) { output.data()[i] = device_val; }); gs_launch(stream,
result.get_shape().elements())([=](auto i) { output.data()[i] = device_val; });
hip_index offsets; hip_index offsets;
std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin()); std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
...@@ -34,7 +35,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector ...@@ -34,7 +35,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
idx[j] += offsets[j]; idx[j] += offsets[j];
} }
output[idx] = input.data()[i]; output[idx] = input.data()[i];
}); });
}); });
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment