"vscode:/vscode.git/clone" did not exist on "b4553de518104ed34d0bf683007042cf3a7eabf9"
Commit 18c42e67 authored by chenxl's avatar chenxl
Browse files

Initial commit

parents
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_arm82.cpp
// Copyrigth 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __aarch64__
#define iqk_mul_mat iqk_mul_mat_arm82
#define iqk_mul_mat_moe iqk_mul_mat_moe_arm82
#include "iqk_mul_mat.inc"
#endif // __aarch64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/macros.h
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#pragma once
#define MIN(X, Y) ((Y) > (X) ? (X) : (Y))
#define MAX(X, Y) ((Y) < (X) ? (X) : (Y))
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
#define ROUNDUP(X, K) (((X) + (K) - 1) & -(K))
#define ARRAYLEN(A) ((sizeof(A) / sizeof(*(A))) / ((unsigned)!(sizeof(A) % sizeof(*(A)))))
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/micros.h
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#pragma once
#include <ctime>
#ifndef _WIN32
#include <unistd.h>
#else
#include <windows.h>
#endif
#ifdef _WIN32
static long long GetQueryPerformanceFrequency() {
LARGE_INTEGER t;
QueryPerformanceFrequency(&t);
return t.QuadPart;
}
static long long GetQueryPerformanceCounter() {
LARGE_INTEGER t;
QueryPerformanceCounter(&t);
return t.QuadPart;
}
#endif
static long long micros(void) {
#ifndef _WIN32
struct timespec ts;
clock_gettime(CLOCK_REALTIME, &ts);
return ts.tv_sec * 1000000 + (ts.tv_nsec + 999) / 1000;
#else
static long long timer_freq = GetQueryPerformanceFrequency();
static long long timer_start = GetQueryPerformanceCounter();
return ((GetQueryPerformanceCounter() - timer_start) * 1000000) / timer_freq;
#endif
}
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/numba.h
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#pragma once
inline int rand32(void) {
static unsigned long long lcg = 1;
lcg *= 6364136223846793005;
lcg += 1442695040888963407;
return lcg >> 32;
}
inline int popcount(unsigned x) {
x = x - ((x >> 1) & 0x55555555);
x = ((x >> 2) & 0x33333333) + (x & 0x33333333);
x = (x + (x >> 4)) & 0x0F0F0F0F;
x = (x + (x >> 16));
return (x + (x >> 8)) & 0x0000003F;
}
inline int hamming(int x, int y) {
return popcount(x ^ y);
}
inline float float01(unsigned x) { // (0,1)
return 1.f / 8388608 * ((x >> 9) + .5f);
}
inline float numba(void) { // (-10,10)
return float01(rand32()) * 2.f - 1.f;
}
template <typename T>
void randomize(T* A, int n) {
for (int i = 0; i < n; ++i)
A[i] = numba();
}
template <typename T>
void randomize(int m, int n, T* A, int lda) {
for (int j = 0; j < n; ++j)
for (int i = 0; i < m; ++i)
A[lda * j + i] = numba();
}
template <typename T, typename U>
void broadcast(T* A, int n, U x) {
for (int i = 0; i < n; ++i)
A[i] = x;
}
template <typename T, typename U>
void broadcast(int m, int n, T* A, int lda, U x) {
for (int j = 0; j < n; ++j)
for (int i = 0; i < m; ++i)
A[lda * j + i] = x;
}
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "sgemm.h"
// #include <cosmo.h>
#include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h>
#include <stdio.h>
#include <sys/auxv.h>
#include <cassert>
// #include "llamafile.h"
static const struct GemmFuncs {
typeof(llamafile_sgemm)* sgemm;
typeof(llamafile_mixmul)* mixmul;
typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() {
#ifdef __x86_64__
// if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) {
// if (X86_HAVE(AVX512F)) {
// if (X86_HAVE(AVX512VL) && //
// X86_HAVE(AVX512BW) && //
// X86_HAVE(AVX512DQ) && //
// X86_HAVE(AVX512_VNNI) && //
// X86_HAVE(AVX512_BF16)) {
// // AMD Zen4+ (2023-)
// sgemm = llamafile_sgemm_amd_zen4;
// mixmul = llamafile_mixmul_amd_zen4;
// iqk_mixmul = iqk_mul_mat_moe_zen4;
// } else {
// // Intel Xeon Skylake+ (2015-)
// sgemm = llamafile_sgemm_amd_avx512f;
// mixmul = llamafile_mixmul_amd_avx512f;
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else if (X86_HAVE(AVXVNNI)) {
// // Intel Alderlake (2021-)
// sgemm = llamafile_sgemm_amd_avxvnni;
// mixmul = llamafile_mixmul_amd_avxvnni;
// iqk_mixmul = iqk_mul_mat_moe;
// } else {
// // Intel Haswell/Broadwell/Skylake (2013-2020)
// // AMD Excavator (2015-2022)
// sgemm = llamafile_sgemm_amd_avx2;
// mixmul = llamafile_mixmul_amd_avx2;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // AMD Piledriver (2011-2014)
// sgemm = llamafile_sgemm_amd_fma;
// mixmul = llamafile_mixmul_amd_fma;
// if (X86_HAVE(F16C))
// iqk_mixmul = iqk_mul_mat_moe;
// }
// } else {
// // Intel Sandybridge/Ivybridge (2010-2012)
// // AMD Bulldozer (2011)
// sgemm = llamafile_sgemm_amd_avx;
// mixmul = llamafile_mixmul_amd_avx;
// }
// } else {
// // AMD K8/Barcelona (2003-2010)
// // Intel Core/Nehalem (2006-2009)
// sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported;
// }
#if defined(__AVX__)
#if defined(__FMA__)
#if defined(__AVX2__)
#if defined(__AVX512F__)
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4;
mixmul = llamafile_mixmul_amd_zen4;
iqk_mixmul = iqk_mul_mat_moe_zen4;
#else
// Intel Xeon Skylake+ (2015-)
sgemm = llamafile_sgemm_amd_avx512f;
mixmul = llamafile_mixmul_amd_avx512f;
iqk_mixmul = iqk_mul_mat_moe;
#endif
#elif defined(__AVXVNNI__)
// Intel Alderlake (2021-)
sgemm = llamafile_sgemm_amd_avxvnni;
mixmul = llamafile_mixmul_amd_avxvnni;
iqk_mixmul = iqk_mul_mat_moe;
#else
// Intel Haswell/Broadwell/Skylake (2013-2020)
// AMD Excavator (2015-2022)
sgemm = llamafile_sgemm_amd_avx2;
mixmul = llamafile_mixmul_amd_avx2;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// AMD Piledriver (2011-2014)
sgemm = llamafile_sgemm_amd_fma;
mixmul = llamafile_mixmul_amd_fma;
#if defined(__F16C__)
iqk_mixmul = iqk_mul_mat_moe;
#endif
#endif
#else
// Intel Sandybridge/Ivybridge (2010-2012)
// AMD Bulldozer (2011)
sgemm = llamafile_sgemm_amd_avx;
mixmul = llamafile_mixmul_amd_avx;
#endif
#else
// AMD K8/Barcelona (2003-2010)
// Intel Core/Nehalem (2006-2009)
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
#elif defined(__aarch64__)
long hwcap = getauxval(AT_HWCAP);
if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)
(hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)
(hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)
// e.g. Apple M1, Raspberry Pi 5
sgemm = llamafile_sgemm_arm82;
mixmul = llamafile_mixmul_arm82;
iqk_mixmul = iqk_mul_mat_moe_arm82;
} else {
// ARM64 baseline ISA
sgemm = llamafile_sgemm_arm80;
mixmul = llamafile_mixmul_arm80;
}
#else
sgemm = llamafile_sgemm_unsupported;
mixmul = llamafile_mixmul_unsupported;
#endif
}
} funcs;
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,
precision);
}
/**
* Performs "mixture of experts" tensor multiplication on CPU.
*/
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
return funcs.mixmul(params, weights, thought, plan, result);
}
bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {
return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);
}
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.h
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#pragma once
#include <stdbool.h>
#include <cstddef>
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_tensor;
struct ggml_compute_params;
bool iqk_mul_mat(long, long, long, int, const void*, const void*, float*, long, int, int);
bool iqk_mul_mat_zen4(long, long, long, int, const void*, const void*, float*, long, int, int);
bool iqk_mul_mat_arm82(long, long, long, int, const void*, const void*, float*, long, int, int);
bool iqk_mul_mat_moe(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
bool iqk_mul_mat_moe_zen4(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
bool iqk_mul_mat_moe_arm82(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
bool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
bool llamafile_sgemm(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_mixmul(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
size_t llamafile_mixmul_needs(const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*);
bool llamafile_sgemm_unsupported(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_sgemm_amd_avx(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_sgemm_amd_fma(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_sgemm_amd_avx2(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_sgemm_amd_avxvnni(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_sgemm_amd_avx512f(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_sgemm_amd_zen4(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_sgemm_arm80(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_sgemm_arm82(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
bool llamafile_mixmul_unsupported(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_amd_avx(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_amd_fma(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_amd_avx2(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_amd_avxvnni(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_amd_avx512f(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_amd_zen4(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_arm80(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_arm82(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
bool llamafile_mixmul_iqk(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
#ifdef __cplusplus
}
#endif
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu.h
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
#pragma once
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
// #include "log.h"
#include "flags.h"
#include "sgemm.h"
// #include <cosmo.h>
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wignored-attributes"
#define ROW_ALIGN 64
#define MATRIX_ALIGN 4096
#define MAX_ALIGN 4096
#ifdef _MSC_VER
#define NOINLINE __declspec(noinline)
#else
#define NOINLINE __attribute__((__noinline__))
#endif
#if defined(__ARM_NEON) || defined(__AVX512F__)
#define VECTOR_REGISTERS 32
#else
#define VECTOR_REGISTERS 16
#endif
#if 0
#define NOT_SUPPORTED tinyBLAS_not_supported(__FILE__, __LINE__)
#else
#define NOT_SUPPORTED false
#endif
#define WANT_QUANTIZATION false
namespace {
bool tinyBLAS_not_supported(const char* file, int line) {
// tinylogf("%s:%d: tinyBLAS not supported\n", file, line);
return false;
}
inline float unhalf(ggml_fp16_t d) {
return GGML_FP16_TO_FP32(d);
}
inline float unhalf(ggml_bf16_t d) {
return GGML_BF16_TO_FP32(d);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// MATRIX MEMORY INDEXING
#define NCA 1
#define NCB 2
#define NCC 4
#define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A**)A)[j] + i : A + lda * (j) + i)
////////////////////////////////////////////////////////////////////////////////////////////////////
// GGML TYPE TRAITS
template <typename T>
struct ggml_type_trait;
template <>
struct ggml_type_trait<float> {
static constexpr ggml_type id = GGML_TYPE_F32;
};
template <>
struct ggml_type_trait<ggml_bf16_t> {
static constexpr ggml_type id = GGML_TYPE_BF16;
};
template <>
struct ggml_type_trait<ggml_fp16_t> {
static constexpr ggml_type id = GGML_TYPE_F16;
};
template <>
struct ggml_type_trait<block_q8_0> {
static constexpr ggml_type id = GGML_TYPE_Q8_0;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED ARITHMETIC OPERATIONS
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline __m128 add(__m128 x, __m128 y) {
return _mm_add_ps(x, y);
}
inline __m128 sub(__m128 x, __m128 y) {
return _mm_sub_ps(x, y);
}
inline __m128 mul(__m128 x, __m128 y) {
return _mm_mul_ps(x, y);
}
#endif // __SSE__
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline __m256 add(__m256 x, __m256 y) {
return _mm256_add_ps(x, y);
}
inline __m256 sub(__m256 x, __m256 y) {
return _mm256_sub_ps(x, y);
}
inline __m256 mul(__m256 x, __m256 y) {
return _mm256_mul_ps(x, y);
}
#endif // __AVX__
#if defined(__AVX512F__)
inline __m512 add(__m512 x, __m512 y) {
return _mm512_add_ps(x, y);
}
inline __m512 sub(__m512 x, __m512 y) {
return _mm512_sub_ps(x, y);
}
inline __m512 mul(__m512 x, __m512 y) {
return _mm512_mul_ps(x, y);
}
#endif // __AVX512F__
#if defined(__ARM_NEON)
inline float32x4_t add(float32x4_t x, float32x4_t y) {
return vaddq_f32(x, y);
}
inline float32x4_t sub(float32x4_t x, float32x4_t y) {
return vsubq_f32(x, y);
}
inline float32x4_t mul(float32x4_t x, float32x4_t y) {
return vmulq_f32(x, y);
}
#endif // __ARM_NEON
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
inline float16x8_t add(float16x8_t x, float16x8_t y) {
return vaddq_f16(x, y);
}
inline float16x8_t sub(float16x8_t x, float16x8_t y) {
return vsubq_f16(x, y);
}
inline float16x8_t mul(float16x8_t x, float16x8_t y) {
return vmulq_f16(x, y);
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD
/**
* Computes a * b + c.
*/
template <typename T, typename U>
inline U madd(T a, T b, U c) {
return add(mul(a, b), c);
}
/**
* Computes a * b + c with error correction.
*
* @see W. Kahan, "Further remarks on reducing truncation errors,"
* Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,
* doi: 10.1145/363707.363723.
*/
template <typename T, typename U>
inline U madder(T a, T b, U c, U* e) {
U y = sub(mul(a, b), *e);
U t = add(c, y);
*e = sub(sub(t, c), y);
return t;
}
#ifdef __ARM_NEON
inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e) {
float32x4_t y = sub(vmulq_n_f32(a, b), *e);
float32x4_t t = add(c, y);
*e = sub(sub(t, c), y);
return t;
}
#endif
#if defined(__FMA__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <>
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
return _mm256_fmadd_ps(a, b, c);
}
#endif
#if defined(__AVX512F__)
template <>
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
return _mm512_fmadd_ps(a, b, c);
}
#endif
#endif
#if defined(__ARM_FEATURE_FMA)
template <>
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
return vfmaq_f32(c, a, b);
}
#if 0 // todo: this specialization chops gcc 12.3 performance in half
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) && 0
template <>
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
return vfmaq_f16(c, b, a);
}
#endif
#endif
#endif
#if defined(__AVX512BF16__)
template <>
inline __m512 madd(__m512bh x, __m512bh y, __m512 z) {
return _mm512_dpbf16_ps(z, x, y);
}
template <>
inline __m512 madder(__m512bh x, __m512bh y, __m512 z, __m512* _) {
return _mm512_dpbf16_ps(z, x, y);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM
#if defined(__ARM_NEON)
inline float hsum(float32x4_t x) {
return vaddvq_f32(x);
}
#endif // __ARM_NEON
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
inline float hsum(float16x8_t x) {
// todo: this works great on clang but it produces terrible code on gcc 12.3
return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x))));
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m128 x) {
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
x = _mm_add_ss(x, _mm_movehdup_ps(x));
#else
__m128 t;
t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
x = _mm_add_ps(x, t);
t = _mm_movehl_ps(t, x);
x = _mm_add_ss(x, t);
#endif
return _mm_cvtss_f32(x);
}
#endif
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m256 x) {
return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)));
}
#endif // __AVX__
#if defined(__AVX512F__)
inline float hsum(__m512 x) {
return _mm512_reduce_add_ps(x);
}
#endif // __AVX512F__
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED MEMORY LOADING
template <typename T, typename U>
T load(const U*);
template <>
inline float load(const float* p) {
return *p;
}
template <>
inline float load(const ggml_fp16_t* p) {
return unhalf(*p);
}
template <>
inline float load(const ggml_bf16_t* p) {
return unhalf(*p);
}
#if defined(__ARM_NEON)
template <>
inline float32x4_t load(const float* p) {
return vld1q_f32(p);
}
template <>
inline float32x4_t load(const ggml_bf16_t* p) {
return vreinterpretq_f32_u32(vshll_n_u16(vld1_u16((const unsigned short*)p), 16));
}
#if !defined(_MSC_VER)
template <>
inline float16x8_t load(const ggml_fp16_t* p) {
return vld1q_f16((const float16_t*)p);
}
template <>
inline float32x4_t load(const ggml_fp16_t* p) {
return vcvt_f32_f16(vld1_f16((const float16_t*)p));
}
#endif // _MSC_VER
#endif // __ARM_NEON
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <>
inline __m128 load(const float* p) {
return _mm_loadu_ps(p);
}
#endif // __SSE__
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <>
inline __m256 load(const float* p) {
return _mm256_loadu_ps(p);
}
#endif // __AVX__
#if defined(__AVX2__) || defined(__AVX512F__)
template <>
inline __m256 load(const ggml_bf16_t* p) {
return _mm256_castsi256_ps(
_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)p)), 16));
}
#endif // __AVX2__
#if defined(__F16C__)
template <>
inline __m256 load(const ggml_fp16_t* p) {
return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)p));
}
#endif // __F16C__
#if defined(__AVX512F__)
template <>
inline __m512 load(const float* p) {
return _mm512_loadu_ps(p);
}
template <>
inline __m512 load(const ggml_fp16_t* p) {
return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)p));
}
template <>
inline __m512 load(const ggml_bf16_t* p) {
return _mm512_castsi512_ps(
_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)p)), 16));
}
#endif // __AVX512F__
#if defined(__AVX512BF16__)
template <>
inline __m512bh load(const ggml_bf16_t* p) {
return (__m512bh)_mm512_loadu_ps((const float*)p);
}
template <>
inline __m512bh load(const float* p) {
return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
}
#endif // __AVX512BF16__
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT OUTPUT STREAMING
inline void store(float* p, float f) {
*p = f;
}
inline void store(ggml_fp16_t* p, float f) {
*p = GGML_FP32_TO_FP16(f);
}
inline void store(ggml_bf16_t* p, float f) {
*p = GGML_FP32_TO_BF16(f);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION
template <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>
class tinyBLAS {
public:
tinyBLAS(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(long m, long n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
NOINLINE void mnpack(long m0, long m, long n0, long n) {
long mc, nc, mp, np;
#if VECTOR_REGISTERS == 32
if (!FLAG_precise) {
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
case 0x55:
mc = 5;
nc = 5;
gemm<5, 5, false>(m0, m, n0, n);
break;
case 0x54:
case 0x53:
case 0x52:
case 0x45:
case 0x44:
case 0x43:
case 0x42:
case 0x35:
case 0x34:
case 0x33:
case 0x32:
case 0x25:
case 0x24:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, false>(m0, m, n0, n);
break;
case 0x51:
case 0x41:
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, false>(m0, m, n0, n);
break;
case 0x15:
case 0x14:
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, false>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, false>(m0, m, n0, n);
break;
default:
return;
}
} else {
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3, true>(m0, m, n0, n);
break;
case 0x42:
case 0x33:
case 0x32:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, true>(m0, m, n0, n);
break;
case 0x41:
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, true>(m0, m, n0, n);
break;
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, true>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, true>(m0, m, n0, n);
break;
default:
return;
}
}
#endif
#if VECTOR_REGISTERS == 16
if (!FLAG_precise) {
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3, false>(m0, m, n0, n);
break;
case 0x42:
case 0x33:
case 0x32:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, false>(m0, m, n0, n);
break;
case 0x41:
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, false>(m0, m, n0, n);
break;
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, false>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, false>(m0, m, n0, n);
break;
default:
return;
}
} else {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) {
case 0x32:
mc = 3;
nc = 2;
gemm<3, 2, true>(m0, m, n0, n);
break;
case 0x23:
mc = 2;
nc = 3;
gemm<2, 3, true>(m0, m, n0, n);
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, true>(m0, m, n0, n);
break;
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, true>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, true>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, true>(m0, m, n0, n);
break;
default:
return;
}
}
#endif
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
template <int RM, int RN, int PRECISE>
NOINLINE void gemm(long m0, long m, long n0, long n) {
long ytiles = RM > 1 ? (m - m0) / RM : 1;
long xtiles = RN > 1 ? (n - n0) / RN : 1;
long tiles = xtiles * ytiles;
long duty = (tiles + nth - 1) / nth;
long start = duty * ith;
long end = start + duty;
if (end > tiles)
end = tiles;
for (long job = start; job < end; ++job) {
long ii = m0 + job / xtiles * RM;
long jj = n0 + job % xtiles * RN;
D Cv[RN][RM] = {};
D Ce[RN][RM] = {};
for (long l = 0; l < k; l += KN)
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i)
if (PRECISE)
Cv[j][i] = madder(load<V>(INDEX(A, lda, ii + i, l)), //
load<V>(INDEX(B, ldb, jj + j, l)), //
Cv[j][i], &Ce[j][i]);
else
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, l)), //
load<V>(INDEX(B, ldb, jj + j, l)), //
Cv[j][i]);
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
}
}
const TA* const A;
const TB* const B;
TC* const C;
const long k;
const long lda;
const long ldb;
const long ldc;
const int ith;
const int nth;
};
//////////////////////////////////////////////////////////////////////////////////////////
// QUANT ZERO MATRIX MULTIPLICATION
#if defined(__ARM_FEATURE_DOTPROD)
template <int CONFIG, typename TA, typename TB, typename TC>
class tinyBLAS_Q0_ARM {
public:
tinyBLAS_Q0_ARM(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(long m, long n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
NOINLINE void mnpack(long m0, long m, long n0, long n) {
long mc, nc, mp, np;
if (!FLAG_precise) {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3, false>(m0, m, n0, n);
break;
case 0x32:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, false>(m0, m, n0, n);
break;
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, false>(m0, m, n0, n);
break;
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, false>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, false>(m0, m, n0, n);
break;
default:
return;
}
} else {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3, true>(m0, m, n0, n);
break;
case 0x32:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, true>(m0, m, n0, n);
break;
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, true>(m0, m, n0, n);
break;
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, true>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, true>(m0, m, n0, n);
break;
default:
return;
}
}
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
template <int RM, int RN, int PRECISE>
NOINLINE void gemm(long m0, long m, long n0, long n) {
long ytiles = RM > 1 ? (m - m0) / RM : 1;
long xtiles = RN > 1 ? (n - n0) / RN : 1;
long tiles = xtiles * ytiles;
long duty = (tiles + nth - 1) / nth;
long start = duty * ith;
long end = start + duty;
if (end > tiles)
end = tiles;
for (long job = start; job < end; ++job) {
long ii = m0 + job / xtiles * RM;
long jj = n0 + job % xtiles * RN;
float32x4_t Cv[RN][RM] = {};
float32x4_t Ce[RN][RM] = {};
for (int l = 0; l < k; ++l)
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i) {
float32x4_t a = vcvtq_f32_s32(vdotq_s32(
vdotq_s32(vdupq_n_s32(0), load_lo(INDEX(A, lda, ii + i, l)),
load_lo(INDEX(B, ldb, jj + j, l))),
load_hi(INDEX(A, lda, ii + i, l)), load_hi(INDEX(B, ldb, jj + j, l))));
float b = unhalf(INDEX(A, lda, ii + i, l)->d) *
unhalf(INDEX(B, ldb, jj + j, l)->d);
if (PRECISE)
Cv[j][i] = badder(a, b, Cv[j][i], &Ce[j][i]);
else
Cv[j][i] = vmlaq_n_f32(Cv[j][i], a, b);
}
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
}
}
inline int8x16_t load_lo(const block_q8_0* b) {
return vld1q_s8(b->qs);
}
inline int8x16_t load_hi(const block_q8_0* b) {
return vld1q_s8(b->qs + 16);
}
inline int8x16_t load_lo(const block_q4_0* b) {
return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))),
vdupq_n_s8(0x8));
}
inline int8x16_t load_hi(const block_q4_0* b) {
return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8));
}
const TA* const A;
const TB* const B;
TC* const C;
const long k;
const long lda;
const long ldb;
const long ldc;
const int ith;
const int nth;
};
#endif // __ARM_FEATURE_DOTPROD
#if defined(__AVX2__) || defined(__AVX512F__)
template <int CONFIG, typename TA, typename TB, typename TC>
class tinyBLAS_Q0_AVX2 {
public:
tinyBLAS_Q0_AVX2(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(long m, long n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
void mnpack(long m0, long m, long n0, long n) {
long mc, nc, mp, np;
#if VECTOR_REGISTERS == 32
if (!FLAG_precise) {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3, false>(m0, m, n0, n);
break;
case 0x32:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, false>(m0, m, n0, n);
break;
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, true>(m0, m, n0, n);
break;
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, true>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, true>(m0, m, n0, n);
break;
default:
return;
}
} else {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3, true>(m0, m, n0, n);
break;
case 0x32:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, true>(m0, m, n0, n);
break;
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, true>(m0, m, n0, n);
break;
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, true>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, true>(m0, m, n0, n);
break;
default:
return;
}
}
#endif
#if VECTOR_REGISTERS == 16
if (!FLAG_precise) {
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) {
case 0x32:
mc = 3;
nc = 2;
gemm<3, 2, false>(m0, m, n0, n);
break;
case 0x23:
mc = 2;
nc = 3;
gemm<2, 3, false>(m0, m, n0, n);
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2, false>(m0, m, n0, n);
break;
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, false>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, false>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, false>(m0, m, n0, n);
break;
default:
return;
}
} else {
switch ((MIN(m - m0, 2) << 4) | MIN(n - n0, 1)) {
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1, true>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2, true>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1, true>(m0, m, n0, n);
break;
default:
return;
}
}
#endif
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
template <int RM, int RN, int PRECISE>
NOINLINE void gemm(long m0, long m, long n0, long n) {
long ytiles = RM > 1 ? (m - m0) / RM : 1;
long xtiles = RN > 1 ? (n - n0) / RN : 1;
long tiles = xtiles * ytiles;
long duty = (tiles + nth - 1) / nth;
long start = duty * ith;
long end = start + duty;
if (end > tiles)
end = tiles;
for (long job = start; job < end; ++job) {
long ii = m0 + job / xtiles * RM;
long jj = n0 + job % xtiles * RN;
__m256 Cv[RN][RM] = {};
__m256 Ce[RN][RM] = {};
for (long l = 0; l < k; ++l)
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i) {
__m256 a = _mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) *
unhalf(INDEX(B, ldb, jj + j, l)->d));
__m256 b = updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)),
load(INDEX(A, lda, ii + i, l))),
_mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)),
load(INDEX(A, lda, ii + i, l))));
if (PRECISE)
Cv[j][i] = madder(a, b, Cv[j][i], &Ce[j][i]);
else
Cv[j][i] = madd(a, b, Cv[j][i]);
}
#pragma GCC unroll 100
for (int j = 0; j < RN; ++j)
#pragma GCC unroll 100
for (int i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
}
}
inline __m256i load(const block_q8_0* b) {
return _mm256_loadu_si256((const __m256i*)b->qs);
}
inline __m256i load(const block_q4_0* b) {
__m128i x = _mm_loadu_si128((const __m128i*)b->qs);
return _mm256_sub_epi8(_mm256_and_si256(_mm256_set1_epi8(15),
_mm256_insertf128_si256(_mm256_castsi128_si256(x),
_mm_srli_epi16(x, 4), 1)),
_mm256_set1_epi8(8));
}
inline __m256 updot(__m256i u, __m256i s) {
__m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
#else
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
#endif
return _mm256_cvtepi32_ps(res);
}
const TA* const A;
const TB* const B;
TC* const C;
const long k;
const long lda;
const long ldb;
const long ldc;
const int ith;
const int nth;
};
#endif // __AVX2__
} // namespace
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul.inc
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas_cpu.h"
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// MIXTURE OF EXPERTS TENSOR MULTIPLICATION
//
//
// SHAPES
//
// - weights [cols, rows, experts]
// - thought [cols, tasks, tokens] w/ tasks ≤ thinkers
// - result [rows, thinkers, tokens] w/ thinkers ≤ experts
// - plan [thinkers, tokens] w/ i32 < experts
//
// DEFINITION
//
// for thinker in range(thinkers):
// for token in range(tokens):
// for row in range(rows):
// c = 0
// for col in range(cols):
// expert = plan[token][thinker]
// a = weights[expert][row][col]
// b = thought[token][thinker % tasks][col]
// c += a * b
// result[token][thinker][row] = c
//
// REGULARITIES
//
// - tokens can be odd
// - thinkers is usually 2
// - tasks is usually 1 or 2
// - cols should be a multiple of 64
// - rows should be a multiple of 64
// - experts is usually 8 but could be 60
// - tokens is always 1 for token generation
// - tokens can be huge for prompt processing
//
// EXAMPLE
//
// mixtral 8x7b w/ 217 token prompt
//
// | ne*0 ne*1 ne*2 ne*3 | nb*0 nb*1 nb*2 nb*3 | type
// =========================================================================
// weights | 16384 6144 8 1 | 18 0x2400 0x3600000 0x1b000000 | q4_0
// thought | 16384 2 217 1 | 4 0x10000 0x20000 0x1b20000 | f32
// result | 6144 2 217 1 | 4 0x6000 0xc000 0xa2c000 | f32
// plan | 2 217 1 1 | 4 0x20 0x1b20 0x1b20 | i32
//
namespace {
class MixMul {
public:
MixMul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result)
: params(params),
weights(weights),
thought(thought),
plan(plan),
result(result),
rows(weights->ne[1]),
cols(weights->ne[0]),
experts(weights->ne[2]),
thinkers(plan->ne[0]),
tasks(thought->ne[1]),
tokens(thought->ne[2]),
ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN),
wdata_((char*)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)),
allocated_(0) {
}
bool allocate_shared_memory() {
if (!(quantized_thought_ = allocate<char>(MATRIX_ALIGN, tokens * tasks * ldq)))
return false;
if (!(rowptr_result_ = allocate<uintptr_t>(ROW_ALIGN, experts * tokens * thinkers)))
return false;
if (!(rowptr_thought_ = allocate<uintptr_t>(ROW_ALIGN, experts * tokens * thinkers)))
return false;
if (!(rowptr_count_ = allocate<long>(sizeof(long), experts)))
return false;
return true;
}
size_t get_allocated_bytes() {
return (wdata_ - (char*)params->wdata) + allocated_;
}
bool mixmul() {
// invariants
assert(tasks <= thinkers);
assert(thinkers <= experts);
assert(tokens == plan->ne[1]);
assert(rows == result->ne[0]);
assert(cols == thought->ne[0]);
assert(tokens == result->ne[2]);
assert(thinkers == result->ne[1]);
// dimensionality
assert(plan->ne[2] == 1);
assert(plan->ne[3] == 1);
assert(result->ne[3] == 1);
assert(weights->ne[3] == 1);
assert(thought->ne[3] == 1);
// miscellaneous
assert(params->nth > 0);
assert(params->ith < params->nth);
assert(plan->type == GGML_TYPE_I32);
// check nb01 is convertible to lda
if (weights->nb[1] % ggml_type_size(weights->type))
return false;
// no support for column strides
if (result->nb[0] != ggml_type_size(result->type))
return false;
if (thought->nb[0] != ggml_type_size(thought->type))
return false;
if (weights->nb[0] != ggml_type_size(weights->type))
return false;
// supported output types
switch (result->type) {
case GGML_TYPE_F32:
return mixmuler<float>();
default:
return false;
}
}
private:
template <typename TC>
bool mixmuler() {
switch (weights->type) {
case GGML_TYPE_F32:
if (thought->type != GGML_TYPE_F32)
return false;
#if defined(__AVX512F__)
return mixmat<16, 1, tinyBLAS<NCB | NCC, 16, __m512, __m512, float, float, TC>, float,
float, TC>();
#elif defined(__AVX__) || defined(__AVX2__)
return mixmat<8, 1, tinyBLAS<NCB | NCC, 8, __m256, __m256, float, float, TC>, float,
float, TC>();
#elif defined(__SSE__)
return mixmat<4, 1, tinyBLAS<NCB | NCC, 4, __m128, __m128, float, float, TC>, float,
float, TC>();
#elif defined(__ARM_NEON)
return mixmat<4, 1, tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, float, float, TC>,
float, float, TC>();
#else
return false;
#endif
case GGML_TYPE_BF16:
if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_BF16)
return false;
#if defined(__AVX512BF16__)
if (!FLAG_precise) {
return mixmat<
32, 1, tinyBLAS<NCB | NCC, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC>,
ggml_bf16_t, ggml_bf16_t, TC>();
} else {
return mixmat<16, 1,
tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC>,
ggml_bf16_t, ggml_bf16_t, TC>();
}
#elif defined(__AVX512F__)
return mixmat<16, 1,
tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC>,
ggml_bf16_t, ggml_bf16_t, TC>();
#elif defined(__AVX2__)
return mixmat<8, 1,
tinyBLAS<NCB | NCC, 8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, TC>,
ggml_bf16_t, ggml_bf16_t, TC>();
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
return mixmat<
4, 1,
tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_bf16_t, ggml_bf16_t, TC>,
ggml_bf16_t, ggml_bf16_t, TC>();
#else
return false;
#endif
case GGML_TYPE_F16:
if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_F16)
return false;
#if defined(__AVX512F__)
return mixmat<16, 1,
tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC>,
ggml_fp16_t, ggml_fp16_t, TC>();
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
// if (X86_CHECK(F16C)) {
return mixmat<8, 1,
tinyBLAS<NCB | NCC, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC>,
ggml_fp16_t, ggml_fp16_t, TC>();
// } else {
// return false;
// }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (result->op_params[0] == GGML_PREC_F32) {
return mixmat<
4, 1,
tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, TC>,
ggml_fp16_t, ggml_fp16_t, TC>();
} else {
return mixmat<
8, 1,
tinyBLAS<NCB | NCC, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC>,
ggml_fp16_t, ggml_fp16_t, TC>();
}
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
return mixmat<
4, 1,
tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, TC>,
ggml_fp16_t, ggml_fp16_t, TC>();
#else
return false;
#endif
case GGML_TYPE_Q4_0:
if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0)
return false;
#if defined(__AVX2__) || defined(__AVX512F__)
return mixmat<32, 32, tinyBLAS_Q0_AVX2<NCB | NCC, block_q4_0, block_q8_0, TC>,
block_q4_0, block_q8_0, TC>();
#elif defined(__ARM_FEATURE_DOTPROD)
return mixmat<32, 32, tinyBLAS_Q0_ARM<NCB | NCC, block_q4_0, block_q8_0, TC>,
block_q4_0, block_q8_0, TC>();
#else
return false;
#endif
case GGML_TYPE_Q8_0:
if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0)
return false;
#if defined(__AVX2__) || defined(__AVX512F__)
return mixmat<32, 32, tinyBLAS_Q0_AVX2<NCB | NCC, block_q8_0, block_q8_0, TC>,
block_q8_0, block_q8_0, TC>();
#elif defined(__ARM_FEATURE_DOTPROD)
return mixmat<32, 32, tinyBLAS_Q0_ARM<NCB | NCC, block_q8_0, block_q8_0, TC>,
block_q8_0, block_q8_0, TC>();
#else
return false;
#endif
default:
return false;
}
}
template <int KN, int BS, typename BLAS, typename TA, typename TB, typename TC>
bool mixmat() {
if (cols % KN)
return false;
switch (params->type) {
case GGML_TASK_TYPE_INIT:
if (thought->type != ggml_type_trait<TB>::id)
quantize_thought(ggml_type_trait<TB>::id);
build_row_pointers(ggml_type_trait<TB>::id);
return true;
case GGML_TASK_TYPE_COMPUTE:
assert(!(cols % BS));
assert(!(weights->nb[1] % sizeof(TA)));
for (int expert = 0; expert < experts; ++expert) {
BLAS tb{cols / BS,
(const TA*)((const char*)weights->data + expert * weights->nb[2]),
(long)(weights->nb[1] / sizeof(TA)),
(const TB*)(rowptr_thought_ + expert * tokens * thinkers),
0,
(TC*)(rowptr_result_ + expert * tokens * thinkers),
0,
params->ith,
params->nth};
tb.matmul(rows, rowptr_count_[expert], GGML_TASK_TYPE_COMPUTE);
}
return true;
default:
return true;
}
}
void build_row_pointers(ggml_type vec_dot_type) {
for (int expert = params->ith; expert < experts; expert += params->nth) {
long count = 0;
for (long token = 0; token < tokens; ++token)
for (int thinker = 0; thinker < thinkers; ++thinker)
if (expert == *(const int32_t*)((const char*)plan->data +
token * plan->nb[1] + thinker * plan->nb[0])) {
long row = count++;
long idx = expert * thinkers * tokens + row;
rowptr_result_[idx] =
(uintptr_t)((char*)result->data + token * result->nb[2] +
thinker * result->nb[1]);
if (thought->type == vec_dot_type)
rowptr_thought_[idx] =
(uintptr_t)((char*)thought->data + token * thought->nb[2] +
thinker % tasks * thought->nb[1]);
else
rowptr_thought_[idx] =
(uintptr_t)((char*)quantized_thought_ + token * tasks * ldq +
thinker % tasks * ldq);
}
rowptr_count_[expert] = count;
}
}
void quantize_thought(ggml_type vec_dot_type) {
long chore = 0;
for (long token = 0; token < tokens; ++token)
for (int task = 0; task < tasks; ++task)
if (chore++ % params->nth == params->ith)
quantize_row(quantized_thought_ + token * tasks * ldq + task * ldq,
(const float*)((const char*)thought->data +
token * thought->nb[2] + task * thought->nb[1]),
vec_dot_type);
}
void quantize_row(void* dst, const float* src, ggml_type type) {
assert((long)ggml_row_size(type, cols) <= ldq);
switch (type) {
case GGML_TYPE_F16:
ggml_fp32_to_fp16_row(src, (ggml_fp16_t*)dst, cols);
break;
case GGML_TYPE_BF16:
ggml_fp32_to_bf16_row(src, (ggml_bf16_t*)dst, cols);
break;
case GGML_TYPE_Q8_0:
quantize_row_q8_0((const float*)src, (block_q8_0*)dst, cols);
break;
default:
GGML_UNREACHABLE();
}
}
template <typename T>
T* allocate(size_t align, size_t elems) {
T* res = nullptr;
size_t need = sizeof(T) * elems;
size_t base = allocated_;
base += align - 1;
base &= -align;
size_t toto = base + need;
if (toto >= allocated_ && toto <= params->wsize) {
res = (T*)(wdata_ + base);
allocated_ = toto;
}
return res;
}
const ggml_compute_params* const params;
const ggml_tensor* const weights;
const ggml_tensor* const thought;
const ggml_tensor* const plan;
ggml_tensor* const result;
const long rows;
const long cols;
const int experts;
const int thinkers;
const int tasks;
const long tokens;
const long ldq;
// variables
char* const wdata_;
size_t allocated_;
// shared memory
long* rowptr_count_ /*[experts]*/;
char* quantized_thought_ /*[tokens][tasks][cols][2]*/;
uintptr_t* rowptr_result_ /*[experts][tokens*thinkers]*/;
uintptr_t* rowptr_thought_ /*[experts][tokens*thinkers]*/;
};
} // namespace
/**
* Performs "mixture of experts" tensor multiplication on CPU.
*/
bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {
MixMul mm{params, weights, thought, plan, result};
return mm.allocate_shared_memory() && mm.mixmul();
}
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_mixmul llamafile_mixmul_amd_avx
#include "tinyblas_cpu_mixmul.inc"
/**
* Returns number of shared memory bytes llamafile_mixmul() needs.
*/
size_t llamafile_mixmul_needs(const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan) {
ggml_compute_params params{};
params.wsize = 0x7ffff000;
params.wdata = (void*)0x1000;
MixMul mm{&params, weights, thought, plan, 0};
if (mm.allocate_shared_memory())
return mm.get_allocated_bytes();
else
return 0;
}
#endif // __x86_64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_mixmul llamafile_mixmul_amd_avx2
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_mixmul llamafile_mixmul_amd_avx512f
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_mixmul llamafile_mixmul_amd_avxvnni
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_mixmul llamafile_mixmul_amd_fma
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_mixmul llamafile_mixmul_amd_zen4
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_arm80.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __aarch64__
#define llamafile_mixmul llamafile_mixmul_arm80
#include "tinyblas_cpu_mixmul.inc"
/**
* Returns number of shared memory bytes llamafile_mixmul() needs.
*/
size_t llamafile_mixmul_needs(const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan) {
ggml_compute_params params{};
params.wsize = 0x7ffff000;
params.wdata = (void*)0x1000;
MixMul mm{&params, weights, thought, plan, 0};
if (mm.allocate_shared_memory())
return mm.get_allocated_bytes();
else
return 0;
}
#endif // __aarch64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_arm82.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __aarch64__
#define llamafile_mixmul llamafile_mixmul_arm82
#include "tinyblas_cpu_mixmul.inc"
#endif // __aarch64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas_cpu.h"
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
namespace {
template <typename TC>
bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
switch (Atype) {
case GGML_TYPE_F32: {
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (k % 4)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{
k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__)
if (k % 32)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_BF16)
return NOT_SUPPORTED;
if (!FLAG_precise) {
tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX2__)
if (k % 8)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{
k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (k % 16)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
// if (X86_CHECK(F16C)) {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32 && n < 2) {
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
// } else {
// return NOT_SUPPORTED;
// }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (precision == GGML_PREC_F32) {
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
} else {
if (k % 8)
return NOT_SUPPORTED;
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
}
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_SUPPORTED;
if (k % 4)
return NOT_SUPPORTED;
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q8_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{
k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
case GGML_TYPE_Q4_0: {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_Q8_0)
return NOT_SUPPORTED;
#if defined(__AVX2__) || defined(__AVX512F__)
tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{
k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n, task);
return true;
#else
return NOT_SUPPORTED;
#endif
}
default:
return NOT_SUPPORTED;
}
(void)m;
(void)n;
(void)k;
(void)A;
(void)lda;
(void)B;
(void)ldb;
(void)C;
(void)ldc;
(void)ith;
(void)nth;
(void)Atype;
(void)Btype;
(void)precision;
}
} // namespace
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
* GGML_PREC_DEFAULT);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {
assert(m >= 0);
assert(n >= 0);
assert(k >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
#if QK_K == 256
#if defined(__x86_64__)
#if defined(__AVX2__) && defined(__FMA__)
// if (X86_CHECK(AVX2) && X86_CHECK(FMA)) {
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
// }
#endif
#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {
// assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);
assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));
if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {
return true;
}
}
#endif
#endif
switch (Ctype) {
case GGML_TYPE_F32:
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,
Btype, Ctype, precision);
default:
return NOT_SUPPORTED;
}
}
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_sgemm llamafile_sgemm_amd_avx
#include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_sgemm llamafile_sgemm_amd_avx2
#include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__
// Adapted from
// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__
#define llamafile_sgemm llamafile_sgemm_amd_avx512f
#include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment