pad.cpp 2.39 KB
Newer Older
1
2
3
4
5
6
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
Khalique's avatar
Khalique committed
7
#include <migraphx/float_equal.hpp>
8
9
10
11
12
13

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

Khalique's avatar
Khalique committed
14
15
argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
16
17
{
    std::size_t nelements = arg1.get_shape().elements();
Khalique's avatar
Khalique committed
18
    if(float_equal(value, std::numeric_limits<float>::lowest()))
Khalique's avatar
Khalique committed
19
    {
Khalique's avatar
Khalique committed
20
        auto val = device_cast(std::numeric_limits<decltype(value)>::lowest());
Khalique's avatar
Khalique committed
21
22
23
24
        nary(stream, result)([=] { return val; });
        // visit_all(result)([&](auto output) {
        //     auto* outptr = device_cast(output.data());
        //     auto val =
Khalique's avatar
Khalique committed
25
26
        //         device_cast(std::numeric_limits<typename
        //         decltype(output)::value_type>::lowest());
Khalique's avatar
Khalique committed
27

Khalique's avatar
Khalique committed
28
29
        //     gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
        // });
Khalique's avatar
Khalique committed
30
    }
Khalique's avatar
Khalique committed
31

Khalique's avatar
Khalique committed
32
33
    else
    {
Khalique's avatar
Khalique committed
34
35
36
37
38
39
40
41
        // visit_all(result)([&](auto output) {
        //     auto* outptr = device_cast(output.data());
        //     auto val =
        //         device_cast(value);
        //     gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
        // });'
        auto val = device_cast(value);
        nary(stream, result)([=] { return val; });
Khalique's avatar
Khalique committed
42
    }
43

Khalique's avatar
Khalique committed
44
    // nary(stream, result)([=] { return value; });
45
46
47
48
49
50
51
52
    visit_all(result, arg1)([&](auto output, auto input) {
        visit_tensor_size(result.get_shape().lens().size(), [&](auto ndim) {
            std::size_t offsets[ndim];
            std::copy(pads.begin(), pads.begin() + ndim, offsets);
            auto* outptr      = output.data();
            const auto* inptr = input.data();
            hip_tensor_descriptor<ndim> desc_input(input.get_shape());
            hip_tensor_descriptor<ndim> desc_output(output.get_shape());
Khalique's avatar
Khalique committed
53
54
55
56
57
58
59
60
            gs_launch(stream, nelements)([=](auto i) {
                auto idx = desc_input.multi(i);
                for(std::size_t j = 0; j < ndim; j++)
                {
                    idx[j] += offsets[j];
                }
                outptr[desc_output.linear(idx)] = inptr[i];
            });
61
62
63
64
65
66
67
68
69
        });
    });
    return result;
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx