format.cu 4.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Copyright (c) OpenMMLab. All rights reserved.

#include "common.h"
#include <iostream>

namespace turbomind {

__device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value)
{
    uint32_t old = *address;
    uint32_t assumed;
    do {
        assumed      = old;
        uint32_t tmp = (assumed & ~(0xfu << (index * 4u))) | (value << (index * 4u));
        old          = atomicCAS(address, assumed, tmp);
    } while (assumed != old);
}

__device__ uint32_t read_u4(const uint32_t* address, uint32_t index)
{
    return (*address >> (index * 4u)) & 0xfu;
}

template<int... Ds>
__global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)> dims)
{
    constexpr int N = sizeof...(Ds);

    size_t count = 1;
    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        count *= dims[i];
    }

    constexpr int order[] = {Ds...};

    for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {

        int indices[N]{};

        PRAGMA_UNROLL
        for (int j = N - 1, ii = i; j >= 0; --j) {
            indices[j] = ii % dims[j];
            ii /= dims[j];
        }

        auto data = read_u4(src + i / 8, i % 8);

        int index = 0;

        PRAGMA_UNROLL
        for (int j = N - 1, stride = 1; j >= 0; --j) {
            index += indices[order[j]] * stride;
            stride *= dims[order[j]];
        }

        atomic_assign_u4(dst + index / 8, index % 8, data);
    }
}

void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
    // permutation for [k/8, m] layout
    Array<int, 10> shape{k / 32, 2, 2, m / 32, 2, 2, 8, 2, 2, 2};
    //        |warp|  lane  | 2x2 |  a0-7  |
    permute_u4<0, 3, 6, 8, 9, 1, 4, 7, 2, 5><<<512, 512, 0, st>>>(dst, src, shape);
}

void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
    // permutation for [k, m/8] layout
    Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
    //        |warp|  lane  | 2x2 |  a0-7  |
    permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
}

__global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
{
    for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
        dst[i] = dequantize_s4_to_fp16x2_v2(src[i]);
    }
}

__global__ void merge_Q(half2* Q, const half* scales, const half* zeros, int count)
{
    for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
        if (TURBOMIND_S4_DEQUANT_USE_FMA) {
            // dequant via HFMA2 has numerical statbility issue
            Q[i] = __halves2half2(-zeros[i] * scales[i], scales[i]);
        }
        else {
            Q[i] = __halves2half2(zeros[i], scales[i]);
        }
    }
}

void convert_s4_k_m8(uint32_t*       A_dst,
                     half2*          Q_dst,
                     half*           workspace,
                     const uint32_t* A_src,
                     const half*     scales,
                     const uint32_t* qzeros,
                     int             m,
                     int             k,
                     int             group_size,
                     cudaStream_t    st)
{
    dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);

    merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);

    reformat_s4_k_m8(A_dst, A_src, m, k, st);
}

void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
{
    Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
    //      dequant   transpose    quant
    // 0123456 -> 0123564 -> 0135642 -> 0135264
    permute_u4<0, 1, 3, 5, 2, 6, 4><<<512, 512, 0, st>>>(dst, src, shape);
}

// [2, k, m/8] -> [k, m/8, 2]
void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
    Array<int, 6> shape{2, k, m / 8, 2, 2, 2};
    //     dequant   transpose   quant
    // 012345 -> 012453 -> 124530 -> 124053
    permute_u4<1, 2, 4, 0, 5, 3><<<512, 512, 0, st>>>(dst, src, shape);
}

__global__ void dequantize_s4_kernel(uint4* dst, const uint* src, size_t count)
{
    for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
        dst[i] = dequantize_s4_to_fp16x2(src[i]);
    }
}

void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st)
{
    dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
}

}  // namespace turbomind