Commit b171e9ad authored by Paul's avatar Paul
Browse files

Refactor pad

parent 1599d553
......@@ -13,6 +13,7 @@ template <class T, std::size_t N>
struct hip_tensor_view
{
using value_type = T;
using hip_index = typename hip_shape<N>::hip_index;
__device__ __host__ hip_tensor_view() = default;
__host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
__host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {}
......
......@@ -15,34 +15,26 @@ argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{
std::size_t nelements = arg1.get_shape().elements();
visit_all(result)([&](auto output) {
auto* outptr = device_cast(output.data());
hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
device_type<type> device_val = value;
using hip_index = typename decltype(output)::hip_index;
type device_val = value;
if(float_equal(value, std::numeric_limits<float>::lowest()))
{
device_val = device_cast(std::numeric_limits<type>::lowest());
}
gs_launch(stream, result.get_shape().elements())([=](auto i) { outptr[i] = device_val; });
});
gs_launch(stream, result.get_shape().elements())([=](auto i) { output.data()[i] = device_val; });
visit_all(result, arg1)([&](auto output, auto input) {
visit_tensor_size(result.get_shape().lens().size(), [&](auto ndim) {
std::size_t offsets[ndim];
std::copy(pads.begin(), pads.begin() + ndim, offsets);
auto* outptr = output.data();
const auto* inptr = input.data();
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto idx = desc_input.multi(i);
for(std::size_t j = 0; j < ndim; j++)
{
idx[j] += offsets[j];
}
outptr[desc_output.linear(idx)] = inptr[i];
});
});
hip_index offsets;
std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
gs_launch(stream, nelements)([=](auto i) {
auto idx = input.get_shape().multi(i);
for(std::size_t j = 0; j < offsets.size(); j++)
{
idx[j] += offsets[j];
}
output[idx] = input.data()[i];
});
});
return result;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment