mul_add.cpp 3.1 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/device/launch.hpp>
2
#include <migraphx/gpu/device/mul_add.hpp>
Paul's avatar
Paul committed
3
#include <migraphx/gpu/device/nary.hpp>
4
5
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Paul's avatar
Paul committed
6
#include <hip/math_functions.h>
Paul's avatar
Paul committed
7
8
9
10
11
12

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

13
__global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r, int n)
14
{
15
16
17
18
19
20
    int id      = blockDim.x * blockIdx.x + threadIdx.x;
    __half2* ha = reinterpret_cast<__half2*>(a);
    __half2* hb = reinterpret_cast<__half2*>(b);
    __half2* hx = reinterpret_cast<__half2*>(x);
    __half2* hr = reinterpret_cast<__half2*>(r);
    if(id < n)
Shucai Xiao's avatar
Shucai Xiao committed
21
    {
22
        auto id1 = id % dim3;
23
        hr[id]   = __hfma2(ha[id], hx[id1], hb[id1]);
Shucai Xiao's avatar
Shucai Xiao committed
24
    }
25
}
Shucai Xiao's avatar
Shucai Xiao committed
26

27
28
29
__global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int dim4, void* r, int n)
{
    int id      = blockDim.x * blockIdx.x + threadIdx.x;
30
31
32
33
    __half2* ha = reinterpret_cast<__half2*>(a);
    __half2* hb = reinterpret_cast<__half2*>(b);
    __half2* hx = reinterpret_cast<__half2*>(x);
    __half2* hr = reinterpret_cast<__half2*>(r);
34
    if(id < n)
35
    {
36
        int idb = id / (factor * dim4) * dim4 + id % dim4;
37
        hr[id]  = __hfma2(ha[id], hx[id], hb[idb]);
38
39
40
    }
}

41
static bool is_bert(const std::vector<shape>& ss)
Shucai Xiao's avatar
Shucai Xiao committed
42
{
43
    auto last_dim = ss.front().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
44
    if(last_dim % 2 != 0)
45
46
47
48
    {
        return false;
    }

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    auto n_dim = ss.front().lens().size();
    if(n_dim == 3)
    {
        auto stride = ss.at(2).strides();
        return (stride[1] == 0);
    }
    else if(n_dim == 2)
    {
        auto stride1 = ss.at(1).strides();
        auto stride2 = ss.at(2).strides();
        return (stride1 == stride2 and stride1[0] == 0);
    }

    return false;
}
Shucai Xiao's avatar
Shucai Xiao committed
64

Paul's avatar
Paul committed
65
void mul_add(hipStream_t stream,
Paul's avatar
Paul committed
66
67
68
69
             const argument& result,
             const argument& arg1,
             const argument& arg2,
             const argument& arg3)
Paul's avatar
Paul committed
70
{
Shucai Xiao's avatar
Shucai Xiao committed
71
    auto sr   = result.get_shape();
72
    auto type = sr.type();
Shucai Xiao's avatar
Shucai Xiao committed
73

74
75
76
77
    std::vector<shape> ss;
    ss.push_back(arg1.get_shape());
    ss.push_back(arg2.get_shape());
    ss.push_back(arg3.get_shape());
Shucai Xiao's avatar
Shucai Xiao committed
78
79
80
    auto lens    = sr.lens();
    int last_dim = lens.back() / 2;
    auto n_dim   = lens.size();
81
    if(type == shape::half_type and is_bert(ss))
82
    {
Shucai Xiao's avatar
Shucai Xiao committed
83
        auto elem_num  = sr.elements() / 2;
84
        int block_size = 1024;
Shucai Xiao's avatar
Shucai Xiao committed
85
86
        int block_num  = (elem_num + block_size - 1) / block_size;
        if(n_dim == 2)
87
        {
88
            mul_add_kernel_dim3<<<block_num, block_size, 0, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
89
                arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
90
91
92
93
        }
        else
        {
            int factor = lens[1];
94
            mul_add_kernel_dim4<<<block_num, block_size, 0, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
95
                arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
96
        }
97
98
99
100
101
102
    }
    else
    {
        nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
                                                   __device__ { return a * x + b; });
    }
Paul's avatar
Paul committed
103
104
105
106
107
108
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx