Unverified Commit 4987f13d authored by Gabe Goodhart's avatar Gabe Goodhart Committed by GitHub
Browse files

Llama cpp bump (df1b612): granite docling / mamba2 optimizations / multimodal...


Llama cpp bump (df1b612): granite docling / mamba2 optimizations / multimodal encoding fixes (#12552)

* feat: Bump llama.cpp to df1b612

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* fix(mtmd): Correctly encode text chunks during mtmd tokenization

There can be text chunks that appear interspersed with the image embeddings
that contain template delimiter tokens for some models. These need to be
correctly translated to text tokens.

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* tests: Use MtmdChunk in image_test

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* style: Fix unnecessary conversion linting

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* fix(ggml): Revert changes to ggml_hip.cpp

These changes were done largely by our code assistant and are likely wrong

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* fix: Revert changes in mem_nvml.cpp

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* feat: Update sync point to 1deee0

This brings in several more optimization commits and model support for
EmbeddingGemma

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* feat: Update patches for 1deee0

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* feat: sync for bump to 1deee0

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* fix: Bad patch updates with errant `+`

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* feat: Bump llama.cpp/ggml to 7049736

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

* fix: format-patches after latest bump

Branch: LlamaCPPBump-GraniteDocling
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>

---------
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>
parent e638f2ac
#include "unary.cuh"
#include "convert.cuh"
static __device__ __forceinline__ float op_abs(float x) {
return fabsf(x);
......@@ -375,6 +376,59 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
}
/* CUDA kernel + launcher for xIELU */
template <typename T>
static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
const float xi = ggml_cuda_cast<float>(x[i]);
const float gate_pos = (xi > 0.0f);
const float y_pos = alpha_p * xi * xi + beta * xi;
const float min_v_eps = fminf(xi, eps);
const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;
const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
dst[i] = ggml_cuda_cast<T>(out);
}
template <typename T>
static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) {
const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE;
xielu_kernel<<<num_blocks, CUDA_XIELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps);
}
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
const float alpha_n = ggml_get_op_params_f32(dst, 1);
const float alpha_p = ggml_get_op_params_f32(dst, 2);
const float beta = ggml_get_op_params_f32(dst, 3);
const float eps = ggml_get_op_params_f32(dst, 4);
if (src0->type == GGML_TYPE_F16) {
xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
} else {
xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
}
}
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
......
......@@ -16,6 +16,7 @@
#define CUDA_SIN_BLOCK_SIZE 256
#define CUDA_COS_BLOCK_SIZE 256
#define CUDA_GLU_BLOCK_SIZE 256
#define CUDA_XIELU_BLOCK_SIZE 256
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
......@@ -72,3 +73,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
......@@ -8,6 +8,9 @@
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#if defined(GGML_HIP_ROCWMMA_FATTN)
#include <rocwmma/rocwmma-version.hpp>
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
......
......@@ -39,12 +39,6 @@ endif()
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
if (GGML_HIP_ROCWMMA_FATTN)
CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
if (NOT ${FOUND_ROCWMMA})
message(FATAL_ERROR "rocwmma has not been found")
endif()
endif()
if (${hip_VERSION} VERSION_LESS 6.1)
message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
......@@ -59,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
......@@ -117,10 +113,6 @@ if (NOT GGML_HIP_MMQ_MFMA)
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
endif()
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
endif()
if (GGML_HIP_EXPORT_METRICS)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
endif()
......
......@@ -102,6 +102,9 @@ static bool ggml_op_is_empty(enum ggml_op op) {
}
}
static inline float ggml_softplus(float input) {
return (input > 20.0f) ? input : logf(1 + expf(input));
}
//
// logging
//
......
......@@ -112,7 +112,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * t
}
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (tensor->src[i]) {
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
}
......@@ -173,7 +173,7 @@ static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor *
}
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (tensor->src[i]) {
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
return false;
......
......@@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
......@@ -338,7 +357,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
char base[256];
char name[256];
snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
const char * suffix = "";
if (op->src[1]->ne[0] % 4 == 0) {
suffix = "_4";
}
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
......@@ -352,15 +377,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
char base[256];
char name[256];
if (op->src[3]->ne[0] == 1) {
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
} else {
const int nsg = (ne00 + 31)/32;
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_nsg=%d", base, nsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
......@@ -369,7 +394,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
return res;
}
......@@ -918,6 +943,96 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);
char base[256];
char name[256];
snprintf(base, 256, "kernel_%s",
"flash_attn_ext_pad");
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
base,
has_mask,
ncpsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);
char base[256];
char name[256];
snprintf(base, 256, "kernel_%s",
"flash_attn_ext_blk");
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
base,
nqptg,
ncpsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
ggml_metal_cv_t cv = ggml_metal_cv_init();
//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const ggml_tensor * op,
......@@ -925,6 +1040,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
......@@ -937,18 +1053,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
// do bounds checks for the mask?
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
"flash_attn_ext",
ggml_type_name(op->src[1]->type),
dk,
dv);
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
bc_mask,
ns10,
ns20,
nsg);
......@@ -964,6 +1085,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
......@@ -983,6 +1107,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
......@@ -1002,12 +1127,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
dk,
dv);
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
ns10,
ns20,
nsg, nwg);
......@@ -1023,6 +1149,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
......@@ -1374,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_SGD);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
......@@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
......@@ -134,6 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
......@@ -142,6 +157,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
......@@ -151,6 +167,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg);
......
......@@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
......@@ -776,9 +777,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
};
}
case GGML_OP_GET_ROWS:
{
return op->ne[3] == 1;
}
return true;
case GGML_OP_SET_ROWS:
{
if (op->src[0]->type != GGML_TYPE_F32) {
......@@ -800,6 +799,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
};
}
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
default:
return false;
}
......
......@@ -1954,11 +1954,20 @@ GGML_TABLE_END()
#define N_SG_IQ4_XS 2
// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
#define FC_MUL_MM 500
#define FC_FLASH_ATTN_EXT_PAD 100
#define FC_FLASH_ATTN_EXT_BLK 200
#define FC_FLASH_ATTN_EXT 300
#define FC_FLASH_ATTN_EXT_VEC 400
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
#define FC_MUL_MV 600
#define FC_MUL_MM 700
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
#define OP_FLASH_ATTN_EXT_NCPSG 64
#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
// kernel argument structs
//
......@@ -2063,6 +2072,7 @@ typedef struct {
} ggml_metal_kargs_clamp;
typedef struct {
int64_t nk0;
int64_t ne00;
int64_t ne01;
int64_t ne02;
......@@ -2128,6 +2138,35 @@ typedef struct {
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_pad;
typedef struct {
int32_t ne01;
int32_t ne30;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_blk;
typedef struct {
int32_t ne01;
int32_t ne02;
......@@ -2146,6 +2185,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
......@@ -2180,6 +2220,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
......@@ -2388,6 +2429,10 @@ typedef struct{
float limit;
} ggml_metal_kargs_glu;
typedef struct {
uint64_t np;
} ggml_metal_kargs_sum;
typedef struct {
int64_t ne00;
int64_t ne01;
......@@ -2457,32 +2502,45 @@ typedef struct {
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t s_off;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t ns12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t ns21;
uint64_t nb22;
int64_t ne30;
uint64_t nb31;
uint64_t nb41;
uint64_t nb42;
uint64_t ns42;
uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
uint64_t ns52;
uint64_t nb53;
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;
typedef struct {
int64_t ne00;
int32_t ne00t;
int32_t ne00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
uint64_t nb03;
int32_t ne10;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_get_rows;
typedef struct {
......@@ -2604,6 +2662,14 @@ typedef struct {
uint64_t nb01;
} ggml_metal_kargs_argmax;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_adamw;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;
#endif // GGML_METAL_IMPL
#include <metal_stdlib>
......@@ -4322,6 +4388,24 @@ kernel void kernel_geglu_quick_f32(
}
}
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
if (tiitg != 0) {
return;
}
float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
}
dst[0] = acc;
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
......@@ -4631,124 +4715,39 @@ kernel void kernel_ssm_conv_f32_f32(
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
kernel void kernel_ssm_conv_f32_f32_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
const int64_t i0 = tpitg.x;
const int64_t i1 = 0;
const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int64_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
s = state;
// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.
// Computed for each thread
float sumf = state * C[i0];
// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);
if (sgptg > 1) {
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
threadgroup_barrier(mem_flags::mem_threadgroup);
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
// For simd group 0 at indices < num simd groups, extract the shared
// simd sum
sumf = 0.0f;
if (sgitg == 0) {
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
y[0] = sumf;
}
}
} else if (tiisg == 0) {
y[0] = sumf;
}
float sumf = 0.0f;
// recurse
s0 = s;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
// Assign the final state to the output buffer
s_buff[i] = s;
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
kernel void kernel_ssm_scan_group_f32(
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
device const void * src1,
......@@ -4760,102 +4759,87 @@ kernel void kernel_ssm_scan_group_f32(
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;
const int64_t i0 = tpitg.x;
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
shared[tpitg.x] = 0.0f;
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int32_t i0 = tpitg.x;
const int32_t i1 = tgpig.x;
const int32_t ir = tgpig.y; // current head
const int32_t i3 = tgpig.z; // current seq
const int64_t nc = args.d_state;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int32_t nc = args.d_state;
const int32_t nr = args.d_inner;
const int32_t nh = args.n_head;
const int32_t ng = args.n_group;
const int32_t n_t = args.n_seq_tokens;
const int64_t s_off = args.s_off;
const int32_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
const int64_t g = ir / (nh / ng); // repeat_interleave
const int32_t i = i0 + i1*nc;
const int32_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
const float state = (s0 * dA) + (B[i0] * x_dt);
s = state;
// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.
// Computed for each thread
float sumf = state * C[i0];
// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);
float s = 0.0f;
// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
const float A0 = A[i0%args.ne30];
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// For simd group 0 at indices < num simd groups, extract the shared
// simd sum
sumf = 0.0f;
if (sgitg == 0) {
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float dt0 = dt[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
const float x_dt = x[0] * dtsp;
const float dA = exp(dtsp * A0);
s = (s0 * dA) + (B[i0] * x_dt);
const float sumf = simd_sum(s * C[i0]);
if (tiisg == 0) {
y[0] = sumf;
}
shared[t*NW + sgitg] = sumf;
}
// recurse
s0 = s;
x += args.ns12;
dt += args.ns21;
B += args.ns42;
C += args.ns52;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
if (tiisg == 0 && i2 + sgitg < n_t) {
y[sgitg*nh*nr] = sumf;
}
y += sgptg*nh*nr;
}
// Assign the final state to the output buffer
s_buff[i] = s;
}
......@@ -7112,10 +7096,142 @@ kernel void kernel_leaky_relu_f32_4(
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
}
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
// pad the last chunk of C elements of k and v into a an extra pad buffer
kernel void kernel_flash_attn_ext_pad(
constant ggml_metal_kargs_flash_attn_ext_pad & args,
device const char * k,
device const char * v,
device const char * mask,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t C = FC_flash_attn_ext_pad_ncpsg;
device char * k_pad = dst;
device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
const int32_t icp = args.ne11 % C;
const int32_t ic0 = args.ne11 - icp;
const int32_t i1 = tgpig[0];
const int32_t i2 = tgpig[1];
const int32_t i3 = tgpig[2];
if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
if (i1 >= icp) {
// here it is not important the exact value that will be used as we rely on masking out the scores in the attention
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
k_dst[i] = 0;
}
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
v_dst[i] = 0;
}
} else {
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
k_dst[i] = k_src[i];
}
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
v_dst[i] = v_src[i];
}
}
}
if (FC_flash_attn_ext_pad_has_mask) {
if (i2 < args.ne32 && i3 < args.ne33) {
for (int ib = i1; ib < args.ne31; ib += C) {
device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
for (int i = tiitg; i < C; i += ntg.x) {
if (i >= icp) {
mask_dst[i] = -MAXHALF;
} else {
mask_dst[i] = mask_src[i];
}
}
}
}
}
}
constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
// scan the blocks of the mask that are not masked
// 0 - masked (i.e. full of -INF, skip)
// 1 - not masked (i.e. at least one element of the mask is not -INF)
kernel void kernel_flash_attn_ext_blk(
constant ggml_metal_kargs_flash_attn_ext_blk & args,
device const char * mask,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
// block size C x Q
const int32_t Q = FC_flash_attn_ext_blk_nqptg;
const int32_t C = FC_flash_attn_ext_blk_ncpsg;
constexpr short NW = N_SIMDWIDTH;
const int32_t i3 = tgpig[2]/args.ne32;
const int32_t i2 = tgpig[2]%args.ne32;
const int32_t i1 = tgpig[1];
const int32_t i0 = tgpig[0];
char res = i0*C + C > args.ne30 ? 1 : 0;
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
// fast route
if (res == 0) {
if (simd_max(*mask_src) > -MAXHALF/2) {
res = 1;
}
}
// detailed check of the elements of the block
if ((C > NW || Q > 1) && res == 0) {
half m = -MAXHALF;
FOR_UNROLL (short j = 0; j < Q; ++j) {
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
m = max(m, mask_src[ii*NW]);
}
mask_src += args.nb31/2;
}
if (simd_max(m) > -MAXHALF/2) {
res = 1;
}
}
const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
const int32_t nblk0 = ((args.ne30 + C - 1)/C);
if (tiisg == 0) {
dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
}
}
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
......@@ -7162,6 +7278,8 @@ void kernel_flash_attn_ext_impl(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device const char * blk,
device char * dst,
threadgroup half * shmem_f16,
uint3 tgpig,
......@@ -7227,6 +7345,13 @@ void kernel_flash_attn_ext_impl(
pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
}
{
const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
const int32_t nblk0 = ((args.ne11 + C - 1)/C);
blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
}
{
q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
......@@ -7286,16 +7411,75 @@ void kernel_flash_attn_ext_impl(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic = 0; ic < args.ne11; ic += C) {
for (int ic0 = 0; ; ++ic0) {
int ic = ic0*C;
if (ic >= args.ne11) {
break;
}
// the last partial chunk uses the pad buffer as source
if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
k = pad;
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
if (!FC_flash_attn_ext_has_mask) {
threadgroup half * sm = (threadgroup half *) (sm2);
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
for (short i = tiisg; i < C; i += NW) {
if (ic + i >= args.ne11) {
sm[2*j*SH + i] = -MAXHALF;
}
}
}
} else {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
pm2[jj] = (device const half2 *) ((device const half *) mask +
(iq1 + j)*C +
(iq2%args.ne32)*(C*args.ne31) +
(iq3%args.ne33)*(C*args.ne31*args.ne32));
}
}
ic = 0;
}
// read the mask into shared mem
if (FC_flash_attn_ext_has_mask) {
if (blk[ic0] == 0) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
pm2[jj] += NW;
}
continue;
}
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
if (FC_flash_attn_ext_bc_mask) {
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
} else {
sm2[j*SH + tiisg] = pm2[jj][tiisg];
}
pm2[jj] += NW;
}
#if 0
// note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
threadgroup_barrier(mem_flags::mem_threadgroup);
// used to detect blocks full of -INF
......@@ -7314,13 +7498,14 @@ void kernel_flash_attn_ext_impl(
continue;
}
#endif
}
// Q*K^T
// this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
threadgroup const q_t * pq = sq;
threadgroup s_t * ps = ss;
......@@ -7331,26 +7516,24 @@ void kernel_flash_attn_ext_impl(
constexpr short NC = (C/8)/NSG;
// TODO: not good to unroll for large contexts - not sure why?
// note: do not unroll for large heads
#pragma unroll (DK <= 64 ? NC : 1)
for (short cc = 0; cc < NC; ++cc) {
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
if (DK8 % 16 != 0) {
if (DK % 16 != 0) {
k8x8_t mk;
q8x8_t mq;
FOR_UNROLL (short i = 0; i < DK8; ++i) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_load(mk, pk, NS10, 0, true);
simdgroup_load(mq, pq, DK);
simdgroup_load(mk, pk + 8*i, NS10, 0, true);
simdgroup_load(mq, pq + 8*i, DK);
simdgroup_barrier(mem_flags::mem_none);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
pk += 8;
pq += 8;
}
} else {
k8x8_t mk[2];
......@@ -7359,26 +7542,22 @@ void kernel_flash_attn_ext_impl(
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
simdgroup_load(mq[0], pq + 0*8, DK);
simdgroup_load(mq[1], pq + 1*8, DK);
simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
simdgroup_barrier(mem_flags::mem_none);
simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
pk += 16;
pq += 16;
}
}
simdgroup_store(mqk, ps, SH, 0, false);
pk += 8*(NSG*NS10 - DK8);
pq += 8*(NSG*0 - DK8);
pk += 8*(NSG*NS10);
ps += 8*(NSG);
}
} else {
......@@ -7392,7 +7571,7 @@ void kernel_flash_attn_ext_impl(
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
for (short ii = 0; ii < DK16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
if (DK16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
......@@ -7512,27 +7691,50 @@ void kernel_flash_attn_ext_impl(
}
{
auto sst = ss;
device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
pv += 8*sgitg;
if (DV <= 64) {
FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
s8x8_t vs;
simdgroup_load(vs, sst, SH, 0, false);
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
v8x8_t mv;
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
v8x8_t mv[2];
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
}
pv += 8*NS20;
}
} else {
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
s8x8_t vs[2];
simdgroup_load(mv, pv, NS20, 0, false);
simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
pv += 8*NSG;
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
v8x8_t mv[4];
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
}
pv += 8*(NS20 - NO*NSG);
sst += 8;
pv += 2*8*NS20;
}
}
}
......@@ -7556,7 +7758,7 @@ void kernel_flash_attn_ext_impl(
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
if (DV16%4 == 0) {
// no need for bound checks
......@@ -7646,7 +7848,7 @@ void kernel_flash_attn_ext_impl(
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
const float scale = 1.0f/S[jj];
const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
if (DV4 % NW == 0) {
FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
......@@ -7691,8 +7893,8 @@ template<
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
short DK, // K head size
short DV, // V head size
short Q = 8, // queries per threadgroup
short C = 64> // cache items per threadgroup
short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
kernel void kernel_flash_attn_ext(
constant ggml_metal_kargs_flash_attn_ext & args,
device const char * q,
......@@ -7700,13 +7902,15 @@ kernel void kernel_flash_attn_ext(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device const char * blk,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
switch (FC_flash_attn_ext_nsg) {
// note: disabled cases to reduce library load time
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
......@@ -7826,6 +8030,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
......@@ -7852,9 +8057,9 @@ template<
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
short DK, // K head size
short DV, // V head size
short NE = 4, // head elements per thread
short Q = 1, // queries per threadgroup
short C = 32, // cache items per threadgroup
short NE, // head elements per thread
short Q, // queries per threadgroup
short C, // cache items per threadgroup
short NSG> // number of simd groups
void kernel_flash_attn_ext_vec_impl(
constant ggml_metal_kargs_flash_attn_ext_vec & args,
......@@ -7863,6 +8068,7 @@ void kernel_flash_attn_ext_vec_impl(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
......@@ -7968,12 +8174,38 @@ void kernel_flash_attn_ext_vec_impl(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
const int ic = ic0 + C*sgitg;
for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
int ic = ic0*C;
if (ic >= args.ne11) {
break;
}
// the last partial chunk uses the pad buffer as source
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
k = pad;
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
if (!FC_flash_attn_ext_vec_has_mask) {
if (ic + tiisg >= args.ne11) {
sm[tiisg] = -MAXHALF;
}
} else {
pm = (device const half *) (mask) +
iq1*C +
(iq2%args.ne32)*(C*args.ne31) +
(iq3%args.ne33)*(C*args.ne31*args.ne32);
}
ic = 0;
}
if (FC_flash_attn_ext_vec_has_mask) {
sm[tiisg] = pm[ic + tiisg];
}
......@@ -7985,7 +8217,7 @@ void kernel_flash_attn_ext_vec_impl(
// Q*K^T
{
device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
threadgroup const q4_t * pq4 = sq4;
pk4 += ty*NS10/4 + tx;
......@@ -8000,7 +8232,7 @@ void kernel_flash_attn_ext_vec_impl(
mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
}
} else {
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
k4_t mk;
......@@ -8098,7 +8330,7 @@ void kernel_flash_attn_ext_vec_impl(
}
if (is_same<vd4_t, v4_t>::value) {
device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
pv4 += ty*NS20/4 + tx;
......@@ -8111,7 +8343,7 @@ void kernel_flash_attn_ext_vec_impl(
}
} else {
FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
const short i = ii*NL + tx;
......@@ -8236,7 +8468,7 @@ void kernel_flash_attn_ext_vec_impl(
device float4 * dst4 = (device float4 *) dst;
device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
// interleave the workgroup data
for (short i = tiisg; i < DV4; i += NW) {
......@@ -8274,8 +8506,8 @@ template<
short DK, // K head size
short DV, // V head size
short NE = 4, // head elements per thread
short Q = 1, // queries per threadgroup
short C = 32> // cache items per threadgroup
short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext_vec & args,
device const char * q,
......@@ -8283,13 +8515,14 @@ kernel void kernel_flash_attn_ext_vec(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
switch (FC_flash_attn_ext_vec_nsg) {
// note: disabled cases to reduce library load time
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
......@@ -8413,7 +8646,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
const float m = simd_max(M);
const float ms = exp(M - m);
S = 1.0f/simd_sum(S*ms);
S = simd_sum(S*ms);
S = S == 0.0f ? 0.0f : 1.0f/S;
const short DV4 = DV/4;
......@@ -8433,21 +8667,17 @@ kernel void kernel_flash_attn_ext_vec_reduce(
}
template<typename T0, typename T1>
kernel void kernel_cpy(
kernel void kernel_cpy_t_t(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 tptg[[threads_per_threadgroup]]) {
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
......@@ -8458,190 +8688,70 @@ kernel void kernel_cpy(
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
}
}
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
#endif
// TODO: templetify these kernels
kernel void kernel_cpy_f32_q8_0(
template<short QK,
typename block_q,
void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_cpy_f32_q(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q8_0(src, dst_data[i00/QK8_0]);
}
}
kernel void kernel_cpy_f32_q4_0(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q4_0(src, dst_data[i00/QK4_0]);
}
}
kernel void kernel_cpy_f32_q4_1(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q4_1(src, dst_data[i00/QK4_1]);
}
}
kernel void kernel_cpy_f32_q5_0(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_q5_0(src, dst_data[i00/QK5_0]);
}
}
quantize_func(src, dst_data[i00]);
kernel void kernel_cpy_f32_q5_1(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q5_1(src, dst_data[i00/QK5_1]);
break;
}
}
kernel void kernel_cpy_f32_iq4_nl(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
}
}
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32(
......@@ -8649,11 +8759,12 @@ kernel void kernel_cpy_q_f32(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
......@@ -8665,10 +8776,12 @@ kernel void kernel_cpy_q_f32(
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
break;
}
}
......@@ -10121,7 +10234,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const char * src0,
......@@ -10134,13 +10247,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
......@@ -10151,6 +10263,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int nb = args.ne00/QK4_NL;
const int ns01 = args.nb01/args.nb00;
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
......@@ -10158,24 +10273,25 @@ void kernel_mul_mv_iq4_nl_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};
device const float * yb = y + ix * QK4_NL + it * 8;
device const float * yb = y + ix*QK4_NL + it*8;
uint32_t aux32[2];
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
float4 qf1, qf2;
for (int ib = ix; ib < nb; ib += 16) {
// [TAG_MUL_MV_WEIRD]
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
for (short row = 0; row < nr0; row++) {
device const block_iq4_nl & xb = x[row*nb + ib];
for (short row = 0; row < NR0; row++) {
device const block_iq4_nl & xb = x[row*ns01 + ib];
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
float4 acc1 = {0.f}, acc2 = {0.f};
......@@ -10206,7 +10322,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
......@@ -10228,7 +10344,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const char * src0,
......@@ -10241,12 +10357,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
......@@ -10257,6 +10372,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int nb = args.ne00/QK_K;
const int ns01 = args.nb01/args.nb00;
const short ix = tiisg/16; // 0 or 1
const short it = tiisg%16; // 0...15
const short ib = it/2;
......@@ -10266,7 +10384,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
......@@ -10275,15 +10393,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
float4 qf1, qf2;
for (int ibl = ix; ibl < nb; ibl += 2) {
// [TAG_MUL_MV_WEIRD]
for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
for (short row = 0; row < nr0; ++row) {
device const block_iq4_xs & xb = x[row*nb + ibl];
for (short row = 0; row < NR0; ++row) {
device const block_iq4_xs & xb = x[row*ns01 + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
float4 acc1 = {0.f}, acc2 = {0.f};
......@@ -10313,7 +10432,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
......@@ -10335,7 +10454,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
......@@ -10348,13 +10467,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_MXFP4;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
......@@ -10365,6 +10483,9 @@ void kernel_mul_mv_mxfp4_f32_impl(
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int nb = args.ne00/QK_MXFP4;
const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
......@@ -10372,20 +10493,22 @@ void kernel_mul_mv_mxfp4_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};
device const float * yb = y + ix * QK_MXFP4 + it * 8;
device const float * yb = y + ix*QK_MXFP4 + it*8;
// note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
// no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
device const float4 * y4 = (device const float4 *) yb;
for (int ib = ix; ib < nb; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
#pragma unroll(nr0)
for (short row = 0; row < nr0; row++) {
device const block_mxfp4 & xb = x[row*nb + ib];
FOR_UNROLL (short row = 0; row < NR0; row++) {
device const block_mxfp4 & xb = x[row*ns01 + ib];
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
......@@ -10403,7 +10526,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
......@@ -10430,64 +10553,58 @@ kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11;
const int32_t i02 = i11;
const int32_t i03 = i12;
for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
float4x4 temp;
dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
dequantize_func(psrc + ind/nl, ind%nl, temp);
pdst[ind] = temp;
break;
}
}
template<typename T>
template<typename T0, typename T>
kernel void kernel_get_rows_f(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11;
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
(( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
}
}
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
kernel void kernel_get_rows_i32(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
const int32_t i02 = i11;
const int32_t i03 = i12;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
const int64_t i02 = i11;
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
pdst[ind] = psrc[ind];
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
(( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
break;
}
}
......@@ -10973,12 +11090,13 @@ kernel void kernel_mul_mm_id(
// get rows
//
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
#endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
......@@ -11383,3 +11501,51 @@ kernel void kernel_pool_2d_avg_f32(
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
......@@ -69,11 +69,20 @@
#define N_SG_IQ4_XS 2
// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
#define FC_MUL_MM 500
#define FC_FLASH_ATTN_EXT_PAD 100
#define FC_FLASH_ATTN_EXT_BLK 200
#define FC_FLASH_ATTN_EXT 300
#define FC_FLASH_ATTN_EXT_VEC 400
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
#define FC_MUL_MV 600
#define FC_MUL_MM 700
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
#define OP_FLASH_ATTN_EXT_NCPSG 64
#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
// kernel argument structs
//
......@@ -178,6 +187,7 @@ typedef struct {
} ggml_metal_kargs_clamp;
typedef struct {
int64_t nk0;
int64_t ne00;
int64_t ne01;
int64_t ne02;
......@@ -243,6 +253,35 @@ typedef struct {
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_pad;
typedef struct {
int32_t ne01;
int32_t ne30;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_blk;
typedef struct {
int32_t ne01;
int32_t ne02;
......@@ -261,6 +300,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
......@@ -295,6 +335,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
......@@ -503,6 +544,10 @@ typedef struct{
float limit;
} ggml_metal_kargs_glu;
typedef struct {
uint64_t np;
} ggml_metal_kargs_sum;
typedef struct {
int64_t ne00;
int64_t ne01;
......@@ -572,32 +617,45 @@ typedef struct {
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t s_off;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t ns12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t ns21;
uint64_t nb22;
int64_t ne30;
uint64_t nb31;
uint64_t nb41;
uint64_t nb42;
uint64_t ns42;
uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
uint64_t ns52;
uint64_t nb53;
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;
typedef struct {
int64_t ne00;
int32_t ne00t;
int32_t ne00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
uint64_t nb03;
int32_t ne10;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_get_rows;
typedef struct {
......@@ -719,4 +777,12 @@ typedef struct {
uint64_t nb01;
} ggml_metal_kargs_argmax;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_adamw;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;
#endif // GGML_METAL_IMPL
......@@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
......@@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
ggml_is_contiguous(node->src[1]), node->src[1]->name);
}
if (node->src[2]) {
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
ggml_is_contiguous(node->src[2]), node->src[2]->name);
}
if (node->src[3]) {
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
ggml_is_contiguous(node->src[3]), node->src[3]->name);
}
if (node) {
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
node->name);
......@@ -289,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_glu(ctx, idx);
} break;
case GGML_OP_SUM:
{
n_fuse = ggml_metal_op_sum(ctx, idx);
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
......@@ -398,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_argmax(ctx, idx);
} break;
case GGML_OP_OPT_STEP_ADAMW:
{
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
} break;
case GGML_OP_OPT_STEP_SGD:
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
......@@ -577,6 +601,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
......@@ -827,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
ggml_metal_kargs_sum args = {
/*.np =*/ n,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
......@@ -906,23 +955,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
ggml_metal_kargs_get_rows args = {
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nw0 = (args.ne00t + nth - 1)/nth;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
return 1;
}
......@@ -1172,25 +1229,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ns12 =*/ nb12/nb10,
/*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.ns21 =*/ nb21/nb20,
/*.nb22 =*/ nb22,
/*.ne30 =*/ ne30,
/*.nb31 =*/ nb31,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.ns42 =*/ nb42/nb40,
/*.nb43 =*/ nb43,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
/*.ns52 =*/ nb52/nb50,
/*.nb53 =*/ nb53,
/*.nb0 =*/ nb0,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_encoder_set_pipeline(enc, pipeline);
......@@ -1206,13 +1274,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
if (ne30 == 1) {
// Mamba-2
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
} else {
GGML_ASSERT(d_inner == 1);
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
}
return 1;
}
......@@ -1273,26 +1335,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
// TODO: support
//const int32_t nk00 = ne00/ggml_blck_size(op->type);
const int32_t nk00 = ne00;
int nth = 32; // SIMD width
while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
int64_t nk0 = ne00;
if (ggml_is_quantized(op->src[0]->type)) {
nk0 = ne00/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne00/ggml_blck_size(op->type);
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk00) {
nrptg = (nth + nk00 - 1)/nk00;
nth = nk00;
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
......@@ -1300,10 +1359,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
}
}
nth = std::min(nth, nk00);
nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = {
/*.ne00 =*/ nk00,
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
......@@ -1321,12 +1381,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3,
};
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
return 1;
}
......@@ -1520,9 +1582,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
!ggml_is_transposed(op->src[1]) &&
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
props_dev->has_simdgroup_mm && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
......@@ -1875,20 +1936,107 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
return (ne01 < 20) && (ne00 % 32 == 0);
}
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
size_t res = 0;
const bool has_mask = op->src[3] != nullptr;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
} else {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_NCPSG*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
}
return res;
}
size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
size_t res = 0;
const bool has_mask = op->src[3] != nullptr;
if (!has_mask) {
return res;
}
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
// this optimization is not useful for the vector kernels
if (is_vec) {
return res;
}
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
return res;
}
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
const int64_t nwg = 32;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
const int64_t ne01 = op->src[0]->ne[1];
const int64_t ne02 = op->src[0]->ne[2];
const int64_t ne03 = op->src[0]->ne[3];
const int64_t ne20 = op->src[2]->ne[0];
size_t res = 0;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const int64_t nwg = 32;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
}
return res;
}
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
......@@ -1911,7 +2059,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
......@@ -1921,8 +2068,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne12 == ne22);
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
float scale;
float max_bias;
......@@ -1949,15 +2096,111 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne01 < 65536);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_pad = bid_dst;
bid_pad.offs += ggml_nbytes(op);
ggml_metal_buffer_id bid_blk = bid_pad;
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
ggml_metal_buffer_id bid_tmp = bid_blk;
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
// half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
bool need_sync = false;
const bool has_kvpad = ne11 % ncpsg != 0;
if (has_kvpad) {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
ggml_metal_kargs_flash_attn_ext_pad args0 = {
/*.ne11 =*/ne11,
/*.ne_12_2 =*/ne12,
/*.ne_12_3 =*/ne13,
/*.nb11 =*/nb11,
/*.nb12 =*/nb12,
/*.nb13 =*/nb13,
/*.nb21 =*/nb21,
/*.nb22 =*/nb22,
/*.nb23 =*/nb23,
/*.ne31 =*/ne31,
/*.ne32 =*/ne32,
/*.ne33 =*/ne33,
/*.nb31 =*/nb31,
/*.nb32 =*/nb32,
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
assert(ne12 == ne22);
assert(ne13 == ne23);
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (has_mask) {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
ggml_metal_kargs_flash_attn_ext_blk args0 = {
/*.ne01 =*/ ne01,
/*.ne30 =*/ ne30,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
}
if (need_sync) {
ggml_metal_op_concurrency_reset(ctx);
}
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
// 2*(2*ncpsg)
......@@ -2007,6 +2250,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
......@@ -2023,24 +2267,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
if (op->src[3]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
}
if (op->src[4]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
}
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
......@@ -2048,14 +2286,62 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
const int64_t nkpsg = 1*ncpsg;
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
const int nkpsg = 1*ncpsg;
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
bool need_sync = false;
const bool has_kvpad = ne11 % ncpsg != 0;
if (has_kvpad) {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
ggml_metal_kargs_flash_attn_ext_pad args0 = {
/*.ne11 =*/ne11,
/*.ne_12_2 =*/ne12,
/*.ne_12_3 =*/ne13,
/*.nb11 =*/nb11,
/*.nb12 =*/nb12,
/*.nb13 =*/nb13,
/*.nb21 =*/nb21,
/*.nb22 =*/nb22,
/*.nb23 =*/nb23,
/*.ne31 =*/ne31,
/*.ne32 =*/ne32,
/*.ne33 =*/ne33,
/*.nb31 =*/nb31,
/*.nb32 =*/nb32,
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
assert(ne12 == ne22);
assert(ne13 == ne23);
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (need_sync) {
ggml_metal_op_concurrency_reset(ctx);
}
// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
......@@ -2120,6 +2406,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
......@@ -2136,25 +2423,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
if (op->src[3]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
}
if (op->src[4]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
}
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
const size_t smem = FATTN_SMEM(nsg);
......@@ -2162,23 +2441,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
if (nwg == 1) {
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
// using 1 workgroup -> write the result directly into dst
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
} else {
// sanity checks
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
// write the results from each workgroup into a temp buffer
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
......@@ -3156,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_adamw args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_sgd args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}
......@@ -39,6 +39,8 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
// return true if we should use the FA vector kernel for this op
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
......@@ -48,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
......@@ -76,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus
}
......
......@@ -195,9 +195,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
}
} break;
default:
break;
......@@ -543,6 +543,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac
props->type = ggml_backend_metal_device_get_type(dev);
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->library = GGML_METAL_NAME;
props->caps = {
/* .async = */ true,
......
......@@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
}
}
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
if (tiitg != 0) {
return;
}
float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
}
dst[0] = acc;
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
......@@ -2032,124 +2050,39 @@ kernel void kernel_ssm_conv_f32_f32(
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
kernel void kernel_ssm_conv_f32_f32_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
const int64_t i0 = tpitg.x;
const int64_t i1 = 0;
const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int64_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
s = state;
// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.
// Computed for each thread
float sumf = state * C[i0];
// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);
if (sgptg > 1) {
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
threadgroup_barrier(mem_flags::mem_threadgroup);
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
// For simd group 0 at indices < num simd groups, extract the shared
// simd sum
sumf = 0.0f;
if (sgitg == 0) {
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
y[0] = sumf;
}
}
} else if (tiisg == 0) {
y[0] = sumf;
}
float sumf = 0.0f;
// recurse
s0 = s;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
// Assign the final state to the output buffer
s_buff[i] = s;
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
kernel void kernel_ssm_scan_group_f32(
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
device const void * src1,
......@@ -2161,102 +2094,87 @@ kernel void kernel_ssm_scan_group_f32(
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;
const int64_t i0 = tpitg.x;
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
shared[tpitg.x] = 0.0f;
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int32_t i0 = tpitg.x;
const int32_t i1 = tgpig.x;
const int32_t ir = tgpig.y; // current head
const int32_t i3 = tgpig.z; // current seq
const int64_t nc = args.d_state;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int32_t nc = args.d_state;
const int32_t nr = args.d_inner;
const int32_t nh = args.n_head;
const int32_t ng = args.n_group;
const int32_t n_t = args.n_seq_tokens;
const int64_t s_off = args.s_off;
const int32_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
const int64_t g = ir / (nh / ng); // repeat_interleave
const int32_t i = i0 + i1*nc;
const int32_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
const float state = (s0 * dA) + (B[i0] * x_dt);
s = state;
// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.
// Computed for each thread
float sumf = state * C[i0];
// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);
float s = 0.0f;
// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
const float A0 = A[i0%args.ne30];
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// For simd group 0 at indices < num simd groups, extract the shared
// simd sum
sumf = 0.0f;
if (sgitg == 0) {
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float dt0 = dt[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
const float x_dt = x[0] * dtsp;
const float dA = exp(dtsp * A0);
s = (s0 * dA) + (B[i0] * x_dt);
const float sumf = simd_sum(s * C[i0]);
if (tiisg == 0) {
y[0] = sumf;
}
shared[t*NW + sgitg] = sumf;
}
// recurse
s0 = s;
x += args.ns12;
dt += args.ns21;
B += args.ns42;
C += args.ns52;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
if (tiisg == 0 && i2 + sgitg < n_t) {
y[sgitg*nh*nr] = sumf;
}
y += sgptg*nh*nr;
}
// Assign the final state to the output buffer
s_buff[i] = s;
}
......@@ -4513,10 +4431,142 @@ kernel void kernel_leaky_relu_f32_4(
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
}
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
// pad the last chunk of C elements of k and v into a an extra pad buffer
kernel void kernel_flash_attn_ext_pad(
constant ggml_metal_kargs_flash_attn_ext_pad & args,
device const char * k,
device const char * v,
device const char * mask,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t C = FC_flash_attn_ext_pad_ncpsg;
device char * k_pad = dst;
device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
const int32_t icp = args.ne11 % C;
const int32_t ic0 = args.ne11 - icp;
const int32_t i1 = tgpig[0];
const int32_t i2 = tgpig[1];
const int32_t i3 = tgpig[2];
if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
if (i1 >= icp) {
// here it is not important the exact value that will be used as we rely on masking out the scores in the attention
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
k_dst[i] = 0;
}
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
v_dst[i] = 0;
}
} else {
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
k_dst[i] = k_src[i];
}
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
v_dst[i] = v_src[i];
}
}
}
if (FC_flash_attn_ext_pad_has_mask) {
if (i2 < args.ne32 && i3 < args.ne33) {
for (int ib = i1; ib < args.ne31; ib += C) {
device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
for (int i = tiitg; i < C; i += ntg.x) {
if (i >= icp) {
mask_dst[i] = -MAXHALF;
} else {
mask_dst[i] = mask_src[i];
}
}
}
}
}
}
constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
// scan the blocks of the mask that are not masked
// 0 - masked (i.e. full of -INF, skip)
// 1 - not masked (i.e. at least one element of the mask is not -INF)
kernel void kernel_flash_attn_ext_blk(
constant ggml_metal_kargs_flash_attn_ext_blk & args,
device const char * mask,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
// block size C x Q
const int32_t Q = FC_flash_attn_ext_blk_nqptg;
const int32_t C = FC_flash_attn_ext_blk_ncpsg;
constexpr short NW = N_SIMDWIDTH;
const int32_t i3 = tgpig[2]/args.ne32;
const int32_t i2 = tgpig[2]%args.ne32;
const int32_t i1 = tgpig[1];
const int32_t i0 = tgpig[0];
char res = i0*C + C > args.ne30 ? 1 : 0;
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
// fast route
if (res == 0) {
if (simd_max(*mask_src) > -MAXHALF/2) {
res = 1;
}
}
// detailed check of the elements of the block
if ((C > NW || Q > 1) && res == 0) {
half m = -MAXHALF;
FOR_UNROLL (short j = 0; j < Q; ++j) {
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
m = max(m, mask_src[ii*NW]);
}
mask_src += args.nb31/2;
}
if (simd_max(m) > -MAXHALF/2) {
res = 1;
}
}
const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
const int32_t nblk0 = ((args.ne30 + C - 1)/C);
if (tiisg == 0) {
dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
}
}
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
......@@ -4563,6 +4613,8 @@ void kernel_flash_attn_ext_impl(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device const char * blk,
device char * dst,
threadgroup half * shmem_f16,
uint3 tgpig,
......@@ -4628,6 +4680,13 @@ void kernel_flash_attn_ext_impl(
pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
}
{
const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
const int32_t nblk0 = ((args.ne11 + C - 1)/C);
blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
}
{
q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
......@@ -4687,16 +4746,75 @@ void kernel_flash_attn_ext_impl(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic = 0; ic < args.ne11; ic += C) {
for (int ic0 = 0; ; ++ic0) {
int ic = ic0*C;
if (ic >= args.ne11) {
break;
}
// the last partial chunk uses the pad buffer as source
if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
k = pad;
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
if (!FC_flash_attn_ext_has_mask) {
threadgroup half * sm = (threadgroup half *) (sm2);
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
for (short i = tiisg; i < C; i += NW) {
if (ic + i >= args.ne11) {
sm[2*j*SH + i] = -MAXHALF;
}
}
}
} else {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
pm2[jj] = (device const half2 *) ((device const half *) mask +
(iq1 + j)*C +
(iq2%args.ne32)*(C*args.ne31) +
(iq3%args.ne33)*(C*args.ne31*args.ne32));
}
}
ic = 0;
}
// read the mask into shared mem
if (FC_flash_attn_ext_has_mask) {
if (blk[ic0] == 0) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
pm2[jj] += NW;
}
continue;
}
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
if (FC_flash_attn_ext_bc_mask) {
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
} else {
sm2[j*SH + tiisg] = pm2[jj][tiisg];
}
pm2[jj] += NW;
}
#if 0
// note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
threadgroup_barrier(mem_flags::mem_threadgroup);
// used to detect blocks full of -INF
......@@ -4715,13 +4833,14 @@ void kernel_flash_attn_ext_impl(
continue;
}
#endif
}
// Q*K^T
// this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
threadgroup const q_t * pq = sq;
threadgroup s_t * ps = ss;
......@@ -4732,26 +4851,24 @@ void kernel_flash_attn_ext_impl(
constexpr short NC = (C/8)/NSG;
// TODO: not good to unroll for large contexts - not sure why?
// note: do not unroll for large heads
#pragma unroll (DK <= 64 ? NC : 1)
for (short cc = 0; cc < NC; ++cc) {
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
if (DK8 % 16 != 0) {
if (DK % 16 != 0) {
k8x8_t mk;
q8x8_t mq;
FOR_UNROLL (short i = 0; i < DK8; ++i) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_load(mk, pk, NS10, 0, true);
simdgroup_load(mq, pq, DK);
simdgroup_load(mk, pk + 8*i, NS10, 0, true);
simdgroup_load(mq, pq + 8*i, DK);
simdgroup_barrier(mem_flags::mem_none);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
pk += 8;
pq += 8;
}
} else {
k8x8_t mk[2];
......@@ -4760,26 +4877,22 @@ void kernel_flash_attn_ext_impl(
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
simdgroup_load(mq[0], pq + 0*8, DK);
simdgroup_load(mq[1], pq + 1*8, DK);
simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
simdgroup_barrier(mem_flags::mem_none);
simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
pk += 16;
pq += 16;
}
}
simdgroup_store(mqk, ps, SH, 0, false);
pk += 8*(NSG*NS10 - DK8);
pq += 8*(NSG*0 - DK8);
pk += 8*(NSG*NS10);
ps += 8*(NSG);
}
} else {
......@@ -4793,7 +4906,7 @@ void kernel_flash_attn_ext_impl(
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
for (short ii = 0; ii < DK16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
if (DK16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
......@@ -4913,27 +5026,50 @@ void kernel_flash_attn_ext_impl(
}
{
auto sst = ss;
device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
pv += 8*sgitg;
if (DV <= 64) {
FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
s8x8_t vs;
simdgroup_load(vs, sst, SH, 0, false);
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
v8x8_t mv;
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
v8x8_t mv[2];
simdgroup_load(mv, pv, NS20, 0, false);
simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
pv += 8*NSG;
simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
}
pv += 8*(NS20 - NO*NSG);
sst += 8;
pv += 8*NS20;
}
} else {
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
s8x8_t vs[2];
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
v8x8_t mv[4];
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
}
pv += 2*8*NS20;
}
}
}
......@@ -4957,7 +5093,7 @@ void kernel_flash_attn_ext_impl(
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
if (DV16%4 == 0) {
// no need for bound checks
......@@ -5047,7 +5183,7 @@ void kernel_flash_attn_ext_impl(
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
const float scale = 1.0f/S[jj];
const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
if (DV4 % NW == 0) {
FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
......@@ -5092,8 +5228,8 @@ template<
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
short DK, // K head size
short DV, // V head size
short Q = 8, // queries per threadgroup
short C = 64> // cache items per threadgroup
short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
kernel void kernel_flash_attn_ext(
constant ggml_metal_kargs_flash_attn_ext & args,
device const char * q,
......@@ -5101,13 +5237,15 @@ kernel void kernel_flash_attn_ext(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device const char * blk,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
switch (FC_flash_attn_ext_nsg) {
// note: disabled cases to reduce library load time
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
......@@ -5227,6 +5365,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
......@@ -5253,9 +5392,9 @@ template<
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
short DK, // K head size
short DV, // V head size
short NE = 4, // head elements per thread
short Q = 1, // queries per threadgroup
short C = 32, // cache items per threadgroup
short NE, // head elements per thread
short Q, // queries per threadgroup
short C, // cache items per threadgroup
short NSG> // number of simd groups
void kernel_flash_attn_ext_vec_impl(
constant ggml_metal_kargs_flash_attn_ext_vec & args,
......@@ -5264,6 +5403,7 @@ void kernel_flash_attn_ext_vec_impl(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
......@@ -5369,12 +5509,38 @@ void kernel_flash_attn_ext_vec_impl(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
const int ic = ic0 + C*sgitg;
for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
int ic = ic0*C;
if (ic >= args.ne11) {
break;
}
// the last partial chunk uses the pad buffer as source
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
k = pad;
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
if (!FC_flash_attn_ext_vec_has_mask) {
if (ic + tiisg >= args.ne11) {
sm[tiisg] = -MAXHALF;
}
} else {
pm = (device const half *) (mask) +
iq1*C +
(iq2%args.ne32)*(C*args.ne31) +
(iq3%args.ne33)*(C*args.ne31*args.ne32);
}
ic = 0;
}
if (FC_flash_attn_ext_vec_has_mask) {
sm[tiisg] = pm[ic + tiisg];
}
......@@ -5386,7 +5552,7 @@ void kernel_flash_attn_ext_vec_impl(
// Q*K^T
{
device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
threadgroup const q4_t * pq4 = sq4;
pk4 += ty*NS10/4 + tx;
......@@ -5401,7 +5567,7 @@ void kernel_flash_attn_ext_vec_impl(
mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
}
} else {
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
k4_t mk;
......@@ -5499,7 +5665,7 @@ void kernel_flash_attn_ext_vec_impl(
}
if (is_same<vd4_t, v4_t>::value) {
device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
pv4 += ty*NS20/4 + tx;
......@@ -5512,7 +5678,7 @@ void kernel_flash_attn_ext_vec_impl(
}
} else {
FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
const short i = ii*NL + tx;
......@@ -5637,7 +5803,7 @@ void kernel_flash_attn_ext_vec_impl(
device float4 * dst4 = (device float4 *) dst;
device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
// interleave the workgroup data
for (short i = tiisg; i < DV4; i += NW) {
......@@ -5675,8 +5841,8 @@ template<
short DK, // K head size
short DV, // V head size
short NE = 4, // head elements per thread
short Q = 1, // queries per threadgroup
short C = 32> // cache items per threadgroup
short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext_vec & args,
device const char * q,
......@@ -5684,13 +5850,14 @@ kernel void kernel_flash_attn_ext_vec(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
switch (FC_flash_attn_ext_vec_nsg) {
// note: disabled cases to reduce library load time
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
......@@ -5814,7 +5981,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
const float m = simd_max(M);
const float ms = exp(M - m);
S = 1.0f/simd_sum(S*ms);
S = simd_sum(S*ms);
S = S == 0.0f ? 0.0f : 1.0f/S;
const short DV4 = DV/4;
......@@ -5834,21 +6002,17 @@ kernel void kernel_flash_attn_ext_vec_reduce(
}
template<typename T0, typename T1>
kernel void kernel_cpy(
kernel void kernel_cpy_t_t(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 tptg[[threads_per_threadgroup]]) {
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
......@@ -5859,190 +6023,70 @@ kernel void kernel_cpy(
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
}
}
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
#endif
// TODO: templetify these kernels
kernel void kernel_cpy_f32_q8_0(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q8_0(src, dst_data[i00/QK8_0]);
}
}
kernel void kernel_cpy_f32_q4_0(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q4_0(src, dst_data[i00/QK4_0]);
}
}
kernel void kernel_cpy_f32_q4_1(
template<short QK,
typename block_q,
void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_cpy_f32_q(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q4_1(src, dst_data[i00/QK4_1]);
}
}
kernel void kernel_cpy_f32_q5_0(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_q5_0(src, dst_data[i00/QK5_0]);
}
}
quantize_func(src, dst_data[i00]);
kernel void kernel_cpy_f32_q5_1(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q5_1(src, dst_data[i00/QK5_1]);
break;
}
}
kernel void kernel_cpy_f32_iq4_nl(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
}
}
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32(
......@@ -6050,11 +6094,12 @@ kernel void kernel_cpy_q_f32(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
......@@ -6066,10 +6111,12 @@ kernel void kernel_cpy_q_f32(
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
break;
}
}
......@@ -7522,7 +7569,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const char * src0,
......@@ -7535,13 +7582,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
......@@ -7552,6 +7598,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int nb = args.ne00/QK4_NL;
const int ns01 = args.nb01/args.nb00;
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
......@@ -7559,24 +7608,25 @@ void kernel_mul_mv_iq4_nl_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};
device const float * yb = y + ix * QK4_NL + it * 8;
device const float * yb = y + ix*QK4_NL + it*8;
uint32_t aux32[2];
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
float4 qf1, qf2;
for (int ib = ix; ib < nb; ib += 16) {
// [TAG_MUL_MV_WEIRD]
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
for (short row = 0; row < nr0; row++) {
device const block_iq4_nl & xb = x[row*nb + ib];
for (short row = 0; row < NR0; row++) {
device const block_iq4_nl & xb = x[row*ns01 + ib];
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
float4 acc1 = {0.f}, acc2 = {0.f};
......@@ -7607,7 +7657,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
......@@ -7629,7 +7679,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const char * src0,
......@@ -7642,12 +7692,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
......@@ -7658,6 +7707,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int nb = args.ne00/QK_K;
const int ns01 = args.nb01/args.nb00;
const short ix = tiisg/16; // 0 or 1
const short it = tiisg%16; // 0...15
const short ib = it/2;
......@@ -7667,7 +7719,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
......@@ -7676,15 +7728,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
float4 qf1, qf2;
for (int ibl = ix; ibl < nb; ibl += 2) {
// [TAG_MUL_MV_WEIRD]
for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
for (short row = 0; row < nr0; ++row) {
device const block_iq4_xs & xb = x[row*nb + ibl];
for (short row = 0; row < NR0; ++row) {
device const block_iq4_xs & xb = x[row*ns01 + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
float4 acc1 = {0.f}, acc2 = {0.f};
......@@ -7714,7 +7767,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
......@@ -7736,7 +7789,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, typename args_t>
template<int NR0, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
......@@ -7749,13 +7802,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_MXFP4;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * nr0;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
......@@ -7766,6 +7818,9 @@ void kernel_mul_mv_mxfp4_f32_impl(
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int nb = args.ne00/QK_MXFP4;
const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
......@@ -7773,20 +7828,22 @@ void kernel_mul_mv_mxfp4_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f};
float sumf[NR0]={0.f};
device const float * yb = y + ix * QK_MXFP4 + it * 8;
device const float * yb = y + ix*QK_MXFP4 + it*8;
// note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
// no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
device const float4 * y4 = (device const float4 *) yb;
for (int ib = ix; ib < nb; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
#pragma unroll(nr0)
for (short row = 0; row < nr0; row++) {
device const block_mxfp4 & xb = x[row*nb + ib];
FOR_UNROLL (short row = 0; row < NR0; row++) {
device const block_mxfp4 & xb = x[row*ns01 + ib];
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
......@@ -7804,7 +7861,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
......@@ -7831,64 +7888,58 @@ kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11;
const int32_t i02 = i11;
const int32_t i03 = i12;
for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
float4x4 temp;
dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
dequantize_func(psrc + ind/nl, ind%nl, temp);
pdst[ind] = temp;
break;
}
}
template<typename T>
template<typename T0, typename T>
kernel void kernel_get_rows_f(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11;
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
(( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
}
}
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
kernel void kernel_get_rows_i32(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
const int32_t i02 = i11;
const int32_t i03 = i12;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
const int64_t i02 = i11;
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
pdst[ind] = psrc[ind];
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
(( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
break;
}
}
......@@ -8374,12 +8425,13 @@ kernel void kernel_mul_mm_id(
// get rows
//
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
#endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
......@@ -8784,3 +8836,51 @@ kernel void kernel_pool_2d_avg_f32(
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
......@@ -1143,10 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSIGMOID",
"EXP",
"GELU_ERF",
"XIELU",
};
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
......@@ -2652,6 +2652,29 @@ struct ggml_tensor * ggml_silu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
}
// ggml_xielu
struct ggml_tensor * ggml_xielu(
struct ggml_context * ctx,
struct ggml_tensor * a,
float alpha_n,
float alpha_p,
float beta,
float eps) {
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
ggml_set_op_params_f32(result, 3, beta);
ggml_set_op_params_f32(result, 4, eps);
result->op = GGML_OP_UNARY;
result->src[0] = a;
return result;
}
// ggml_silu_back
struct ggml_tensor * ggml_silu_back(
......@@ -3829,6 +3852,15 @@ struct ggml_tensor * ggml_soft_max_ext(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
struct ggml_tensor * ggml_soft_max_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
float max_bias) {
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
}
void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks) {
......
......@@ -56,7 +56,7 @@ func (c *ImageContext) Free(modelPath string) {
}
}
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]float32, error) {
func (c *ImageContext) MultimodalTokenize(llamaContext *llama.Context, data []byte) ([]llama.MtmdChunk, error) {
if c == nil {
return nil, nil
}
......@@ -70,10 +70,10 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]f
c.mu.Lock()
defer c.mu.Unlock()
embed, err := c.findImage(hash)
chunks, err := c.findImage(hash)
if err != nil {
if c.mtmd != nil {
embed, err = c.mtmd.NewEmbed(llamaContext, data)
chunks, err = c.mtmd.MultimodalTokenize(llamaContext, data)
if err != nil {
return nil, err
}
......@@ -81,10 +81,10 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]f
return nil, errors.New("received image but vision model not loaded")
}
c.addImage(hash, embed)
c.addImage(hash, chunks)
}
return embed, nil
return chunks, nil
}
func (c *ImageContext) BatchSize(configuredBatchSize int) int {
......@@ -102,7 +102,7 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
type imageCache struct {
key uint64
val [][]float32
val []llama.MtmdChunk
lastUsed time.Time
}
......@@ -114,7 +114,7 @@ func (c *ImageContext) hashImage(image []byte) uint64 {
var errImageNotFound = errors.New("image not found in cache")
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
func (c *ImageContext) findImage(hash uint64) ([]llama.MtmdChunk, error) {
for i := range c.images {
if c.images[i].key == hash {
slog.Debug("loading image embeddings from cache", "entry", i)
......@@ -126,7 +126,7 @@ func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
return nil, errImageNotFound
}
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
func (c *ImageContext) addImage(hash uint64, embed []llama.MtmdChunk) {
best := time.Now()
var bestImage int
......
......@@ -3,16 +3,18 @@ package llamarunner
import (
"reflect"
"testing"
"github.com/ollama/ollama/llama"
)
func TestImageCache(t *testing.T) {
cache := ImageContext{images: make([]imageCache, 4)}
valA := [][]float32{{0.1, 0.2}, {0.3}}
valB := [][]float32{{0.4}, {0.5}, {0.6}}
valC := [][]float32{{0.7}}
valD := [][]float32{{0.8}}
valE := [][]float32{{0.9}}
valA := []llama.MtmdChunk{{Embed: []float32{0.1, 0.2}}, {Embed: []float32{0.3}}}
valB := []llama.MtmdChunk{{Embed: []float32{0.4}}, {Embed: []float32{0.5}}, {Embed: []float32{0.6}}}
valC := []llama.MtmdChunk{{Embed: []float32{0.7}}}
valD := []llama.MtmdChunk{{Embed: []float32{0.8}}}
valE := []llama.MtmdChunk{{Embed: []float32{0.9}}}
// Empty cache
result, err := cache.findImage(0x5adb61d31933a946)
......
......@@ -209,13 +209,19 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error)
return nil, fmt.Errorf("invalid image index: %d", n)
}
embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data)
chunks, err := s.image.MultimodalTokenize(s.lc, images[imageIndex].Data)
if err != nil {
return nil, err
}
for _, e := range embed {
inputs = append(inputs, input{embed: e})
for _, c := range chunks {
if len(c.Embed) != 0 {
inputs = append(inputs, input{embed: c.Embed})
} else {
for _, t := range c.Tokens {
inputs = append(inputs, input{token: t})
}
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment