mul_add.cpp 3.16 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include "migraphx/gpu/device/launch.hpp"
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/amd_detail/amd_hip_runtime.h>
4
#include <migraphx/gpu/device/mul_add.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/gpu/device/nary.hpp>
6
7
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Paul's avatar
Paul committed
8
9
10
11
12
13

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

14
__global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r, int n)
15
{
16
17
18
19
20
21
    int id      = blockDim.x * blockIdx.x + threadIdx.x;
    __half2* ha = reinterpret_cast<__half2*>(a);
    __half2* hb = reinterpret_cast<__half2*>(b);
    __half2* hx = reinterpret_cast<__half2*>(x);
    __half2* hr = reinterpret_cast<__half2*>(r);
    if(id < n)
Shucai Xiao's avatar
Shucai Xiao committed
22
    {
23
        auto id1 = id % dim3;
24
        hr[id]   = __hfma2(ha[id], hx[id1], hb[id1]);
Shucai Xiao's avatar
Shucai Xiao committed
25
    }
26
}
Shucai Xiao's avatar
Shucai Xiao committed
27

28
29
30
__global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int dim4, void* r, int n)
{
    int id      = blockDim.x * blockIdx.x + threadIdx.x;
31
32
33
34
    __half2* ha = reinterpret_cast<__half2*>(a);
    __half2* hb = reinterpret_cast<__half2*>(b);
    __half2* hx = reinterpret_cast<__half2*>(x);
    __half2* hr = reinterpret_cast<__half2*>(r);
35
    if(id < n)
36
    {
37
        int idb = id / (factor * dim4) * dim4 + id % dim4;
38
        hr[id]  = __hfma2(ha[id], hx[id], hb[idb]);
39
40
41
    }
}

42
static bool is_bert(const std::vector<shape>& ss)
Shucai Xiao's avatar
Shucai Xiao committed
43
{
44
    auto last_dim = ss.front().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
45
    if(last_dim % 2 != 0)
46
47
48
49
    {
        return false;
    }

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    auto n_dim = ss.front().lens().size();
    if(n_dim == 3)
    {
        auto stride = ss.at(2).strides();
        return (stride[1] == 0);
    }
    else if(n_dim == 2)
    {
        auto stride1 = ss.at(1).strides();
        auto stride2 = ss.at(2).strides();
        return (stride1 == stride2 and stride1[0] == 0);
    }

    return false;
}
Shucai Xiao's avatar
Shucai Xiao committed
65

Paul's avatar
Paul committed
66
void mul_add(hipStream_t stream,
Paul's avatar
Paul committed
67
68
69
70
             const argument& result,
             const argument& arg1,
             const argument& arg2,
             const argument& arg3)
Paul's avatar
Paul committed
71
{
Shucai Xiao's avatar
Shucai Xiao committed
72
    auto sr   = result.get_shape();
73
    auto type = sr.type();
Shucai Xiao's avatar
Shucai Xiao committed
74

75
76
77
78
    std::vector<shape> ss;
    ss.push_back(arg1.get_shape());
    ss.push_back(arg2.get_shape());
    ss.push_back(arg3.get_shape());
Shucai Xiao's avatar
Shucai Xiao committed
79
80
81
    auto lens    = sr.lens();
    int last_dim = lens.back() / 2;
    auto n_dim   = lens.size();
82
    if(type == shape::half_type and is_bert(ss))
83
    {
Shucai Xiao's avatar
Shucai Xiao committed
84
        auto elem_num  = sr.elements() / 2;
85
        int block_size = 1024;
Shucai Xiao's avatar
Shucai Xiao committed
86
87
        int block_num  = (elem_num + block_size - 1) / block_size;
        if(n_dim == 2)
88
        {
89
            mul_add_kernel_dim3<<<block_num, block_size, 0, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
90
                arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
91
92
93
94
        }
        else
        {
            int factor = lens[1];
95
            mul_add_kernel_dim4<<<block_num, block_size, 0, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
96
                arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
97
        }
98
99
100
101
102
103
    }
    else
    {
        nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
                                                   __device__ { return a * x + b; });
    }
Paul's avatar
Paul committed
104
105
106
107
108
109
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx