amd_inline_asm.hpp 6.2 KB
Newer Older
1
2
3
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP

4
#include "float_type.hpp"
Jing Zhang's avatar
Jing Zhang committed
5

6
7
namespace ck {

8
// outer-product: c[i,j] += inner_product(a[i], b[j])
9
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
10
{
11
12
13
14
15
16
17
18
19
20
21
22
23
// disable inline asm due to the compiler issue: SWDEV-202749
///\to-do: enable the inline asm after the compiler fix
#if CK_WORKAROUND_SWDEV_202749
    c0 += a * b0;
    c1 += a * b1;
#else
    asm volatile("\n \
            v_mac_f32 %0, %2, %3 \n \
            v_mac_f32 %1, %2, %4 \n \
            "
                 : "=v"(c0), "=v"(c1)
                 : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#endif
Chao Liu's avatar
Chao Liu committed
24
25
}

26
// outer-product: c[i,j] += inner_product(a[i], b[j])
27
__device__ void amd_assembly_outer_product_1x4(
28
    float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
Chao Liu's avatar
Chao Liu committed
29
{
Jing Zhang's avatar
Jing Zhang committed
30
31
32
33
34
35
    asm volatile("\n \
            v_mac_f32 %0, %4, %5 \n \
            v_mac_f32 %1, %4, %6 \n \
            v_mac_f32 %2, %4, %7 \n \
            v_mac_f32 %3, %4, %8 \n \
            "
36
37
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
                 : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
Jing Zhang's avatar
Jing Zhang committed
38
39
}

40
// outer-product: c[i,j] += inner_product(a[i], b[j])
41
42
__device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
43
{
44
    asm volatile("\n \
45
46
            v_dot2_f32_f16 %0, %2, %3, %0\n \
            v_dot2_f32_f16 %1, %2, %4, %1\n \
47
48
49
50
51
52
53
            "
                 : "=v"(c0), "=v"(c1) // Dest registers
                 : "v"(a),            // 1st Src register for 1 half2 registers
                   "v"(b0),           // 2nd Src register
                   "v"(b1),
                   "0"(c0), // 3rd Src register
                   "1"(c1));
Jing Zhang's avatar
Jing Zhang committed
54
55
}

56
// outer-product: c[i,j] += inner_product(a[i], b[j])
57
58
__device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
59
{
60
61
62
    const half2_t* p_a_half2  = reinterpret_cast<const half2_t*>(&a);
    const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
    const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
Chao Liu's avatar
Chao Liu committed
63

64
65
    // do dot2 two times
    asm volatile("\n \
66
67
68
69
            v_dot2_f32_f16 %0, %2, %4, %0\n \
            v_dot2_f32_f16 %1, %2, %6, %1\n \
            v_dot2_f32_f16 %0, %3, %5, %0\n \
            v_dot2_f32_f16 %1, %3, %7, %1\n \
70
71
72
73
74
75
76
77
78
79
            "
                 : "=v"(c0), "=v"(c1) // Dest registers
                 : "v"(p_a_half2[0]),
                   "v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
                   "v"(p_b0_half2[0]),
                   "v"(p_b0_half2[1]),
                   "v"(p_b1_half2[0]),
                   "v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
                   "0"(c0),
                   "1"(c1)); // 3rd Src Acc registers for 2 half2 registers
Jing Zhang's avatar
Jing Zhang committed
80
81
}

82
// outer-product: c[i,j] += inner_product(a[i], b[j])
83
84
85
86
87
88
89
90
91
__device__ void amd_assembly_outer_product_1x4(half2_t a,
                                               half2_t b0,
                                               half2_t b1,
                                               half2_t b2,
                                               half2_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
Jing Zhang's avatar
Jing Zhang committed
92
{
93
    asm volatile("\n \
94
95
96
97
            v_dot2_f32_f16 %0, %4, %5, %0\n \
            v_dot2_f32_f16 %1, %4, %6, %1\n \
            v_dot2_f32_f16 %2, %4, %7, %2\n \
            v_dot2_f32_f16 %3, %4, %8, %3\n \
98
99
100
101
102
103
104
105
106
107
108
            "
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
                 : "v"(a),                                // 1st Src register for 1 half2 registers
                   "v"(b0),                               // 2nd Src register
                   "v"(b1),
                   "v"(b2),
                   "v"(b3),
                   "0"(c0), // 3rd Src register
                   "1"(c1),
                   "2"(c2),
                   "3"(c3));
Jing Zhang's avatar
Jing Zhang committed
109
110
}

111
// outer-product: c[i,j] += inner_product(a[i], b[j])
112
113
114
115
116
117
118
119
120
__device__ void amd_assembly_outer_product_1x4(half4_t a,
                                               half4_t b0,
                                               half4_t b1,
                                               half4_t b2,
                                               half4_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
Jing Zhang's avatar
Jing Zhang committed
121
{
122
123
124
125
126
    const half2_t* p_a_half2  = reinterpret_cast<const half2_t*>(&a);
    const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
    const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
    const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2);
    const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3);
Jing Zhang's avatar
Jing Zhang committed
127

128
129
    // do dot2 two times
    asm volatile("\n \
130
131
132
133
134
135
136
137
            v_dot2_f32_f16 %0, %4, %6,  %0\n \
            v_dot2_f32_f16 %1, %4, %8,  %1\n \
            v_dot2_f32_f16 %2, %4, %10, %2\n \
            v_dot2_f32_f16 %3, %4, %12, %3\n \
            v_dot2_f32_f16 %0, %5, %7,  %0\n \
            v_dot2_f32_f16 %1, %5, %9,  %1\n \
            v_dot2_f32_f16 %2, %5, %11, %2\n \
            v_dot2_f32_f16 %3, %5, %13, %3\n \
Jing Zhang's avatar
Jing Zhang committed
138
            "
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
                 : "v"(p_a_half2[0]),
                   "v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
                   "v"(p_b0_half2[0]),
                   "v"(p_b0_half2[1]),
                   "v"(p_b1_half2[0]),
                   "v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
                   "v"(p_b2_half2[0]),
                   "v"(p_b2_half2[1]),
                   "v"(p_b3_half2[0]),
                   "v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers
                   "0"(c0),
                   "1"(c1),
                   "2"(c2),
                   "3"(c3)); // 3rd Src Acc registers for 2 half2 registers
Jing Zhang's avatar
Jing Zhang committed
154
}
155
156
157

} // namespace ck
#endif