mul_add.cpp 3.99 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include "migraphx/gpu/device/launch.hpp"
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/amd_detail/amd_hip_runtime.h>
4
#include <migraphx/gpu/device/mul_add.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/gpu/device/nary.hpp>
6
7
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Paul's avatar
Paul committed
8
9
10
11
12
13

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

14
15
16
17
18
19
20
21
22
23
24
25
26
//__global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int n)
//{
//    int id = blockDim.x * blockIdx.x + threadIdx.x;
//    __half* ha = reinterpret_cast<__half*>(a);
//    __half* hb = reinterpret_cast<__half*>(b);
//    __half* hx = reinterpret_cast<__half*>(x);
//    __half* hr = reinterpret_cast<__half*>(r);
//    if (id < n)
//    {
//        hr[id] = __float2half(__half2float(ha[id]) * __half2float(hx[id]) + __half2float(hb[id]));
//    }
//}

Shucai Xiao's avatar
Shucai Xiao committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// __global__ void mul_add_kernel(void* a, int an, void* x, int xn, void* b, int bn, void* r, int n)
// {
//     int id      = blockDim.x * blockIdx.x + threadIdx.x;
//     __half2* ha = reinterpret_cast<__half2*>(a);
//     __half2* hb = reinterpret_cast<__half2*>(b);
//     __half2* hx = reinterpret_cast<__half2*>(x);
//     __half2* hr = reinterpret_cast<__half2*>(r);
//     if(id < n)
//     {
//         hr[id] = __hadd2(__hmul2(ha[id % an], hx[id % xn]), hb[id % bn]);
//     }
// }

__global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides, int elem_num)
41
{
Shucai Xiao's avatar
Shucai Xiao committed
42
43
44
45
46
47
48
49
    __shared__ int shared_strides[18];
    int tid = threadIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z;
    if (tid < 18)
    {
        shared_strides[tid] = strides[tid];
    }
    __syncthreads();

50
51
52
53
    __half2* ha = reinterpret_cast<__half2*>(a);
    __half2* hb = reinterpret_cast<__half2*>(b);
    __half2* hx = reinterpret_cast<__half2*>(x);
    __half2* hr = reinterpret_cast<__half2*>(r);
Shucai Xiao's avatar
Shucai Xiao committed
54
55
56

    tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) * blockDim.x * blockDim.y * blockDim.z;
    if(tid < elem_num)
57
    {
Shucai Xiao's avatar
Shucai Xiao committed
58
59
60
61
        int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y + shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x + shared_strides[5] * threadIdx.y + threadIdx.z;
        int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y + shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x + shared_strides[11] * threadIdx.y + threadIdx.z;
        int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y + shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x + shared_strides[17] * threadIdx.y + threadIdx.z;
        hr[tid] = __hadd2(__hmul2(ha[tida], hx[tidx]), hb[tidb]);
62
63
64
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
// void mul_add(hipStream_t stream,
//              const argument& result,
//              const argument& arg1,
//              const argument& arg2,
//              const argument& arg3)
// {
//     auto type = result.get_shape().type();
//     if(type == shape::half_type)
//     {
//         std::cout << "case1" << std::endl;
//         mul_add_kernel<<<block_num, block_size>>>(
//             arg1.data(), s1e, arg2.data(), s2e, arg3.data(), s3e, result.data(), elem_num);
//     }
//     else
//     {
//         std::cout << "mul_add" << std::endl;
//         nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
//                                                    __device__ { return a * x + b; });
//     }
// }

Paul's avatar
Paul committed
86
void mul_add(hipStream_t stream,
Paul's avatar
Paul committed
87
88
89
90
             const argument& result,
             const argument& arg1,
             const argument& arg2,
             const argument& arg3)
Paul's avatar
Paul committed
91
{
Shucai Xiao's avatar
Shucai Xiao committed
92
93
94
95
96
97
98
99
100
101
    auto sr = result.get_shape();
    auto s2 = arg2.get_shape();
    auto s3 = arg3.get_shape();

    hip_visit_all(result, arg1, arg2, arg3, sr)([&](auto r, auto i1, auto i2, auto i3, auto dsr) {
        gs_launch(stream, sr.elements())([=](auto i) __device__ {
            auto idx      = dsr.multi(i);
            r[i] = i1[i] * i2[idx] + i3[idx];
        });
    });
Paul's avatar
Paul committed
102
103
104
105
106
107
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx