matrix_view.cuh 9.13 KB
Newer Older
nicodafagood's avatar
nicodafagood committed
1
2
3
4
5
6
7
8
9
10
11
12
13
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/

#ifndef _matrix_view_cuh
#define _matrix_view_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>

#include "qdq_util.cuh"

namespace vllm {
nicodafagood's avatar
nicodafagood committed
14
namespace mygq {
nicodafagood's avatar
nicodafagood committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

class MatrixView_half
{
public:
    const half* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
    __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
    __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }

    __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
    {
        half2* ptr = (half2*) item_ptr(row, column);
        half2 i01 = ptr[0];
        half2 i23 = ptr[1];
        items[0] = __low2half(i01);
        items[1] = __high2half(i01);
        items[2] = __low2half(i23);
        items[3] = __high2half(i23);
    }
    __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
    {
        half2* ptr = (half2*)item_ptr(row, column);
        half2 i01 = ptr[0];
        half2 i23 = ptr[1];
        items[0] = __half2float(__low2half(i01));
        items[1] = __half2float(__high2half(i01));
        items[2] = __half2float(__low2half(i23));
        items[3] = __half2float(__high2half(i23));
    }

    __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
    {
        half2* ptr = (half2*)item_ptr(row, column);
        half2 i01 = ptr[0];
        half2 i23 = ptr[1];
        items[0] = __half2half2(__low2half(i01));
        items[1] = __half2half2(__high2half(i01));
        items[2] = __half2half2(__low2half(i23));
        items[3] = __half2half2(__high2half(i23));
    }
};

class MatrixView_half_rw
{
public:
    half* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
    __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
    __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
    __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }

    __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
    {
        half2 v01 = __halves2half2(v0, v1);
        half2 v23 = __halves2half2(v2, v3);
        half2* ptr = (half2*) item_ptr(row, column);
        ptr[0] = v01;
        ptr[1] = v23;
    }
};

class MatrixView_q4_row
{
public:
    const uint32_t* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ int item(int row, int column) const
    {
        int shift = (column & 0x07) * 4;
        return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
    }

    __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
    {
        int shift = (column & 0x07) * 4;
        uint32_t d = data[row * width / 8 + column / 8] >> shift;
        items[0] = d & 0x0f;
        items[1] = (d >> 4) & 0x0f;
    }

    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
    {
        int shift = (column & 0x07) * 4;
        uint32_t d = data[row * width / 8 + column / 8] >> shift;
        items[0] = d & 0x0f;
        items[1] = (d >> 4) & 0x0f;
        items[2] = (d >> 8) & 0x0f;
        items[3] = (d >> 12) & 0x0f;
    }
};

class MatrixView_q4_column
{
public:
    const uint32_t* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ int item(int row, int column) const
    {
        int shift = (row & 0x07) * 4;
        return (data[row / 8 * width + column] >> shift) & 0x0f;
    }

    __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
    __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};

class MatrixView_q2_row
{
public:
    const uint32_t* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ int item(int row, int column) const
    {
        int shift = (column & 0x0f) * 2;
        return (data[row * width / 16 + column / 16] >> shift) & 0x03;
    }

    __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
    {
        int shift = (column & 0x0f) * 2;
        uint32_t d = data[row * width / 16 + column / 16] >> shift;
        items[0] = d & 0x03;
        items[1] = (d >> 2) & 0x03;
    }

    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
    {
        int shift = (column & 0x0f) * 2;
        uint32_t d = data[row * width / 16 + column / 16] >> shift;
        items[0] = d & 0x03;
        items[1] = (d >> 2) & 0x03;
        items[2] = (d >> 4) & 0x03;
        items[3] = (d >> 6) & 0x03;
    }
};

class MatrixView_q3_row
{
public:
    const uint32_t* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ int item(int row, int column) const
    {
        int z_w = column * 3 / 32;
        int z_mod =  column & 0x1f;

        if (z_mod == 10) {
            return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
        } else if (z_mod == 21) {
            return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
        } else if (z_mod < 10) {
            return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
        } else if (z_mod < 21) {
            return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3  - 32)) & 0x07;
        } else {
            return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3  - 64)) & 0x07;
        }
    }

    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
    {
        int shift = (column & 0x1f);
        uint32_t d;
        if (shift <= 4) {
            d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
        } else if (shift == 8) {
            d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
        } else if (shift <= 16) {
            d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
        } else if (shift == 20) {
            d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
        } else {
            d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
        }
        items[0] = d & 0x07;
        items[1] = (d >> 3) & 0x07;
        items[2] = (d >> 6) & 0x07;
        items[3] = (d >> 9) & 0x07;
    }
};

class MatrixView_q8_row
{
public:
    const uint32_t* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ int item(int row, int column) const
    {
        int shift = (column & 0x03) * 8;
        return (data[row * width / 4 + column / 4] >> shift) & 0xff;
    }

    __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
    {
        int shift = (column & 0x03) * 8;
        uint32_t d = data[row * width / 4 + column / 4] >> shift;
        items[0] = d & 0xff;
        items[1] = (d >> 8) & 0xff;
    }

    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
    {
        int shift = (column & 0x03) * 2;
        uint32_t d = data[row * width / 4 + column / 4] >> shift;
        items[0] = d & 0xff;
        items[1] = (d >> 8) & 0xff;
        items[2] = (d >> 16) & 0xff;
        items[3] = (d >> 24) & 0xff;
    }
};

nicodafagood's avatar
nicodafagood committed
272
}  // namespace mygq
nicodafagood's avatar
nicodafagood committed
273
274
}  // namespace vllm
#endif