mul_add.cpp 3.05 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include "migraphx/gpu/device/launch.hpp"
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/amd_detail/amd_hip_runtime.h>
4
#include <migraphx/gpu/device/mul_add.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/gpu/device/nary.hpp>
6
7
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Paul's avatar
Paul committed
8
9
10
11
12
13

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

14
__global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r, int n)
15
{
16
17
18
19
20
21
    int id      = blockDim.x * blockIdx.x + threadIdx.x;
    __half2* ha = reinterpret_cast<__half2*>(a);
    __half2* hb = reinterpret_cast<__half2*>(b);
    __half2* hx = reinterpret_cast<__half2*>(x);
    __half2* hr = reinterpret_cast<__half2*>(r);
    if(id < n)
Shucai Xiao's avatar
Shucai Xiao committed
22
    {
23
        auto id1 = id % dim3;
24
        hr[id]   = __hfma2(ha[id], hx[id1], hb[id1]);
Shucai Xiao's avatar
Shucai Xiao committed
25
    }
26
}
Shucai Xiao's avatar
Shucai Xiao committed
27

28
29
30
__global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int dim4, void* r, int n)
{
    int id      = blockDim.x * blockIdx.x + threadIdx.x;
31
32
33
34
    __half2* ha = reinterpret_cast<__half2*>(a);
    __half2* hb = reinterpret_cast<__half2*>(b);
    __half2* hx = reinterpret_cast<__half2*>(x);
    __half2* hr = reinterpret_cast<__half2*>(r);
35
    if(id < n)
36
    {
37
        int idb = id / (factor * dim4) * dim4 + id % dim4;
38
        hr[id]  = __hfma2(ha[id], hx[id], hb[idb]);
39
40
41
    }
}

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
static bool is_bert(const std::vector<shape>& ss)
{
    auto n_dim = ss.front().lens().size();
    if(n_dim == 3)
    {
        auto stride = ss.at(2).strides();
        return (stride[1] == 0);
    }
    else if(n_dim == 2)
    {
        auto stride1 = ss.at(1).strides();
        auto stride2 = ss.at(2).strides();
        return (stride1 == stride2 and stride1[0] == 0);
    }

    return false;
}
Shucai Xiao's avatar
Shucai Xiao committed
59

Paul's avatar
Paul committed
60
void mul_add(hipStream_t stream,
Paul's avatar
Paul committed
61
62
63
64
             const argument& result,
             const argument& arg1,
             const argument& arg2,
             const argument& arg3)
Paul's avatar
Paul committed
65
{
Shucai Xiao's avatar
Shucai Xiao committed
66
    auto sr   = result.get_shape();
67
    auto type = sr.type();
Shucai Xiao's avatar
Shucai Xiao committed
68

69
70
71
72
    std::vector<shape> ss;
    ss.push_back(arg1.get_shape());
    ss.push_back(arg2.get_shape());
    ss.push_back(arg3.get_shape());
Shucai Xiao's avatar
Shucai Xiao committed
73
74
75
    auto lens    = sr.lens();
    int last_dim = lens.back() / 2;
    auto n_dim   = lens.size();
76
    if(type == shape::half_type and is_bert(ss))
77
    {
Shucai Xiao's avatar
Shucai Xiao committed
78
        auto elem_num  = sr.elements() / 2;
79
        int block_size = 1024;
Shucai Xiao's avatar
Shucai Xiao committed
80
81
        int block_num  = (elem_num + block_size - 1) / block_size;
        if(n_dim == 2)
82
        {
83
            mul_add_kernel_dim3<<<block_num, block_size, 0, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
84
                arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
85
86
87
88
        }
        else
        {
            int factor = lens[1];
89
            mul_add_kernel_dim4<<<block_num, block_size, 0, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
90
                arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
91
        }
92
93
94
95
96
97
    }
    else
    {
        nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
                                                   __device__ { return a * x + b; });
    }
Paul's avatar
Paul committed
98
99
100
101
102
103
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx