pad.cpp 1.44 KB
Newer Older
1
2
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
3
#include <migraphx/clamp.hpp>
4
5
6
7
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
Khalique's avatar
Khalique committed
8
#include <migraphx/float_equal.hpp>
9
10
11
12
13
14

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

Khalique's avatar
Khalique committed
15
16
argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
17
18
{
    std::size_t nelements = arg1.get_shape().elements();
Paul's avatar
Paul committed
19
    hip_visit_all(result, arg1)([&](auto output, auto input) {
Paul's avatar
Paul committed
20
21
        using type      = typename decltype(output)::value_type;
        using hip_index = typename decltype(output)::hip_index;
22
        type device_val = pad_clamp<host_type<type>>(value);
23
24
        gs_launch(stream, result.get_shape().elements())(
            [=](auto i) __device__ { output.data()[i] = device_val; });
25

Paul's avatar
Paul committed
26
27
        hip_index offsets;
        std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
28
        gs_launch(stream, nelements)([=](auto i) __device__ {
Paul's avatar
Paul committed
29
30
31
32
33
34
            auto idx = input.get_shape().multi(i);
            for(std::size_t j = 0; j < offsets.size(); j++)
            {
                idx[j] += offsets[j];
            }
            output[idx] = input.data()[i];
Paul's avatar
Paul committed
35
        });
36
37
38
39
40
41
42
43
    });
    return result;
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx