hip.hpp 1.66 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <migraph/operators.hpp>
Paul's avatar
Paul committed
5

Paul's avatar
Paul committed
6
namespace migraph {
Paul's avatar
Paul committed
7
namespace gpu {
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
migraph::argument allocate_gpu(migraph::shape s);
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
migraph::argument to_gpu(migraph::argument arg);
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
migraph::argument from_gpu(migraph::argument arg);
Paul's avatar
Paul committed
14

mei-ye's avatar
mei-ye committed
15
16
void copy_to_gpu(char* dst, const char* src, std::size_t size);

Paul's avatar
Paul committed
17
18
19
20
21
22
23
24
struct hip_allocate
{
    std::string name() const { return "hip::allocate"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
        return inputs.front();
    }
Paul's avatar
Paul committed
25
    argument compute(context&, shape output_shape, std::vector<argument>) const
Paul's avatar
Paul committed
26
    {
Paul's avatar
Paul committed
27
        return allocate_gpu(output_shape);
Paul's avatar
Paul committed
28
29
30
    }
};

Paul's avatar
Paul committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
struct hip_write
{
    std::string name() const { return "hip::write"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
        return inputs.front();
    }
    argument compute(context&, shape, std::vector<argument> args) const
    {
        return to_gpu(args.front());
    }
};

mei-ye's avatar
mei-ye committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
struct hip_memcpy
{
    std::string name() const { return "hip_memcpy"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        return inputs.at(2);
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const {
        std::size_t * p_data = reinterpret_cast<std::size_t*>(args.at(1).data());
        char* dst = args.at(0).data() + p_data[0];
        const char* src = args.at(2).data();
        std::size_t size = args.at(2).get_shape().bytes();
        copy_to_gpu(dst, src, size);
        return {output_shape, dst};
    }
};        

Paul's avatar
Paul committed
62
} // namespace gpu
Paul's avatar
Paul committed
63

Paul's avatar
Paul committed
64
} // namespace migraph
Paul's avatar
Paul committed
65
66

#endif