matrix_view.cuh 5.05 KB
Newer Older
CHU Tianxiang's avatar
CHU Tianxiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/

#ifndef _matrix_view_cuh
#define _matrix_view_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>

#include "qdq_util.cuh"

namespace vllm {
namespace gptq {

class MatrixView_half
{
public:
    const half* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
    __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
    __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }

    __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
    {
        half2* ptr = (half2*) item_ptr(row, column);
        half2 i01 = ptr[0];
        half2 i23 = ptr[1];
        items[0] = __low2half(i01);
        items[1] = __high2half(i01);
        items[2] = __low2half(i23);
        items[3] = __high2half(i23);
    }
    __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
    {
        half2* ptr = (half2*)item_ptr(row, column);
        half2 i01 = ptr[0];
        half2 i23 = ptr[1];
        items[0] = __half2float(__low2half(i01));
        items[1] = __half2float(__high2half(i01));
        items[2] = __half2float(__low2half(i23));
        items[3] = __half2float(__high2half(i23));
    }

    __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
    {
        half2* ptr = (half2*)item_ptr(row, column);
        half2 i01 = ptr[0];
        half2 i23 = ptr[1];
        items[0] = __half2half2(__low2half(i01));
        items[1] = __half2half2(__high2half(i01));
        items[2] = __half2half2(__low2half(i23));
        items[3] = __half2half2(__high2half(i23));
    }
};

class MatrixView_half_rw
{
public:
    half* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
    __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
    __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
    __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }

    __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
    {
        half2 v01 = __halves2half2(v0, v1);
        half2 v23 = __halves2half2(v2, v3);
        half2* ptr = (half2*) item_ptr(row, column);
        ptr[0] = v01;
        ptr[1] = v23;
    }
};

class MatrixView_q4_row
{
public:
    const uint32_t* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ int item(int row, int column) const
    {
        int shift = (column & 0x07) * 4;
        return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
    }

    __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
    {
        int shift = (column & 0x07) * 4;
        uint32_t d = data[row * width / 8 + column / 8] >> shift;
        items[0] = d & 0x0f;
        items[1] = (d >> 4) & 0x0f;
    }

    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
    {
        int shift = (column & 0x07) * 4;
        uint32_t d = data[row * width / 8 + column / 8] >> shift;
        items[0] = d & 0x0f;
        items[1] = (d >> 4) & 0x0f;
        items[2] = (d >> 8) & 0x0f;
        items[3] = (d >> 12) & 0x0f;
    }
};

class MatrixView_q4_column
{
public:
    const uint32_t* data;
    const int height;
    const int width;

    __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
        : data(data), height(height), width(width)
    { }

    __device__ __forceinline__ int item(int row, int column) const
    {
        int shift = (row & 0x07) * 4;
        return (data[row / 8 * width + column] >> shift) & 0x0f;
    }

    __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
    __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};

}  // namespace gptq
}  // namespace vllm
#endif