ggml-aarch64.c 4.37 KB
Newer Older
xuxzh1's avatar
update  
xuxzh1 committed
1
#define GGML_COMMON_DECL_C
xuxzh1's avatar
init  
xuxzh1 committed
2
3
#include "ggml-common.h"

xuxzh1's avatar
update  
xuxzh1 committed
4
#include "ggml-aarch64.h"
xuxzh1's avatar
init  
xuxzh1 committed
5
#include "ggml-impl.h"
xuxzh1's avatar
update  
xuxzh1 committed
6
#include "ggml-quants.h"
xuxzh1's avatar
init  
xuxzh1 committed
7
8
9
10
#include <assert.h>

#define UNUSED GGML_UNUSED

xuxzh1's avatar
update  
xuxzh1 committed
11
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
xuxzh1's avatar
init  
xuxzh1 committed
12
13
14
15
16
17
    block_q4_0x4 out;

    for (int i = 0; i < 4; i++) {
        out.d[i] = in[i].d;
    }

xuxzh1's avatar
update  
xuxzh1 committed
18
    const int end = QK4_0 * 2 / blck_size_interleave;
xuxzh1's avatar
init  
xuxzh1 committed
19

xuxzh1's avatar
update  
xuxzh1 committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    if (blck_size_interleave == 8) {
        const uint64_t xor_mask = 0x8888888888888888ULL;
        for (int i = 0; i < end; ++i) {
            int src_id = i % 4;
            int src_offset = (i / 4) * blck_size_interleave;
            int dst_offset = i * blck_size_interleave;

            uint64_t elems;
            // Using memcpy to avoid unaligned memory accesses
            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
            elems ^= xor_mask;
            memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
        }
    } else if (blck_size_interleave == 4) {
        const uint32_t xor_mask = 0x88888888;
        for (int i = 0; i < end; ++i) {
            int src_id = i % 4;
            int src_offset = (i / 4) * blck_size_interleave;
            int dst_offset = i * blck_size_interleave;

            uint32_t elems;
            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
            elems ^= xor_mask;
            memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
        }
    } else {
        GGML_ASSERT(false);
xuxzh1's avatar
init  
xuxzh1 committed
47
48
49
50
51
52
53
54
55
    }

    return out;
}

// interleave 8 block_q4_0s in blocks of blck_size_interleave
// returns an interleaved block_q4_0x8
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
xuxzh1's avatar
update  
xuxzh1 committed
56
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
xuxzh1's avatar
init  
xuxzh1 committed
57
58
59
60
61
62
    block_q4_0x8 out;

    for (int i = 0; i < 8; i++) {
        out.d[i] = in[i].d;
    }

xuxzh1's avatar
update  
xuxzh1 committed
63
64
    const int end = QK4_0 * 4 / blck_size_interleave;
    const uint64_t xor_mask = 0x8888888888888888ULL;
xuxzh1's avatar
init  
xuxzh1 committed
65

xuxzh1's avatar
update  
xuxzh1 committed
66
67
68
69
    for (int i = 0; i < end; ++i) {
        int src_id = i % 8;
        int src_offset = (i / 8) * blck_size_interleave;
        int dst_offset = i * blck_size_interleave;
xuxzh1's avatar
init  
xuxzh1 committed
70

xuxzh1's avatar
update  
xuxzh1 committed
71
72
73
74
        uint64_t elems;
        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
        elems ^= xor_mask;
        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
xuxzh1's avatar
init  
xuxzh1 committed
75
76
    }

xuxzh1's avatar
update  
xuxzh1 committed
77
    return out;
xuxzh1's avatar
init  
xuxzh1 committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
}

static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
    assert(n_per_row % QK4_0 == 0);
    const int nb = n_per_row / QK4_0;

    void * out_ptr = NULL;
    if (nrows_interleaved == 8) {
        out_ptr = (block_q4_0x8 *) dst;
    }
    else if (nrows_interleaved == 4) {
        out_ptr = (block_q4_0x4 *) dst;
    }
    assert(nrows_interleaved <= 8);
    block_q4_0 dst_tmp[8];

    for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {

        for (int64_t x = 0; x < nb; x++) {

            for (int i  = 0; i < nrows_interleaved; i++ ) {
                quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
            }

            if (nrows_interleaved == 8) {
xuxzh1's avatar
update  
xuxzh1 committed
103
                *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave);
xuxzh1's avatar
init  
xuxzh1 committed
104
105
106
                out_ptr = (block_q4_0x8 *) out_ptr + 1;
            }
            else if (nrows_interleaved == 4) {
xuxzh1's avatar
update  
xuxzh1 committed
107
                *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave);
xuxzh1's avatar
init  
xuxzh1 committed
108
109
110
111
112
113
114
115
116
                out_ptr = (block_q4_0x4 *) out_ptr + 1;
            }
        }
    }

    return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
}

size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
xuxzh1's avatar
update  
xuxzh1 committed
117
118
    UNUSED(quant_weights);
    return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
xuxzh1's avatar
init  
xuxzh1 committed
119
120
121
}

size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
xuxzh1's avatar
update  
xuxzh1 committed
122
123
    UNUSED(quant_weights);
    return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
xuxzh1's avatar
init  
xuxzh1 committed
124
125
126
}

size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
xuxzh1's avatar
update  
xuxzh1 committed
127
128
    UNUSED(quant_weights);
    return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
xuxzh1's avatar
init  
xuxzh1 committed
129
}