Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
3a9e8e9f
Unverified
Commit
3a9e8e9f
authored
Nov 12, 2025
by
Daniel Hiltgen
Committed by
GitHub
Nov 12, 2025
Browse files
vulkan: temporary cary of vulkan fixes (#12971)
This should be reverted once we update ggml past b6897
parent
cb1cb064
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4958 additions
and
257 deletions
+4958
-257
llama/patches/0029-vulkan-Call-ggml_vk_buffer_write_2d-from-ggml_vk_buf.patch
...ulkan-Call-ggml_vk_buffer_write_2d-from-ggml_vk_buf.patch
+32
-0
llama/patches/0030-Vulkan-MMQ-Integer-Dot-Refactor-and-K-Quant-support-.patch
...ulkan-MMQ-Integer-Dot-Refactor-and-K-Quant-support-.patch
+2140
-0
llama/patches/0031-vulkan-Update-topk_moe-fusion-to-handle-gpt-s-late-s.patch
...ulkan-Update-topk_moe-fusion-to-handle-gpt-s-late-s.patch
+657
-0
llama/patches/0032-vulkan-Fuse-rope-set_rows-16769.patch
llama/patches/0032-vulkan-Fuse-rope-set_rows-16769.patch
+1242
-0
llama/patches/0033-vulkan-Handle-argsort-with-a-large-number-of-rows-16.patch
...ulkan-Handle-argsort-with-a-large-number-of-rows-16.patch
+85
-0
llama/patches/0034-vulkan-fix-shmem-overrun-in-mmq-id-shader-16873.patch
...034-vulkan-fix-shmem-overrun-in-mmq-id-shader-16873.patch
+77
-0
llama/patches/0035-vulkan-Fix-crash-when-FP16-mul_mat-accumulation-is-n.patch
...ulkan-Fix-crash-when-FP16-mul_mat-accumulation-is-n.patch
+80
-0
ml/backend/ggml/ggml/src/ggml-impl.h
ml/backend/ggml/ggml/src/ggml-impl.h
+16
-0
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
+1
-1
ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+594
-224
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
...end/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
+12
-4
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
...ml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
+5
-5
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
...gml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
+3
-3
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
...ml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
+2
-2
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
...gml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
+2
-2
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
...gml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
+2
-2
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
...gml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
+2
-2
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
...ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
+2
-4
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
...ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
+2
-4
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
...ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
+2
-4
No files found.
llama/patches/0029-vulkan-Call-ggml_vk_buffer_write_2d-from-ggml_vk_buf.patch
0 → 100644
View file @
3a9e8e9f
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 03:53:04 -0500
Subject: [PATCH] vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy
(#16793)
This lets the copy to the destination device use the host-visible
vidmem optimization.
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 221e29509..18b7cbccf 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -5654,14 +5654,11 @@
static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
// Copy device to device
ggml_vk_ensure_sync_staging_buffer(src->device, size);
- ggml_vk_ensure_sync_staging_buffer(dst->device, size);
// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
- // memcpy to dst staging buffer
- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
// Copy to dst buffer
- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
+ ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
}
}
llama/patches/0030-Vulkan-MMQ-Integer-Dot-Refactor-and-K-Quant-support-.patch
0 → 100644
View file @
3a9e8e9f
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Ruben Ortlam <picard12@live.de>
Date: Wed, 29 Oct 2025 14:39:03 +0100
Subject: [PATCH] Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
* vulkan: add mmq q2_k integer dot support
* Refactor mmq caching
* Reduce mmq register use
* Load 4 quant blocks into shared memory in one step
* Pack q2_k blocks into caches of 32
* Use 32-bit accumulators for integer dot matmul
* Add q4_k mmq
* Add q3_k mmq
* Add q5_k mmq
* Add q6_k mmq
* Add mxfp4 mmq, enable MMQ MUL_MAT_ID
* Fix mmv dm loads
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 165 +++++-
.../vulkan-shaders/dequant_funcs.glsl | 10 +-
.../vulkan-shaders/dequant_funcs_cm2.glsl | 6 +-
.../vulkan-shaders/dequant_mxfp4.comp | 4 +-
.../vulkan-shaders/dequant_q2_k.comp | 4 +-
.../vulkan-shaders/dequant_q4_k.comp | 4 +-
.../vulkan-shaders/dequant_q5_k.comp | 4 +-
.../vulkan-shaders/mul_mat_vec_q2_k.comp | 6 +-
.../vulkan-shaders/mul_mat_vec_q4_k.comp | 6 +-
.../vulkan-shaders/mul_mat_vec_q5_k.comp | 6 +-
.../ggml-vulkan/vulkan-shaders/mul_mm.comp | 72 +--
.../vulkan-shaders/mul_mm_funcs.glsl | 14 +-
.../vulkan-shaders/mul_mm_id_funcs.glsl | 70 +++
.../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 288 +++-------
.../vulkan-shaders/mul_mmq_funcs.glsl | 538 ++++++++++++++++--
.../vulkan-shaders/mul_mmq_shmem_types.glsl | 78 +++
.../src/ggml-vulkan/vulkan-shaders/types.glsl | 53 +-
.../vulkan-shaders/vulkan-shaders-gen.cpp | 5 +-
18 files changed, 928 insertions(+), 405 deletions(-)
create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 18b7cbccf..53b57c179 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -488,6 +488,7 @@
struct vk_device_struct {
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_quantize_q8_1;
@@ -2449,8 +2450,11 @@
static void ggml_vk_load_shaders(vk_device& device) {
l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
+ l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
- l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
+ l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
+ l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
@@ -2513,10 +2517,16 @@
static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+ // Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
+ // K-quants use even more registers, mitigate by setting WMITER to 1
+ l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
+ m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
+ s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
+
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
@@ -2525,10 +2535,18 @@
static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
+ l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
+ m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
+ s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
+
+ l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
+ m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
+ s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
+
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
- m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
+ m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2913,18 +2931,15 @@
static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
-#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _m[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _s[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
// Create 2 variants, {f16,f32} accumulator
@@ -2963,11 +2978,19 @@
static void ggml_vk_load_shaders(vk_device& device) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
}
#endif
@@ -2997,6 +3020,24 @@
static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ if (device->integer_dot_product) {
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ }
+#endif
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
@@ -3023,6 +3064,24 @@
static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ if (device->integer_dot_product) {
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ }
+#endif
}
#undef CREATE_MM2
#undef CREATE_MMQ
@@ -3087,6 +3146,12 @@
static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
}
#endif
@@ -3146,7 +3211,7 @@
static void ggml_vk_load_shaders(vk_device& device) {
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
-#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
&& !device->coopmat_bf16_support
#endif
) {
@@ -4930,7 +4995,7 @@
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
- vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
return nullptr;
@@ -5077,6 +5142,17 @@
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
}
+ // MMQ
+ if (src1_type == GGML_TYPE_Q8_1) {
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
+
+ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ return nullptr;
+ }
+
+ return pipelines;
+ }
+
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
switch (src0_type) {
@@ -6879,10 +6955,19 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
+
+ // Check for mmq first
+ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
+
+ if (mmp == nullptr) {
+ // Fall back to f16 dequant mul mat
+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
+ quantize_y = false;
+ }
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
- const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat
@@ -6892,8 +6977,8 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
- const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
+ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
@@ -6906,12 +6991,13 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
const uint64_t ids_sz = nbi2;
const uint64_t d_sz = sizeof(float) * d_ne;
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
+ vk_pipeline to_q8_1 = nullptr;
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
@@ -6926,9 +7012,16 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
+ if (quantize_y) {
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
+ }
+
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
+ }
if (
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
@@ -6937,7 +7030,7 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
ctx->prealloc_size_x = x_sz_upd;
}
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = y_sz_upd;
}
@@ -6949,6 +7042,9 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
}
+ if (quantize_y) {
+ ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
+ }
return;
}
@@ -6985,6 +7081,9 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
+ } else if (quantize_y) {
+ d_Y = ctx->prealloc_y;
+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
@@ -7016,6 +7115,17 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_y_last_tensor_used = src1;
}
}
+ if (quantize_y) {
+ if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
+ ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
+ ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
+ ctx->prealloc_y_last_tensor_used = src1;
+ }
+ }
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
@@ -7024,14 +7134,19 @@
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
}
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
+ uint32_t y_sz_total = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
+ }
+
// compute
ggml_vk_matmul_id(
ctx, subctx, pipeline,
- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
index 0d98f5a9d..09676a623 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
@@ -437,7 +437,7 @@
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
+ return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
@@ -488,9 +488,9 @@
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
- const vec2 d = vec2(data_a[a_offset + ib].d);
+ const vec2 dm = vec2(data_a[a_offset + ib].dm);
- return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+ return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
@@ -529,7 +529,7 @@
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
- const vec2 loadd = vec2(data_a[a_offset + ib].d);
+ const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -567,7 +567,7 @@
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint8_t hm = uint8_t(1 << (iqs / 16));
- const vec2 loadd = vec2(data_a[a_offset + ib].d);
+ const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
index 67baedf7c..8ac6482dc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
@@ -120,7 +120,7 @@
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
- const f16vec2 d = bl.block.d;
+ const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
@@ -131,7 +131,7 @@
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
- float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
+ float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
return ret;
}
@@ -680,7 +680,7 @@
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
- float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
+ float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
}
#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
index ffba5a77d..3194ba291 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
@@ -26,7 +26,7 @@
void main() {
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
- data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
- data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
+ data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
+ data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
index 58dc2e5df..dc05a7834 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
@@ -24,8 +24,8 @@
void main() {
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
- FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+ FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
+ FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
index 8b7be557e..0f23dc0a3 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
@@ -20,8 +20,8 @@
void main() {
const uint is = 2 * il;
const uint n = 4;
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
index 6bc04670f..970469a60 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
@@ -19,8 +19,8 @@
void main() {
const uint ir = tid % 16;
const uint is = 2 * il;
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
index 03ed25d3b..14093c0de 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
@@ -41,9 +41,7 @@
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
@@ -75,7 +73,7 @@
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
- temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
+ temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
}
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
index 21d07d2e5..49d91ad59 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
@@ -14,9 +14,7 @@
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -81,7 +79,7 @@
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+ temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
index 9e46c89a1..0d61b4966 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
@@ -14,9 +14,7 @@
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -113,7 +111,7 @@
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+ temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
index a20788c4b..d260969f0 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
@@ -120,81 +120,11 @@
shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[BN];
-uint _ne1;
-
-#ifdef MUL_MAT_ID_USE_SUBGROUPS
-shared uvec4 ballots_sh[NUM_WARPS];
-
-void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
- _ne1 = 0;
- uint num_elements = p.nei1 * p.nei0;
- uint nei0shift = findLSB(p.nei0);
-
- uint ids[16];
- uint iter = 0;
-
- for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
- // prefetch up to 16 elements
- if (iter == 0) {
- [[unroll]] for (uint k = 0; k < 16; ++k) {
- uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
- }
- }
- uint i = j + gl_LocalInvocationIndex;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- uint id = ids[iter++];
- uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
-
- ballots_sh[gl_SubgroupID] = ballot;
- barrier();
-
- uint subgroup_base = 0;
- uint total = 0;
- for (uint k = 0; k < gl_NumSubgroups; ++k) {
- if (k == gl_SubgroupID) {
- subgroup_base = total;
- }
- total += subgroupBallotBitCount(ballots_sh[k]);
- }
- barrier();
-
- uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
- row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
- }
- _ne1 += total;
- iter &= 15;
- if (_ne1 >= (ic + 1) * BN) {
- break;
- }
- }
- barrier();
-}
-#endif // MUL_MAT_ID_USE_SUBGROUPS
-#endif // MUL_MAT_ID
-
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
+#include "mul_mm_id_funcs.glsl"
#include "mul_mm_funcs.glsl"
void main() {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
index 0ebfbd646..ee5ded2e8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
@@ -134,15 +134,15 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
- const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
+ const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
- const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
+ const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi];
- const vec2 d = vec2(data_a[ib].d);
+ const vec2 dm = vec2(data_a[ib].dm);
- const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+ const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)
@@ -179,7 +179,7 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
- const vec2 loadd = vec2(data_a[ib].d);
+ const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -215,7 +215,7 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint8_t hm = uint8_t(1 << (iqs / 16));
- const vec2 loadd = vec2(data_a[ib].d);
+ const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -468,7 +468,7 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
- const float d = e8m0_to_fp32(data_a[ib].e);
+ const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
new file mode 100644
index 000000000..1d0e84ac9
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
@@ -0,0 +1,70 @@
+#ifdef MUL_MAT_ID
+shared u16vec2 row_ids[BN];
+uint _ne1;
+
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+shared uvec4 ballots_sh[NUM_WARPS];
+
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+ uint nei0shift = findLSB(p.nei0);
+
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_LocalInvocationIndex;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+ ballots_sh[gl_SubgroupID] = ballot;
+ barrier();
+
+ uint subgroup_base = 0;
+ uint total = 0;
+ for (uint k = 0; k < gl_NumSubgroups; ++k) {
+ if (k == gl_SubgroupID) {
+ subgroup_base = total;
+ }
+ total += subgroupBallotBitCount(ballots_sh[k]);
+ }
+ barrier();
+
+ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
+ }
+ _ne1 += total;
+ iter &= 15;
+ if (_ne1 >= (ic + 1) * BN) {
+ break;
+ }
+ }
+ barrier();
+}
+#endif // MUL_MAT_ID_USE_SUBGROUPS
+#endif // MUL_MAT_ID
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index b5d761c0b..8b238ac4b 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -10,10 +10,9 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
-#ifdef COOPMAT
-#extension GL_KHR_cooperative_matrix : enable
-#extension GL_KHR_memory_scope_semantics : enable
+#if defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
@@ -24,7 +23,10 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
@@ -76,40 +78,27 @@
layout (constant_id = 10) const uint WARP = 32;
#define BK 32
-#ifdef COOPMAT
-#define SHMEM_STRIDE (BK / 4 + 4)
-#else
-#define SHMEM_STRIDE (BK / 4 + 1)
-#endif
+#define MMQ_SHMEM
-shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
+#include "mul_mmq_shmem_types.glsl"
-#ifndef COOPMAT
-#if QUANT_AUXF == 1
-shared FLOAT_TYPE buf_a_dm[BM];
-#else
-shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
-#endif
+#ifndef BK_STEP
+#define BK_STEP 4
#endif
-shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
-#ifndef COOPMAT
-shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
-#endif
+// Shared memory cache
+shared block_a_cache buf_a[BM * BK_STEP];
+shared block_b_cache buf_b[BN * BK_STEP];
+// Register cache
+block_a_cache cache_a[WMITER * TM];
+block_b_cache cache_b;
-#define LOAD_VEC_A (4 * QUANT_R)
+#define LOAD_VEC_A (4 * QUANT_R_MMQ)
#define LOAD_VEC_B 16
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[4096];
-#endif // MUL_MAT_ID
-
#define NUM_WARPS (BLOCK_SIZE / WARP)
-#ifdef COOPMAT
-shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
-#endif
-
+#include "mul_mm_id_funcs.glsl"
#include "mul_mmq_funcs.glsl"
void main() {
@@ -139,26 +128,12 @@
void main() {
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
-
-#ifdef COOPMAT
- const uint warp_i = gl_SubgroupID;
-
- const uint tiw = gl_SubgroupInvocationID;
-
- const uint cms_per_row = WM / TM;
- const uint cms_per_col = WN / TN;
-
- const uint storestride = WARP / TM;
- const uint store_r = tiw % TM;
- const uint store_c = tiw / TM;
-#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
-#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
@@ -172,17 +147,27 @@
void main() {
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
- uint _ne1 = 0;
- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+ if (bitCount(p.nei0) == 1) {
+ load_row_ids(expert_idx, true, ic);
+ } else {
+ load_row_ids(expert_idx, false, ic);
+ }
+#else
+ _ne1 = 0;
+ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
+ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
- row_ids[_ne1] = u16vec2(ii0, ii1);
+ if (_ne1 >= ic * BN) {
+ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
+ }
_ne1++;
}
}
}
barrier();
+#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
@@ -209,159 +194,70 @@
void main() {
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
-#ifdef COOPMAT
- coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
- coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
- coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
-
- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
-
- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
-
- [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
- sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
- }
-#else
- int32_t cache_a_qs[WMITER * TM * BK / 4];
-
- int32_t cache_b_qs[TN * BK / 4];
-
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
-#endif
-#if QUANT_AUXF == 1
- FLOAT_TYPE cache_a_dm[WMITER * TM];
-#else
- FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
-#endif
-
- FLOAT_TYPE_VEC2 cache_b_ds[TN];
-
- for (uint block = start_k; block < end_k; block += BK) {
+ for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
- const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
- const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
+ const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
+ const uint iqs = loadr_a;
- if (iqs == 0) {
-#if QUANT_AUXF == 1
- buf_a_dm[buf_ib] = get_d(ib);
-#else
- buf_a_dm[buf_ib] = get_dm(ib);
-#endif
+ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
-#if QUANT_R == 1
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
-#else
- const i32vec2 vals = repack(ib, iqs);
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
-#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
+ const uint buf_ib = loadc_b + l;
+
#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
- const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
- const uint ib = idx / 8;
- const uint iqs = idx & 0x7;
+ const u16vec2 row_idx = row_ids[buf_ib];
+ const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
#else
- const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
- const uint ib_outer = ib / 4;
- const uint ib_inner = ib % 4;
-
- const uint iqs = loadr_b;
+ const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
#endif
+ const uint iqs = loadr_b;
- const uint buf_ib = loadc_b + l;
-
- if (iqs == 0) {
- buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
+ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
- const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
}
barrier();
- pos_a_ib += 1;
- pos_b_ib += 1;
+ pos_a_ib += BK_STEP;
+ pos_b_ib += BK_STEP;
-#ifdef COOPMAT
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- const uint ib_a = warp_r * WM + cm_row * TM;
+ for (uint k_step = 0; k_step < BK_STEP; k_step++) {
// Load from shared into cache
- coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
-
- // TODO: only cache values that are actually needed
- [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
- cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
- }
-
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- const uint ib_b = warp_c * WN + cm_col * TN;
- coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
-
- // TODO: only cache values that are actually needed
- [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
- cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
- }
-
- cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
- cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
- }
-
- coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
- sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
- }
- }
-#else
- // Load from shared into cache
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
- cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
- }
- }
- }
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint reg_ib = wsir * TM + cr;
+ const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
- cache_b_ds[cc] = buf_b_ds[ib];
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
+ block_a_to_registers(reg_ib, k_step * BM + buf_ib);
}
}
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint cache_a_idx = wsir * TM + cr;
- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
- int32_t q_sum = 0;
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
- cache_b_qs[cc * (BK / 4) + idx_k]);
- }
+ const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
+ block_b_to_registers(ib);
- sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint cache_a_idx = wsir * TM + cr;
+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+
+ sums[sums_idx] += mmq_dot_product(cache_a_idx);
+ }
}
}
}
}
-#endif
barrier();
}
@@ -373,54 +269,6 @@
void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
-#ifdef COOPMAT
-#ifdef MUL_MAT_ID
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < BN; col += storestride) {
- const uint row_i = dc + cm_col * TN + col + store_c;
- if (row_i >= _ne1) break;
-
- const u16vec2 row_idx = row_ids[row_i];
-
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- }
- }
-#else
- const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
-
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
-
- if (is_aligned && is_in_bounds) {
- // Full coopMat is within bounds and stride_d is aligned with 16B
- coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
- coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
- } else if (is_in_bounds) {
- // Full coopMat is within bounds, but stride_d is not aligned
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
- // Partial coopMat is within bounds
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- }
- }
- }
- }
-#endif // MUL_MAT_ID
-#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@@ -431,19 +279,21 @@
void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
#ifdef MUL_MAT_ID
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ if (dr_warp + cr < p.M) {
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
+ }
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#endif // MUL_MAT_ID
}
}
}
}
-#endif // COOPMAT
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
index fe71eb131..c0c03fedc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
@@ -6,41 +6,89 @@
// Each iqs value maps to a 32-bit integer
-#if defined(DATA_A_Q4_0)
+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
+// 2-byte loads for Q4_0 blocks (18 bytes)
+// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
- const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]);
+#ifdef DATA_A_Q4_0
+ const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
+#else // DATA_A_Q4_1
+ const uint32_t vui = data_a_packed32[ib].qs[iqs];
+ return i32vec2( vui & 0x0F0F0F0F,
+ (vui >> 4) & 0x0F0F0F0F);
+#endif
}
+#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
+#else // DATA_A_Q4_1
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+}
#endif
-#if defined(DATA_A_Q4_1)
-i32vec2 repack(uint ib, uint iqs) {
- // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
- const uint32_t vui = data_a_packed32[ib].qs[iqs];
- return i32vec2( vui & 0x0F0F0F0F,
- (vui >> 4) & 0x0F0F0F0F);
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+#ifdef DATA_A_Q4_0
+ buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ }
+#else // DATA_A_Q4_1
+ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+ }
+#endif
}
-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
}
-#endif
-#if defined(DATA_A_Q5_0)
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const uint32_t vui = cache_a[ib_a].qs[iqs];
+ const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
+ (vui >> 4) & 0x0F0F0F0F);
+
+ const int32_t qs_b0 = cache_b.qs[iqs];
+ const int32_t qs_b1 = cache_b.qs[iqs + 4];
+
+ q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
+ q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+
+#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
+// 2-byte loads for Q5_0 blocks (22 bytes)
+// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
- const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]);
+ const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
- const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
+#ifdef DATA_A_Q5_0
+ const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
+#else // DATA_A_Q5_1
+ const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
+#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
@@ -50,40 +98,457 @@
i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1);
}
+#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
+#else // DATA_A_Q5_1
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+}
#endif
-#if defined(DATA_A_Q5_1)
-i32vec2 repack(uint ib, uint iqs) {
- // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
- const uint32_t vui = data_a_packed32[ib].qs[iqs];
- const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
- const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
- | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+#ifdef DATA_A_Q5_0
+ buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
- const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
- | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
+ }
+#else // DATA_A_Q5_1
+ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
- return i32vec2(v0, v1);
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+ buf_a[buf_ib].qh = data_a_packed32[ib].qh;
+ }
+#endif
}
-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+ cache_a[reg_ib].qh = buf_a[buf_ib].qh;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const uint32_t vui = cache_a[ib_a].qs[iqs];
+ const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
+ const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
+ | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+ const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
+ | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+
+ const int32_t qs_b0 = cache_b.qs[iqs];
+ const int32_t qs_b1 = cache_b.qs[iqs + 4];
+
+ q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
+ q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
+// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
- return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]));
+ return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+ const int32_t qs_b = cache_b.qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, qs_b);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_MXFP4)
+// 1-byte loads for mxfp4 blocks (17 bytes)
+i32vec2 repack(uint ib, uint iqs) {
+ const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
+ data_a[ib].qs[iqs * 4 + 1],
+ data_a[ib].qs[iqs * 4 + 2],
+ data_a[ib].qs[iqs * 4 + 3]));
+
+ return i32vec2( quants & 0x0F0F0F0F,
+ (quants >> 4) & 0x0F0F0F0F);
+}
+
+ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(da * dsb.x * float(q_sum));
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
+ data_a[ib].qs[iqs * 4 + 1],
+ data_a[ib].qs[iqs * 4 + 2],
+ data_a[ib].qs[iqs * 4 + 3]));
+
+ const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
+ const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
+
+ buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
+ buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d = buf_a[buf_ib].d;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
+// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
+#if defined(DATA_A_Q2_K)
+// 4-byte loads for Q2_K blocks (84 bytes)
+int32_t repack(uint ib, uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+
+ return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
+}
+
+uint8_t get_scale(uint ib, uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ return data_a[ib_k].scales[iqs_k / 4];
+}
+
+ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+
+ // Repack 4x4 quants into one int
+ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
+ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
+ const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
+ const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
+
+ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
+ buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+ cache_a[reg_ib].scales = buf_a[buf_ib].scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t sum_d = 0;
+ int32_t sum_m = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
+ const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
+ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
+
+ sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
+ sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_Q3_K)
+// 2-byte loads for Q3_K blocks (110 bytes)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint hm_idx = iqs * QUANT_R_MMQ;
+ const uint iqs_k = (ib % 8) * 8 + hm_idx;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+ const uint hm_shift = iqs_k / 8;
+
+ // Repack 2x4 quants into one int
+ // Add the 3rd bit instead of subtracting it to allow packing the quants
+ const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
+ buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
+ (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
+
+ if (iqs == 0) {
+ const uint is = iqs_k / 4;
+ const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
+ (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
+
+ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ float result = 0.0;
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ // Subtract 4 from the quants to correct the 3rd bit offset
+ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
+ q_sum = 0;
+
+ [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
+ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
+
+ return ACC_TYPE(cache_b.ds.x * result);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
+// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
+
+ const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 16) / 8) * 4;
+
+ // Repack 2x4 quants into one int
+#if defined(DATA_A_Q4_K)
+ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
+ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
+
+ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
+#else // defined(DATA_A_Q5_K)
+ const uint qh_idx = iqs * QUANT_R_MMQ;
+ const uint qh_shift = iqs_k / 8;
+
+ buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
+ (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
+#endif
+
+
+ if (iqs == 0) {
+ // Scale index
+ const uint is = iqs_k / 8;
+ u8vec2 scale_dm;
+ if (is < 4) {
+ scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
+ } else {
+ scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
+ (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
+ }
+
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+#if defined(DATA_A_Q4_K)
+ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
+#else // defined(DATA_A_Q5_K)
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+#endif
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#ifdef MMQ_SHMEM
+void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_outer = ib / 4;
+ const uint ib_inner = ib % 4;
+
+ if (iqs == 0) {
+ buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
+ }
+
+ const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
+ buf_b[buf_ib].qs[iqs * 4 ] = values.x;
+ buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
+ buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
+ buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
+}
+
+void block_b_to_registers(const uint ib) {
+ cache_b.ds = buf_b[ib].ds;
+ [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
+ cache_b.qs[iqs] = buf_b[ib].qs[iqs];
+ }
+}
+#endif
+
+#if defined(DATA_A_Q6_K)
+// 2-byte loads for Q6_K blocks (210 bytes)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
+ const uint ql_shift = ((iqs_k % 32) / 16) * 4;
+
+ const uint qh_idx = (iqs_k / 32) * 8 + iqs;
+ const uint qh_shift = ((iqs_k % 32) / 8) * 2;
+
+ const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
+ const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
+ buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
+
+ if (iqs == 0) {
+ const uint is = iqs_k / 4;
+ const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
+
+ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ float result = 0.0;
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
+ q_sum = 0;
+
+ [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
+
+ return ACC_TYPE(cache_b.ds.x * result);
+}
+#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
@@ -103,3 +568,10 @@
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
+
+#if defined(DATA_A_Q2_K)
+FLOAT_TYPE_VEC2 get_dm(uint ib) {
+ const uint ib_k = ib / 8;
+ return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
+}
+#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
new file mode 100644
index 000000000..72fec4404
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
@@ -0,0 +1,78 @@
+#if defined(DATA_A_Q4_0)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_Q4_1)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q5_0)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ uint32_t qh;
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_Q5_1)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ uint32_t qh;
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q8_0)
+#define QUANT_R_MMQ 1
+// AMD likes 4, Intel likes 1 and Nvidia likes 2
+#define BK_STEP 1
+struct block_a_cache {
+ int32_t qs[32/4];
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_MXFP4)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE d;
+};
+#elif defined(DATA_A_Q2_K)
+#define QUANT_R_MMQ 4
+struct block_a_cache {
+ uint32_t qs[2];
+ u8vec2 scales;
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q3_K)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[4];
+ FLOAT_TYPE_VEC2 d_scales;
+};
+#elif defined(DATA_A_Q4_K)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[4];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q5_K)
+#define QUANT_R_MMQ 1
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q6_K)
+#define QUANT_R_MMQ 1
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 d_scales;
+};
+#endif
+
+struct block_b_cache
+{
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 ds;
+};
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
index 2fa54ce51..02578c77c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
@@ -66,6 +66,7 @@
struct block_q4_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q4_1 32
@@ -98,6 +99,7 @@
struct block_q4_1_packed32
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_0 32
@@ -123,6 +125,7 @@
struct block_q5_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_1 32
@@ -158,6 +161,7 @@
struct block_q5_1_packed32
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_0 32
@@ -186,6 +190,7 @@
struct block_q8_0_packed32
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_1 32
@@ -226,21 +231,21 @@
struct block_q2_K
{
uint8_t scales[QUANT_K_Q2_K/16];
uint8_t qs[QUANT_K_Q2_K/4];
- f16vec2 d;
+ f16vec2 dm;
};
struct block_q2_K_packed16
{
uint16_t scales[QUANT_K_Q2_K/16/2];
uint16_t qs[QUANT_K_Q2_K/4/2];
- f16vec2 d;
+ f16vec2 dm;
};
struct block_q2_K_packed32
{
uint32_t scales[QUANT_K_Q2_K/16/4];
uint32_t qs[QUANT_K_Q2_K/4/4];
- f16vec2 d;
+ f16vec2 dm;
};
#if defined(DATA_A_Q2_K)
@@ -249,6 +254,8 @@
struct block_q2_K_packed32
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
+#define SCALES_PER_32 2
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q3_K 256
@@ -274,27 +281,28 @@
struct block_q3_K_packed16
#define QUANT_R 1
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q4_K 256
struct block_q4_K
{
- f16vec2 d;
+ f16vec2 dm;
uint8_t scales[3*QUANT_K_Q4_K/64];
uint8_t qs[QUANT_K_Q4_K/2];
};
struct block_q4_K_packed16
{
- f16vec2 d;
+ f16vec2 dm;
uint16_t scales[3*QUANT_K_Q4_K/64/2];
uint16_t qs[QUANT_K_Q4_K/2/2];
};
struct block_q4_K_packed32
{
- f16vec2 d;
+ f16vec2 dm;
uint32_t scales[3*QUANT_K_Q4_K/64/4];
uint32_t qs[QUANT_K_Q4_K/2/4];
};
@@ -310,13 +318,14 @@
struct block_q4_K_packed128
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q5_K 256
struct block_q5_K
{
- f16vec2 d;
+ f16vec2 dm;
uint8_t scales[12];
uint8_t qh[QUANT_K_Q5_K/8];
uint8_t qs[QUANT_K_Q5_K/2];
@@ -324,12 +333,20 @@
struct block_q5_K
struct block_q5_K_packed16
{
- f16vec2 d;
+ f16vec2 dm;
uint16_t scales[12/2];
uint16_t qh[QUANT_K_Q5_K/8/2];
uint16_t qs[QUANT_K_Q5_K/2/2];
};
+struct block_q5_K_packed32
+{
+ f16vec2 dm;
+ uint32_t scales[12/4];
+ uint32_t qh[QUANT_K_Q5_K/8/4];
+ uint32_t qs[QUANT_K_Q5_K/2/4];
+};
+
struct block_q5_K_packed128
{
uvec4 q5k[11];
@@ -340,6 +357,8 @@
struct block_q5_K_packed128
#define QUANT_R 1
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
+#define A_TYPE_PACKED32 block_q5_K_packed32
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q6_K 256
@@ -356,7 +375,7 @@
struct block_q6_K_packed16
{
uint16_t ql[QUANT_K_Q6_K/2/2];
uint16_t qh[QUANT_K_Q6_K/4/2];
- int8_t scales[QUANT_K_Q6_K/16];
+ int16_t scales[QUANT_K_Q6_K/16/2];
float16_t d;
};
@@ -365,6 +384,7 @@
struct block_q6_K_packed16
#define QUANT_R 1
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
+#define DATA_A_QUANT_K
#endif
// IQuants
@@ -1363,18 +1383,11 @@
struct block_mxfp4
uint8_t qs[QUANT_K_MXFP4/2];
};
-//struct block_mxfp4_packed16
-//{
-// uint8_t e;
-// uint16_t qs[QUANT_K_MXFP4/2/2];
-//};
-
#if defined(DATA_A_MXFP4)
#define QUANT_K QUANT_K_MXFP4
#define QUANT_R QUANT_R_MXFP4
#define QUANT_AUXF 1
#define A_TYPE block_mxfp4
-//#define A_TYPE_PACKED16 block_mxfp4_packed16
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
@@ -1397,12 +1410,12 @@
void init_iq_shmem(uvec3 wgsize)
#endif
#if defined(DATA_A_MXFP4)
-const FLOAT_TYPE kvalues_mxfp4_const[16] = {
- FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
- FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
+const int8_t kvalues_mxfp4_const[16] = {
+ int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
+ int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
};
-shared FLOAT_TYPE kvalues_mxfp4[16];
+shared int8_t kvalues_mxfp4[16];
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index 0f25ba345..03fa01639 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -566,7 +566,8 @@
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
- if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
+ // Integer dot mmq performs better with f32 accumulators
+ if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
@@ -574,7 +575,7 @@
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
void process_shaders() {
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
// matmul
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
llama/patches/0031-vulkan-Update-topk_moe-fusion-to-handle-gpt-s-late-s.patch
0 → 100644
View file @
3a9e8e9f
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 08:44:29 -0500
Subject: [PATCH] vulkan: Update topk_moe fusion to handle gpt's late softmax
(#16656)
* vulkan: Update topk_moe fusion to handle gpt's late softmax
Based on #16649.
* Add ggml_check_edges
* Add sync logging to show fusion effects
* handle clamp added in #16655
* Update ggml/src/ggml-impl.h
Co-authored-by: Diego Devesa <slarengh@gmail.com>
---
ggml/src/ggml-impl.h | 16 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 304 +++++++++++-------
.../ggml-vulkan/vulkan-shaders/topk_moe.comp | 90 ++++--
3 files changed, 272 insertions(+), 138 deletions(-)
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 639d551a2..e5c446d1d 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -693,6 +693,7 @@
GGML_API void ggml_dxgi_pdh_release();
#endif
#ifdef __cplusplus
+#include <array>
#include <initializer_list>
#include <vector>
@@ -708,6 +709,21 @@
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
+// Return true if the edges in the graph match expectations.
+inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
+ int start_idx,
+ std::initializer_list<std::array<int, 3>> edges) {
+ for (const auto & edge : edges) {
+ int dst_node = edge[0];
+ int src_idx = edge[1];
+ int src_node = edge[2];
+ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
+ return false;
+ }
+ }
+ return true;
+}
+
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 53b57c179..b2855b078 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -387,12 +387,76 @@
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
-static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
+ GGML_OP_RESHAPE };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
+//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
+//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
+//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
+//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
+//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
+//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
+//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+ { 5, 0, 4 }, // reshape->src[0] == get_rows
+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
+ { 7, 0, 6 }, // clamp->src[0] == sum_rows
+ { 8, 0, 5 }, // div->src[0] == reshape
+ { 8, 1, 7 }, // div->src[1] == clamp
+ { 9, 0, 8 }, // reshape->src[0] == div
+};
+
+// same as early_softmax_norm but ending after the get_rows
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+};
+//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
+//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
+//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
+//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
+//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
+//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
+ { 1, 0, 0 }, // view->src[0] == argsort
+ { 2, 1, 1 }, // get_rows->src[1] == view
+ { 3, 0, 2 }, // reshape->src[0] == get_rows
+ { 4, 0, 3 }, // soft_max->src[0] == reshape
+ { 5, 0, 4 }, // reshape->src[0] == soft_max
+};
+
+enum topk_moe_mode {
+ TOPK_MOE_EARLY_SOFTMAX,
+ TOPK_MOE_EARLY_SOFTMAX_NORM,
+ TOPK_MOE_LATE_SOFTMAX,
+ TOPK_MOE_COUNT,
+};
+
+static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
+ topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
+ num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
+ TOPK_MOE_LATE_SOFTMAX;
+ return mode;
+}
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -607,8 +671,7 @@
struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
- // [2] is {!norm, norm}
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
std::vector<vk_pipeline_ref> all_pipelines;
@@ -956,6 +1019,8 @@
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
struct vk_op_add_id_push_constants {
@@ -3806,8 +3871,9 @@
static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
}
for (auto &c : compiles) {
@@ -8085,8 +8151,8 @@
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
- return ctx->device->pipeline_topk_moe[idx][with_norm];
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8141,6 +8207,13 @@
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
}
case GGML_OP_ARGSORT:
+ if (ctx->num_additional_fused_ops) {
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
+ GGML_ASSERT(idx < num_topk_moe_pipelines);
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
+ }
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
@@ -9676,10 +9749,12 @@
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
- ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
- ggml_tensor * ids = cgraph->nodes[node_idx + 3];
+ ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
+ (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
+ cgraph->nodes[node_idx + 5];
+ ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
@@ -9738,9 +9813,14 @@
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
GGML_ASSERT(d_ids != nullptr);
}
- vk_op_topk_moe_push_constants pc;
+ vk_op_topk_moe_push_constants pc {};
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
+ if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+ }
GGML_ASSERT(n_expert_used <= n_experts);
@@ -11335,7 +11415,13 @@
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+
+#define ENABLE_SYNC_LOGGING 0
+
if (need_sync) {
+#if ENABLE_SYNC_LOGGING
+ std::cerr << "sync" << std::endl;
+#endif
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -11353,6 +11439,18 @@
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+#if ENABLE_SYNC_LOGGING
+ if (!dryrun) {
+ for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
+ auto *n = cgraph->nodes[node_idx + i];
+ std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
+ if (n->op == GGML_OP_GLU) {
+ std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
+ }
+ std::cerr << std::endl;
+ }
+ }
+#endif
switch (node->op) {
case GGML_OP_REPEAT:
@@ -11531,7 +11629,11 @@
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ARGSORT:
- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ if (ctx->num_additional_fused_ops) {
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
+ } else {
+ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ }
break;
case GGML_OP_SUM:
@@ -12329,30 +12431,27 @@
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
- int node_idx, bool with_norm) {
+ int node_idx, topk_moe_mode mode) {
- if (with_norm) {
- if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
- return false;
- }
- }
- } else {
- if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
- return false;
- }
- }
- }
+ const ggml_tensor * softmax;
+ const ggml_tensor * weights;
- const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
- const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
+ switch (mode) {
+ case TOPK_MOE_EARLY_SOFTMAX_NORM:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 9];
+ break;
+ case TOPK_MOE_EARLY_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 4];
+ break;
+ case TOPK_MOE_LATE_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 4];
+ weights = cgraph->nodes[node_idx + 5];
+ break;
+ default:
+ return false;
+ }
const float * op_params = (const float *)softmax->op_params;
@@ -12378,60 +12477,6 @@
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return false;
}
- // Check that the nodes don't have any unexpected uses
- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
-
- // softmax is used by reshape and argsort
- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
- reshape1->src[0] != softmax ||
- argsort->src[0] != softmax) {
- return false;
- }
- // reshape is used by get_rows
- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
- get_rows->src[0] != reshape1) {
- return false;
- }
- // argsort is used by view
- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
- view->src[0] != argsort) {
- return false;
- }
- // view is written (via argsort), we can skip checking it
-
- if (with_norm) {
- // get_rows is used by reshape
- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
- reshape5->src[0] != get_rows) {
- return false;
- }
-
- // reshape is used by sum_rows and div
- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
- sum_rows->src[0] != reshape5 ||
- div->src[0] != reshape5) {
- return false;
- }
-
- // sum_rows is used by div
- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
- div->src[1] != sum_rows) {
- return false;
- }
-
- // div/reshape are written
- if (reshape8->src[0] != div) {
- return false;
- }
- }
-
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
@@ -12517,10 +12562,18 @@
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12618,10 +12671,18 @@
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
@@ -12754,25 +12815,44 @@
static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
- // Avoid reordering topk_moe_norm
- if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
- bool is_topk_moe_norm = true;
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
- if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
- is_topk_moe_norm = false;
+ // Check for fusion patterns and avoid reordering them
+ auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
+ if (start + (int)pattern.size() <= graph->n_nodes) {
+ bool is_pattern = true;
+ for (size_t j = 0; j < pattern.size(); ++j) {
+ if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
+ is_pattern = false;
+ }
}
+ return is_pattern;
}
- if (is_topk_moe_norm) {
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
+ return false;
+ };
+
+ auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
+ if (match_pattern(pattern, first_unused)) {
+ for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
- continue;
+ return true;
}
+ return false;
+ };
+
+ if (keep_pattern(topk_moe_early_softmax_norm)) {
+ continue;
+ }
+ if (keep_pattern(topk_moe_early_softmax)) {
+ continue;
}
+ if (keep_pattern(topk_moe_late_softmax)) {
+ continue;
+ }
+
// First, grab the next unused node.
current_set.push_back(first_unused);
@@ -12790,6 +12870,12 @@
static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (is_empty(graph->nodes[j])) {
continue;
}
+ // Don't pull forward nodes from fusion patterns
+ if (match_pattern(topk_moe_early_softmax_norm, j) ||
+ match_pattern(topk_moe_early_softmax, j) ||
+ match_pattern(topk_moe_late_softmax, j)) {
+ continue;
+ }
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
index 9e56d5f8a..bc1c278bf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
@@ -11,6 +11,8 @@
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -18,6 +20,7 @@
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
+layout(constant_id = 3) const bool late_softmax = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
@@ -25,53 +28,72 @@
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
-void main() {
- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
- if (row >= n_rows) {
- return;
- }
+const float INFINITY = 1.0 / 0.0;
- const uint logits_offset = n_experts * row;
- const uint weights_offset = n_expert_used * row;
- const uint ids_offset = n_experts * row;
-
- float logits_r[experts_per_thread];
-
- const float INFINITY = 1.0 / 0.0;
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
+ float max_val = -INFINITY;
[[unroll]]
- for (uint i = 0; i < n_experts; i += WARP_SIZE) {
- const uint expert = i + gl_LocalInvocationID.x;
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ max_val = max(max_val, vals[i]);
+ }
}
- float max_val = logits_r[0];
+ max_val = subgroupMax(max_val);
+
+ float sum = 0.f;
[[unroll]]
- for (int i = 1; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- max_val = max(val, max_val);
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ const float val = exp(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
}
- max_val = subgroupMax(max_val);
+ sum = subgroupAdd(sum);
- float wt[experts_per_thread];
- float tmp = 0.f;
+ const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- wt[i] = exp(val - max_val);
- tmp += wt[i];
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ vals[i] *= inv_sum;
+ }
}
+}
- tmp = subgroupAdd(tmp);
+void main() {
+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
+ if (row >= n_rows) {
+ return;
+ }
- const float inv_sum = 1.0f / tmp;
+ const uint logits_offset = n_experts * row;
+ const uint weights_offset = n_expert_used * row;
+ const uint ids_offset = n_experts * row;
+
+ float wt[experts_per_thread];
[[unroll]]
- for (int i = 0; i < experts_per_thread; i++) {
- wt[i] = wt[i] * inv_sum;
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + gl_LocalInvocationID.x;
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+ }
+
+ if (!late_softmax) {
+ softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
}
// at this point, each thread holds a portion of softmax,
@@ -82,6 +104,11 @@
void main() {
float output_weights[experts_per_thread];
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] = 0.f;
+ }
+
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
@@ -121,6 +148,7 @@
void main() {
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
+ wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
@@ -129,6 +157,10 @@
void main() {
}
}
+ if (late_softmax) {
+ softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
+ }
+
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
llama/patches/0032-vulkan-Fuse-rope-set_rows-16769.patch
0 → 100644
View file @
3a9e8e9f
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 15:13:10 -0500
Subject: [PATCH] vulkan: Fuse rope+set_rows (#16769)
This pattern appears in a lot of models, the rope operation is applied right
before storing into the KV cache (usually on the K tensor).
Add a path to some of the rope shaders that computes the destination address
based on the set_rows tensor. Compile variants of the shader with D_TYPE of
f16 (the usual KV cache type).
Add a src3 operand to ggml_vk_op_f32 - sometimes rope uses three srcs and needs
the fourth for the row indices.
Add fused_ops_write_mask to indicate which intermediate tensors need to write
their results to memory. Skipping writing the roped K value helps to allow more
nodes to run concurrently.
Add logic to ggml_vk_graph_optimize to make ROPE+VIEW+SET_ROWS consecutive. It
rarely starts out that way in the graph.
Add new backend tests.
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 334 +++++++++++++-----
.../ggml-vulkan/vulkan-shaders/rope_head.glsl | 2 +
.../ggml-vulkan/vulkan-shaders/rope_neox.comp | 13 +-
.../ggml-vulkan/vulkan-shaders/rope_norm.comp | 13 +-
.../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +
tests/test-backend-ops.cpp | 122 +++++--
6 files changed, 371 insertions(+), 117 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index b2855b078..aaf4334b5 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -458,6 +458,11 @@
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
return mode;
}
+static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
+ { 1, 0, 0 }, // view->src[0] == rope
+ { 2, 0, 1 }, // set_rows->src[0] == view
+};
+
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -640,8 +645,8 @@
struct vk_device_struct {
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_soft_max_back_f32;
- vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
- vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
+ vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
+ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
@@ -1054,6 +1059,7 @@
struct vk_op_rope_push_constants {
uint32_t s2;
int32_t sections[4];
uint32_t is_back;
+ uint32_t set_rows_stride;
};
struct vk_op_soft_max_push_constants {
@@ -1563,6 +1569,10 @@
struct ggml_backend_vk_context {
// number of additional consecutive nodes that are being fused with the
// node currently being processed
int num_additional_fused_ops {};
+ // Bitmask of which fused ops need to write an intermediate value to memory.
+ // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
+ // If there's no fusion, bit 0 is still set.
+ int fused_ops_write_mask {};
};
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -3697,21 +3707,27 @@
static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
} else {
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
@@ -8170,7 +8186,8 @@
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
- const int mode = ((const int32_t *) dst->op_params)[2];
+ const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst;
+ const int mode = ((const int32_t *) rope->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
@@ -8179,6 +8196,9 @@
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_neox_f32;
}
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_rope_neox_f32_f16;
+ }
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_neox_f16;
}
@@ -8200,6 +8220,9 @@
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_norm_f32;
}
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_rope_norm_f32_f16;
+ }
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_norm_f16;
}
@@ -8409,20 +8432,22 @@
static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_ten
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
}
-template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
GGML_UNUSED(p);
GGML_UNUSED(src0);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
GGML_UNUSED(dst);
static_assert(!std::is_const<T>::value, "unexpected type");
GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
+ GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);
GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8430,9 +8455,10 @@
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8440,9 +8466,10 @@
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8450,9 +8477,10 @@
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8460,9 +8488,10 @@
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src0);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8472,9 +8501,10 @@
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8483,10 +8513,11 @@
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
template<typename PC>
-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
+static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
if (src1 != nullptr) {
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -8494,6 +8525,9 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
if (src2 != nullptr) {
std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
}
+ if (src3 != nullptr) {
+ std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3];
+ }
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
@@ -8520,6 +8554,13 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
const uint64_t ne2 = ne20 * ne21;
+ const bool use_src3 = src3 != nullptr;
+ const uint64_t ne30 = use_src3 ? src3->ne[0] : 0;
+ const uint64_t ne31 = use_src3 ? src3->ne[1] : 0;
+ const uint64_t ne32 = use_src3 ? src3->ne[2] : 0;
+ const uint64_t ne33 = use_src3 ? src3->ne[3] : 0;
+ const uint64_t ne3 = ne30 * ne31;
+
const uint64_t ned0 = dst->ne[0];
const uint64_t ned1 = dst->ne[1];
const uint64_t ned2 = dst->ne[2];
@@ -8550,6 +8591,7 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
+ ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr;
vk_buffer d_X = nullptr;
size_t x_buf_offset = 0;
@@ -8557,10 +8599,13 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
size_t y_buf_offset = 0;
vk_buffer d_Z = nullptr;
size_t z_buf_offset = 0;
+ vk_buffer d_W = nullptr;
+ size_t w_buf_offset = 0;
bool src0_uma = false;
bool src1_uma = false;
bool src2_uma = false;
+ bool src3_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
@@ -8573,6 +8618,10 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
src2_uma = d_Z != nullptr;
}
+ if (use_src3) {
+ ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset);
+ src3_uma = d_W != nullptr;
+ }
}
vk_buffer d_D = dst_buf_ctx->dev_buffer;
@@ -8594,11 +8643,17 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
GGML_ASSERT(d_Z != nullptr);
}
+ if (use_src3 && !src3_uma) {
+ d_W = src3_buf_ctx->dev_buffer;
+ w_buf_offset = vk_tensor_offset(src3) + src3->view_offs;
+ GGML_ASSERT(d_W != nullptr);
+ }
// Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
- init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst);
+ init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
+ w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
std::array<uint32_t, 3> elements;
@@ -8799,12 +8854,13 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
}
- uint64_t x_sz, y_sz, z_sz, d_sz;
+ uint64_t x_sz, y_sz, z_sz, w_sz, d_sz;
if (op_supports_incontiguous) {
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
+ w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0;
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
if (x_buf_offset + x_sz >= d_X->size) {
@@ -8816,6 +8872,9 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
}
+ if (use_src3 && w_buf_offset + w_sz >= d_W->size) {
+ w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset);
+ }
if (d_buf_offset + d_sz >= d_D->size) {
d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
}
@@ -8823,6 +8882,7 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
+ w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0;
d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
}
@@ -8864,14 +8924,19 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
- vk_subbuffer subbuf_z;
+ vk_subbuffer subbuf_z, subbuf_w;
if (use_src2) {
subbuf_z = { d_Z, z_buf_offset, z_sz };
} else {
subbuf_z = { d_X, 0, x_sz };
}
+ if (use_src3) {
+ subbuf_w = { d_W, w_buf_offset, w_sz };
+ } else {
+ subbuf_w = { d_X, 0, x_sz };
+ }
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }, subbuf_w }, pc, elements);
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
// buffer device address path doesn't use dst buffer
@@ -8887,6 +8952,8 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} else if (op == GGML_OP_OPT_STEP_SGD) {
// OPT_STEP_SGD works on src0, it does not need dst
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
+ } else if (use_src3) {
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src2) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src1) {
@@ -8901,7 +8968,7 @@
static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -8921,7 +8988,7 @@
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9046,7 +9113,7 @@
static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9061,7 +9128,7 @@
static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SUB, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9076,7 +9143,7 @@
static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_MUL, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9091,7 +9158,7 @@
static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_DIV, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9106,7 +9173,7 @@
static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src2_type_size = ggml_type_size(src2->type);
- ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
+ ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_ADD_ID, {
(uint32_t)dst->ne[0],
(uint32_t)dst->ne[1],
(uint32_t)src0->nb[1] / src0_type_size,
@@ -9339,7 +9406,7 @@
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx,
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
- ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
+ ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
@@ -9457,7 +9524,7 @@
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const size_t n = ggml_nelements(dst->src[0]);
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9467,7 +9534,7 @@
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONCAT, {
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9491,7 +9558,7 @@
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
}
- ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
+ ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
(uint32_t)ggml_nelements(dst), 0, 0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
@@ -9505,23 +9572,23 @@
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param2 = ggml_get_op_params_f32(dst, 1);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
}
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9529,12 +9596,12 @@
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param2 = ggml_get_op_params_f32(dst, 1);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
}
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
}
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9549,17 +9616,17 @@
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, cons
memcpy(&p.param1, &s01_packed, sizeof(float));
memcpy(&p.param2, &s23_packed, sizeof(float));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
}
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
}
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
}
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9575,7 +9642,7 @@
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
}
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
}
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9590,7 +9657,7 @@
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
return;
}
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SET_ROWS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9601,13 +9668,13 @@
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
}
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9618,7 +9685,7 @@
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
const float eps = float_op_params[1];
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@@ -9641,7 +9708,7 @@
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9658,16 +9725,16 @@
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9690,7 +9757,7 @@
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
- ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GLU,
{
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0],
@@ -9703,7 +9770,7 @@
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
int32_t * op_params = (int32_t *)dst->op_params;
- ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
+ ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
}
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
@@ -9728,7 +9795,7 @@
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -9744,7 +9811,7 @@
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
@@ -9835,7 +9902,12 @@
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
}, pc, elements);
}
-static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
+static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop, bool dryrun = false) {
+ ggml_tensor * dst = cgraph->nodes[node_idx];
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+ const ggml_tensor * src3 = nullptr;
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -9859,11 +9931,20 @@
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
- ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
+ uint32_t set_rows_stride = 0;
+ // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
+ // and overrides the dst and sets src3=row_indices
+ if (ctx->num_additional_fused_ops > 0) {
+ set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type);
+ src3 = cgraph->nodes[node_idx + 2]->src[1];
+ dst = cgraph->nodes[node_idx + 2];
+ }
+
+ ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
- { sections[0], sections[1], sections[2], sections[3] }, backprop
+ { sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride,
}, dryrun);
}
@@ -9872,7 +9953,7 @@
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
uint32_t ncols = src0->ne[0];
- ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
+ ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
op_params[0],
}, dryrun);
@@ -9880,26 +9961,26 @@
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
}
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
}
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
p.weight = 1.0f / (float)src0->ne[0];
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9932,7 +10013,7 @@
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
- ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
+ ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL, {
dst_addr,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
@@ -10005,7 +10086,7 @@
static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
- ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
+ ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
}
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -10013,7 +10094,7 @@
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
const uint32_t max_period = dst->op_params[1];
const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
+ ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
nb1, dim, max_period,
}, dryrun);
}
@@ -10046,7 +10127,7 @@
static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.s0 = static_cast<uint32_t>(s0);
- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
}
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -10069,7 +10150,7 @@
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
const uint32_t parallel_elements = N * OC * OH * OW;
- ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
+ ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
IW, IH, OW, OH, OC,
parallel_elements,
op,
@@ -10123,7 +10204,7 @@
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
GGML_ASSERT(ne03 == ne2);
GGML_ASSERT(ne02 == ne12);
- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
}
static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
@@ -10172,7 +10253,7 @@
static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context
GGML_ASSERT(ne02 == ne2);
GGML_ASSERT(ne03 == ne12);
- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -10196,12 +10277,12 @@
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(src0->ne[3] == p.channels);
GGML_ASSERT(src1->ne[3] == p.batches);
- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
}
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const float * op_params = (const float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
}
#ifdef GGML_VULKAN_RUN_TESTS
@@ -11327,7 +11408,6 @@
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
- case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
@@ -11401,9 +11481,12 @@
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// nodes require synchronization.
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
- if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
- need_sync = true;
- break;
+ // If the node actually writes to memory, then check if it needs to sync
+ if (ctx->fused_ops_write_mask & (1 << i)) {
+ if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
+ need_sync = true;
+ break;
+ }
}
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
@@ -11430,7 +11513,9 @@
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
- ctx->unsynced_nodes_written.push_back(cur_node);
+ if (ctx->fused_ops_write_mask & (1 << i)) {
+ ctx->unsynced_nodes_written.push_back(cur_node);
+ }
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
continue;
@@ -11621,11 +11706,11 @@
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ROPE:
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
+ ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, false, dryrun);
break;
case GGML_OP_ROPE_BACK:
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
+ ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true, dryrun);
break;
case GGML_OP_ARGSORT:
@@ -12487,6 +12572,41 @@
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return true;
}
+static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
+ int node_idx) {
+ GGML_UNUSED(ctx);
+ const ggml_tensor *rope = cgraph->nodes[node_idx + 0];
+ const ggml_tensor *view = cgraph->nodes[node_idx + 1];
+ const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2];
+
+ // ne3 not tested
+ if (rope->src[0]->ne[3] != 1) {
+ return false;
+ }
+
+ if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ if (set_rows->src[1]->type != GGML_TYPE_I64) {
+ return false;
+ }
+
+ // The view should flatten two dims of rope into one dim
+ if (!ggml_is_contiguous(view) ||
+ view->ne[0] != rope->ne[0] * rope->ne[1]) {
+ return false;
+ }
+
+ // Only norm/neox shaders have the fusion code
+ const int mode = ((const int32_t *) rope->op_params)[2];
+ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
+ return false;
+ }
+
+ return true;
+}
+
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
const ggml_tensor *first_node = cgraph->nodes[node_idx];
@@ -12562,6 +12682,10 @@
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
+ ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
+ ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
+ ctx->num_additional_fused_ops = 2;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
@@ -12671,20 +12795,31 @@
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
+ ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
+ ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
+ ctx->num_additional_fused_ops = 2;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ // view of argsort writes to memory
+ ctx->fused_ops_write_mask |= 1 << 3;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ // view of argsort writes to memory
+ ctx->fused_ops_write_mask |= 1 << 3;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
+ // view of argsort writes to memory
+ ctx->fused_ops_write_mask |= 1 << 1;
}
}
+ ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
@@ -12730,6 +12865,7 @@
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
}
i += ctx->num_additional_fused_ops;
ctx->num_additional_fused_ops = 0;
+ ctx->fused_ops_write_mask = 0;
}
if (vk_perf_logger_enabled) {
@@ -12887,6 +13023,32 @@
static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
}
if (ok) {
current_set.push_back(j);
+ // Look for ROPE + VIEW + SET_ROWS and make them consecutive
+ if (graph->nodes[j]->op == GGML_OP_ROPE) {
+ int view_idx = -1;
+ int set_rows_idx = -1;
+ for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) {
+ if (view_idx == -1 &&
+ graph->nodes[k]->op == GGML_OP_VIEW &&
+ graph->nodes[k]->src[0] == graph->nodes[j]) {
+ view_idx = k;
+ continue;
+ }
+ if (view_idx != -1 &&
+ set_rows_idx == -1 &&
+ graph->nodes[k]->op == GGML_OP_SET_ROWS &&
+ graph->nodes[k]->src[0] == graph->nodes[view_idx]) {
+ set_rows_idx = k;
+ break;
+ }
+ }
+ if (set_rows_idx != -1) {
+ current_set.push_back(view_idx);
+ current_set.push_back(set_rows_idx);
+ used[view_idx] = true;
+ used[set_rows_idx] = true;
+ }
+ }
}
}
// Second pass grabs view nodes.
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
index 50fc1f1e2..0eda186c8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
@@ -10,6 +10,7 @@
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_pos[];};
layout (binding = 2) readonly buffer Z {float data_ff[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
+layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
layout (push_constant) uniform parameter {
uint ncols;
@@ -27,6 +28,7 @@
layout (push_constant) uniform parameter {
uint s2;
int sections[4];
uint is_back;
+ uint set_rows_stride;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
index 06e095bef..9f4538155 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
@@ -16,12 +16,19 @@
void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
- const uint idst = row_dst*ne0 + i0/2;
+ uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
+ // Fusion optimization: ROPE + VIEW + SET_ROWS..
+ // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
+ if (p.set_rows_stride != 0) {
+ idst = row_x*ne0 + i0/2;
+ idst += data_i[channel_x].x * p.set_rows_stride;
+ }
+
if (i0 >= p.n_dims) {
- data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
- data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
+ data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
+ data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
return;
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
index 6ba957540..f4209ed95 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
@@ -16,12 +16,19 @@
void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
- const uint idst = row_dst*ne0 + i0;
+ uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
+ // Fusion optimization: ROPE + VIEW + SET_ROWS..
+ // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
+ if (p.set_rows_stride != 0) {
+ idst = row_x*ne0 + i0;
+ idst += data_i[channel_x].x * p.set_rows_stride;
+ }
+
if (i0 >= p.n_dims) {
- data_d[idst + 0] = data_a[ix + 0];
- data_d[idst + 1] = data_a[ix + 1];
+ data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
+ data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
return;
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index 03fa01639..e6ec589fb 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -842,10 +842,14 @@
void process_shaders() {
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+ string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+ string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 9eb2b6687..657b6cc2f 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -2105,6 +2105,34 @@
struct test_get_rows_back : public test_case {
}
};
+static void init_set_rows_row_ids(ggml_tensor * t, int num_rows) {
+ std::random_device rd;
+ std::default_random_engine rng(rd());
+ for (int i2 = 0; i2 < t->ne[2]; i2++) {
+ for (int i1 = 0; i1 < t->ne[1]; i1++) {
+ // generate a shuffled subset of row indices
+ std::vector<int64_t> data(num_rows);
+ for (int i = 0; i < num_rows; i++) {
+ data[i] = i;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ data.resize(t->ne[0]);
+
+ const size_t offs = i1*t->nb[1] + i2*t->nb[2];
+ if (t->type == GGML_TYPE_I32) {
+ // TODO: Make a template or something
+ std::vector<int32_t> data_i32(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data_i32[i] = static_cast<int32_t>(data[i]);
+ }
+ ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
+ } else {
+ ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
+ }
+ }
+ }
+}
+
// GGML_OP_SET_ROWS
struct test_set_rows : public test_case {
const ggml_type type;
@@ -2148,37 +2176,13 @@
struct test_set_rows : public test_case {
}
void initialize_tensors(ggml_context * ctx) override {
- std::random_device rd;
- std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) {
continue;
}
- for (int i2 = 0; i2 < t->ne[2]; i2++) {
- for (int i1 = 0; i1 < t->ne[1]; i1++) {
- // generate a shuffled subset of row indices
- std::vector<int64_t> data(ne[1]);
- for (int i = 0; i < ne[1]; i++) {
- data[i] = i;
- }
- std::shuffle(data.begin(), data.end(), rng);
- data.resize(t->ne[0]);
-
- const size_t offs = i1*t->nb[1] + i2*t->nb[2];
- if (t->type == GGML_TYPE_I32) {
- // TODO: Make a template or something
- std::vector<int32_t> data_i32(t->ne[0]);
- for (int i = 0; i < t->ne[0]; i++) {
- data_i32[i] = static_cast<int32_t>(data[i]);
- }
- ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
- } else {
- ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
- }
- }
- }
+ init_set_rows_row_ids(t, ne[1]);
} else {
init_tensor_uniform(t);
}
@@ -2207,6 +2211,67 @@
struct test_set_rows : public test_case {
}
};
+// GGML_OP_ROPE + GGML_OP_VIEW + GGML_OP_SET_ROWS
+struct test_rope_set_rows : public test_case {
+ const ggml_type type;
+ const ggml_type type_idx;
+ const std::array<int64_t, 4> ne;
+ int mode;
+
+ std::string vars() override {
+ return VARS_TO_STR4(type, type_idx, ne, mode);
+ }
+
+ std::string op_desc(ggml_tensor * t) override {
+ GGML_UNUSED(t);
+ return "ROPE_SET_ROWS";
+ }
+
+ bool run_whole_graph() override { return true; }
+
+ test_rope_set_rows(ggml_type type,
+ ggml_type type_idx,
+ std::array<int64_t, 4> ne,
+ int mode)
+ : type(type), type_idx(type_idx), ne(ne), mode(mode) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
+ ggml_set_name(src, "src");
+
+ ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
+
+ ggml_tensor * rope = ggml_rope(ctx, src, pos, ne[0], mode);
+
+ ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
+
+ ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
+ ggml_set_name(dst, "dst");
+
+ ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne[2], 1, 1);
+ ggml_set_name(row_idxs, "row_idxs");
+
+ ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs);
+ ggml_set_name(out, "out");
+
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) {
+ continue;
+ }
+
+ init_set_rows_row_ids(t, ne[2]);
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
+};
+
// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
@@ -6008,6 +6073,13 @@
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
+ for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX }) {
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+ test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 1, 100 }, mode));
+ test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 512, 1 }, mode));
+ }
+ }
+
for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) {
llama/patches/0033-vulkan-Handle-argsort-with-a-large-number-of-rows-16.patch
0 → 100644
View file @
3a9e8e9f
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Thu, 30 Oct 2025 01:27:41 -0500
Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++----
2 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index aaf4334b5..3604ceb04 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1084,6 +1084,7 @@
struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants {
uint32_t ncols;
+ uint32_t nrows;
int32_t order;
};
@@ -8710,6 +8711,7 @@
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_IM2COL:
{
@@ -9952,9 +9954,11 @@
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
int32_t * op_params = (int32_t *)dst->op_params;
uint32_t ncols = src0->ne[0];
+ uint32_t nrows = ggml_nrows(src0);
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
+ nrows,
op_params[0],
}, dryrun);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
index c81b84452..c4e68bc02 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
@@ -14,6 +14,7 @@
layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
+ uint nrows;
uint order;
} p;
@@ -26,10 +27,9 @@
void swap(uint idx0, uint idx1) {
dst_row[idx1] = tmp;
}
-void argsort(bool needs_bounds_check) {
+void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
- const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
@@ -72,8 +72,16 @@
void argsort(bool needs_bounds_check) {
void main() {
if (p.ncols == BLOCK_SIZE) {
- argsort(false);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(false, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
} else {
- argsort(true);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(true, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
}
}
llama/patches/0034-vulkan-fix-shmem-overrun-in-mmq-id-shader-16873.patch
0 → 100644
View file @
3a9e8e9f
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Ruben Ortlam <picard12@live.de>
Date: Fri, 31 Oct 2025 08:14:49 +0100
Subject: [PATCH] vulkan: fix shmem overrun in mmq id shader (#16873)
* vulkan: fix shmem overrun in mmq id shader
* metal : fix mul_mm_id
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +-
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +-
tests/test-backend-ops.cpp | 3 +++
4 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 758116342..c78082ac3 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -677,7 +677,7 @@
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256];
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
- snprintf(name, 256, "%s", base);
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index 8b238ac4b..d955b4fc7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -82,9 +82,13 @@
layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl"
+#ifdef MUL_MAT_ID
+#define BK_STEP 1
+#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
+#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
index 72fec4404..1c0f5306f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
@@ -27,7 +27,7 @@
struct block_a_cache {
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
-#define BK_STEP 1
+// #define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 657b6cc2f..1f8dda383 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -6722,6 +6722,9 @@
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
+ // gpt-oss issue with Vulkan mmq_id
+ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
+
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {
llama/patches/0035-vulkan-Fix-crash-when-FP16-mul_mat-accumulation-is-n.patch
0 → 100644
View file @
3a9e8e9f
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Masato Nakasaka <masato.nakasaka@intel.com>
Date: Fri, 31 Oct 2025 16:18:59 +0900
Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
supported (#16796)
* Experimenting crash fix
* added assert for aborting and fixed comment
* changed to check if a pipeline is empty or not
* Moved function in class definition
* replaced with is_empty
* Modified is_empty to check only unaligned pipelines
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++-------
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 3604ceb04..80185d9f0 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -146,8 +146,13 @@
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
struct vk_matmul_pipeline_struct {
vk_pipeline l, m, s;
vk_pipeline a_l, a_m, a_s;
+ // Returns true when all unaligned pipelines are null.
+ // We only check for unaligned variants since one of the unaligned pipelines must exist
+ // while aligned pipelines are optional
+ bool is_empty() const {
+ return l == nullptr && m == nullptr && s == nullptr;
+ }
};
-
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
struct vk_matmul_pipeline2 {
@@ -5080,7 +5085,7 @@
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5229,7 +5234,7 @@
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5264,16 +5269,17 @@
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
}
+ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
+ bool support_fp16acc = !mmp.f16acc->is_empty();
+ bool support_fp32acc = !mmp.f32acc->is_empty();
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
+ return mmp.f16acc;
} else {
GGML_ASSERT(support_fp32acc);
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
+ return mmp.f32acc;
}
}
ml/backend/ggml/ggml/src/ggml-impl.h
View file @
3a9e8e9f
...
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
...
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
#endif
#endif
#ifdef __cplusplus
#ifdef __cplusplus
#include <array>
#include <initializer_list>
#include <initializer_list>
#include <vector>
#include <vector>
...
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
...
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return
ggml_can_fuse_subgraph
(
cgraph
,
start_idx
,
ops
.
size
(),
ops
.
begin
(),
outputs
.
begin
(),
outputs
.
size
());
return
ggml_can_fuse_subgraph
(
cgraph
,
start_idx
,
ops
.
size
(),
ops
.
begin
(),
outputs
.
begin
(),
outputs
.
size
());
}
}
// Return true if the edges in the graph match expectations.
inline
bool
ggml_check_edges
(
const
struct
ggml_cgraph
*
cgraph
,
int
start_idx
,
std
::
initializer_list
<
std
::
array
<
int
,
3
>>
edges
)
{
for
(
const
auto
&
edge
:
edges
)
{
int
dst_node
=
edge
[
0
];
int
src_idx
=
edge
[
1
];
int
src_node
=
edge
[
2
];
if
(
cgraph
->
nodes
[
start_idx
+
dst_node
]
->
src
[
src_idx
]
!=
cgraph
->
nodes
[
start_idx
+
src_node
])
{
return
false
;
}
}
return
true
;
}
// expose GGUF internals for test code
// expose GGUF internals for test code
GGML_API
size_t
gguf_type_size
(
enum
gguf_type
type
);
GGML_API
size_t
gguf_type_size
(
enum
gguf_type
type
);
GGML_API
struct
gguf_context
*
gguf_init_from_file_impl
(
FILE
*
file
,
struct
gguf_init_params
params
);
GGML_API
struct
gguf_context
*
gguf_init_from_file_impl
(
FILE
*
file
,
struct
gguf_init_params
params
);
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
View file @
3a9e8e9f
...
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
...
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char
name
[
256
];
char
name
[
256
];
snprintf
(
base
,
256
,
"kernel_mul_mm_id_map0_ne20_%d"
,
ne20
);
snprintf
(
base
,
256
,
"kernel_mul_mm_id_map0_ne20_%d"
,
ne20
);
snprintf
(
name
,
256
,
"%s"
,
base
);
snprintf
(
name
,
256
,
"%s
_ne02=%d
"
,
base
,
ne02
);
ggml_metal_pipeline_t
res
=
ggml_metal_library_get_pipeline
(
lib
,
name
);
ggml_metal_pipeline_t
res
=
ggml_metal_library_get_pipeline
(
lib
,
name
);
if
(
res
)
{
if
(
res
)
{
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
View file @
3a9e8e9f
...
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
...
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
struct vk_matmul_pipeline_struct {
struct vk_matmul_pipeline_struct {
vk_pipeline l, m, s;
vk_pipeline l, m, s;
vk_pipeline a_l, a_m, a_s;
vk_pipeline a_l, a_m, a_s;
// Returns true when all unaligned pipelines are null.
// We only check for unaligned variants since one of the unaligned pipelines must exist
// while aligned pipelines are optional
bool is_empty() const {
return l == nullptr && m == nullptr && s == nullptr;
}
};
};
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
struct vk_matmul_pipeline2 {
struct vk_matmul_pipeline2 {
...
@@ -387,12 +392,81 @@ static constexpr uint32_t num_argsort_pipelines = 11;
...
@@ -387,12 +392,81 @@ static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_RESHAPE };
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
{ 1, 0, 0 }, // reshape->src[0] == softmax
{ 2, 0, 0 }, // argsort->src[0] == softmax
{ 3, 0, 2 }, // view->src[0] == argsort
{ 4, 0, 1 }, // get_rows->src[0] == reshape
{ 4, 1, 3 }, // get_rows->src[1] == view
{ 5, 0, 4 }, // reshape->src[0] == get_rows
{ 6, 0, 5 }, // sum_rows->src[0] == reshape
{ 7, 0, 6 }, // clamp->src[0] == sum_rows
{ 8, 0, 5 }, // div->src[0] == reshape
{ 8, 1, 7 }, // div->src[1] == clamp
{ 9, 0, 8 }, // reshape->src[0] == div
};
// same as early_softmax_norm but ending after the get_rows
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
{ 1, 0, 0 }, // reshape->src[0] == softmax
{ 2, 0, 0 }, // argsort->src[0] == softmax
{ 3, 0, 2 }, // view->src[0] == argsort
{ 4, 0, 1 }, // get_rows->src[0] == reshape
{ 4, 1, 3 }, // get_rows->src[1] == view
};
//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
{ 1, 0, 0 }, // view->src[0] == argsort
{ 2, 1, 1 }, // get_rows->src[1] == view
{ 3, 0, 2 }, // reshape->src[0] == get_rows
{ 4, 0, 3 }, // soft_max->src[0] == reshape
{ 5, 0, 4 }, // reshape->src[0] == soft_max
};
enum topk_moe_mode {
TOPK_MOE_EARLY_SOFTMAX,
TOPK_MOE_EARLY_SOFTMAX_NORM,
TOPK_MOE_LATE_SOFTMAX,
TOPK_MOE_COUNT,
};
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
TOPK_MOE_LATE_SOFTMAX;
return mode;
}
static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
{ 1, 0, 0 }, // view->src[0] == rope
{ 2, 0, 1 }, // set_rows->src[0] == view
};
struct vk_device_struct {
struct vk_device_struct {
std::recursive_mutex mutex;
std::recursive_mutex mutex;
...
@@ -488,6 +562,7 @@ struct vk_device_struct {
...
@@ -488,6 +562,7 @@ struct vk_device_struct {
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_quantize_q8_1;
vk_pipeline pipeline_quantize_q8_1;
...
@@ -575,8 +650,8 @@ struct vk_device_struct {
...
@@ -575,8 +650,8 @@ struct vk_device_struct {
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_soft_max_back_f32;
vk_pipeline pipeline_soft_max_back_f32;
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16
, pipeline_rope_norm_f32_f16
;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16
, pipeline_rope_neox_f32_f16
;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
...
@@ -606,8 +681,7 @@ struct vk_device_struct {
...
@@ -606,8 +681,7 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
vk_pipeline pipeline_flash_attn_split_k_reduce;
// [2] is {!norm, norm}
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
std::vector<vk_pipeline_ref> all_pipelines;
std::vector<vk_pipeline_ref> all_pipelines;
...
@@ -955,6 +1029,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
...
@@ -955,6 +1029,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_rows;
uint32_t n_expert_used;
uint32_t n_expert_used;
float clamp_min;
float clamp_max;
};
};
struct vk_op_add_id_push_constants {
struct vk_op_add_id_push_constants {
...
@@ -988,6 +1064,7 @@ struct vk_op_rope_push_constants {
...
@@ -988,6 +1064,7 @@ struct vk_op_rope_push_constants {
uint32_t s2;
uint32_t s2;
int32_t sections[4];
int32_t sections[4];
uint32_t is_back;
uint32_t is_back;
uint32_t set_rows_stride;
};
};
struct vk_op_soft_max_push_constants {
struct vk_op_soft_max_push_constants {
...
@@ -1012,6 +1089,7 @@ struct vk_op_soft_max_push_constants {
...
@@ -1012,6 +1089,7 @@ struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants {
struct vk_op_argsort_push_constants {
uint32_t ncols;
uint32_t ncols;
uint32_t nrows;
int32_t order;
int32_t order;
};
};
...
@@ -1497,6 +1575,10 @@ struct ggml_backend_vk_context {
...
@@ -1497,6 +1575,10 @@ struct ggml_backend_vk_context {
// number of additional consecutive nodes that are being fused with the
// number of additional consecutive nodes that are being fused with the
// node currently being processed
// node currently being processed
int num_additional_fused_ops {};
int num_additional_fused_ops {};
// Bitmask of which fused ops need to write an intermediate value to memory.
// Bit 'i' means nodes[start_of_fusion + i] writes to memory.
// If there's no fusion, bit 0 is still set.
int fused_ops_write_mask {};
};
};
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
...
@@ -2449,8 +2531,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -2449,8 +2531,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
...
@@ -2513,10 +2598,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -2513,10 +2598,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
// Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
// K-quants use even more registers, mitigate by setting WMITER to 1
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
...
@@ -2525,10 +2616,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -2525,10 +2616,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
// chip specific tuning
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
m_warptile_mmqid =
m_warptile_mmqid_int =
{ 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
}
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
...
@@ -2913,18 +3012,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -2913,18 +3012,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID
, REQSUBGROUPSIZE
) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
if (device->mul_mat ## ID ## _l[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
} \
} \
if (device->mul_mat ## ID ## _m[TYPE]) { \
if (device->mul_mat ## ID ## _m[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
} \
} \
if (device->mul_mat ## ID ## _s[TYPE]) { \
if (device->mul_mat ## ID ## _s[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
} \
} \
// Create 2 variants, {f16,f32} accumulator
// Create 2 variants, {f16,f32} accumulator
...
@@ -2963,11 +3059,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -2963,11 +3059,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
}
}
#endif
#endif
...
@@ -2997,6 +3101,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -2997,6 +3101,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
}
#endif
} else {
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
...
@@ -3023,6 +3145,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -3023,6 +3145,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
}
#endif
}
}
#undef CREATE_MM2
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MMQ
...
@@ -3087,6 +3227,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -3087,6 +3227,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
}
}
#endif
#endif
...
@@ -3146,7 +3292,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -3146,7 +3292,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
}
// reusing CREATE_MM from the fp32 path
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
if ((device->coopmat2 || device->coopmat_support)
#if defined(GGML_VULKAN_
INTEGER_DOT
_GLSLC_SUPPORT)
#if defined(GGML_VULKAN_
BFLOAT16
_GLSLC_SUPPORT)
&& !device->coopmat_bf16_support
&& !device->coopmat_bf16_support
#endif
#endif
) {
) {
...
@@ -3567,21 +3713,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -3567,21 +3713,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main",
4
, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main",
5
, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main",
4
, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main",
5
, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main",
4
, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main",
5
, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main",
4
, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main",
5
, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
} else {
} else {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}
}
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
...
@@ -3741,8 +3893,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -3741,8 +3893,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
}
}
for (auto &c : compiles) {
for (auto &c : compiles) {
...
@@ -4930,9 +5083,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
...
@@ -4930,9 +5083,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
// MMQ
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines =
(ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc :
ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
if (pipelines->
s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr
) {
if (pipelines->
is_empty()
) {
return nullptr;
return nullptr;
}
}
...
@@ -5077,6 +5230,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
...
@@ -5077,6 +5230,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
}
}
}
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
if (pipelines->is_empty()) {
return nullptr;
}
return pipelines;
}
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
switch (src0_type) {
switch (src0_type) {
...
@@ -5105,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
...
@@ -5105,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
return nullptr;
}
}
vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
bool support_fp16acc =
ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr
;
bool support_fp16acc =
!mmp.f16acc->is_empty()
;
bool support_fp32acc =
ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr
;
bool support_fp32acc =
!mmp.f32acc->is_empty()
;
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
return
ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type]
.f16acc;
return
mmp
.f16acc;
} else {
} else {
GGML_ASSERT(support_fp32acc);
GGML_ASSERT(support_fp32acc);
return
ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type]
.f32acc;
return
mmp
.f32acc;
}
}
}
}
...
@@ -5654,14 +5819,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
...
@@ -5654,14 +5819,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
// Copy device to device
// Copy device to device
ggml_vk_ensure_sync_staging_buffer(src->device, size);
ggml_vk_ensure_sync_staging_buffer(src->device, size);
ggml_vk_ensure_sync_staging_buffer(dst->device, size);
// Copy to src staging buffer
// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
// memcpy to dst staging buffer
memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
// Copy to dst buffer
// Copy to dst buffer
ggml_vk_buffer_
copy
(dst, dst_offset,
dst
->device->sync_staging, 0, size);
ggml_vk_buffer_
write_2d
(dst, dst_offset,
src
->device->sync_staging
->ptr
, 0, size
, 1
);
}
}
}
}
...
@@ -6882,10 +7044,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -6882,10 +7044,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
// Check for mmq first
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
if (mmp == nullptr) {
// Fall back to f16 dequant mul mat
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
quantize_y = false;
}
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
const bool qy_needs_dequant =
!quantize_y && (
(src1->type != f16_type && !y_f32_kernel) || y_non_contig
)
;
if (qx_needs_dequant) {
if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat
// Fall back to dequant + f16 mulmat
...
@@ -6895,8 +7066,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -6895,8 +7066,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Not implemented
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
const uint32_t kpad =
quantize_y ? 0 :
ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
const bool aligned =
!quantize_y &&
ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
...
@@ -6909,12 +7080,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -6909,12 +7080,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
const uint64_t y_sz =
quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (
y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne
)
;
const uint64_t ids_sz = nbi2;
const uint64_t ids_sz = nbi2;
const uint64_t d_sz = sizeof(float) * d_ne;
const uint64_t d_sz = sizeof(float) * d_ne;
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
vk_pipeline to_q8_1 = nullptr;
if (x_non_contig) {
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
...
@@ -6929,9 +7101,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -6929,9 +7101,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
}
if (dryrun) {
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
uint64_t y_sz_upd = y_sz * ne12 * ne13;
if (quantize_y) {
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
}
if (
if (
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
...
@@ -6940,7 +7119,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -6940,7 +7119,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
ctx->prealloc_size_x = x_sz_upd;
ctx->prealloc_size_x = x_sz_upd;
}
}
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
if
(
(qy_needs_dequant
|| quantize_y)
&& ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = y_sz_upd;
ctx->prealloc_size_y = y_sz_upd;
}
}
...
@@ -6952,6 +7131,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -6952,6 +7131,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
}
}
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
return;
return;
}
}
...
@@ -6988,6 +7170,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -6988,6 +7170,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
if (qy_needs_dequant) {
d_Y = ctx->prealloc_y;
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
} else if (quantize_y) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
} else {
d_Y = d_Qy;
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
y_buf_offset = qy_buf_offset;
...
@@ -7019,6 +7204,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -7019,6 +7204,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_tensor_used = src1;
}
}
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
}
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
uint32_t stride_batch_y = ne10*ne11;
...
@@ -7027,14 +7223,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
...
@@ -7027,14 +7223,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
}
}
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant
&& !quantize_y
) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
}
uint32_t y_sz_total = y_sz * ne12 * ne13;
if (quantize_y) {
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
}
// compute
// compute
ggml_vk_matmul_id(
ggml_vk_matmul_id(
ctx, subctx, pipeline,
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz
* ne12 * ne13
},
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz
_total
},
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01,
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
stride_batch_x, stride_batch_y, ne20*ne21,
...
@@ -7973,8 +8174,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
...
@@ -7973,8 +8174,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) {
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
GGML_ASSERT(idx < num_topk_moe_pipelines);
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1
;
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops)
;
return ctx->device->pipeline_topk_moe[idx][
with_norm
];
return ctx->device->pipeline_topk_moe[idx][
mode
];
}
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
...
@@ -7992,7 +8193,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
...
@@ -7992,7 +8193,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_ROPE:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_ROPE_BACK:
{
{
const int mode = ((const int32_t *) dst->op_params)[2];
const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst;
const int mode = ((const int32_t *) rope->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
...
@@ -8001,6 +8203,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
...
@@ -8001,6 +8203,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_neox_f32;
return ctx->device->pipeline_rope_neox_f32;
}
}
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_neox_f32_f16;
}
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_neox_f16;
return ctx->device->pipeline_rope_neox_f16;
}
}
...
@@ -8022,6 +8227,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
...
@@ -8022,6 +8227,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_norm_f32;
return ctx->device->pipeline_rope_norm_f32;
}
}
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_norm_f32_f16;
}
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_norm_f16;
return ctx->device->pipeline_rope_norm_f16;
}
}
...
@@ -8029,6 +8237,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
...
@@ -8029,6 +8237,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
return nullptr;
}
}
case GGML_OP_ARGSORT:
case GGML_OP_ARGSORT:
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
return ctx->device->pipeline_argsort_f32[idx];
...
@@ -8224,20 +8439,22 @@ static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_ten
...
@@ -8224,20 +8439,22 @@ static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_ten
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
}
}
template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2,
const ggml_tensor * src3,
ggml_tensor * dst) {
GGML_UNUSED(p);
GGML_UNUSED(p);
GGML_UNUSED(src0);
GGML_UNUSED(src0);
GGML_UNUSED(src1);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
GGML_UNUSED(dst);
GGML_UNUSED(dst);
static_assert(!std::is_const<T>::value, "unexpected type");
static_assert(!std::is_const<T>::value, "unexpected type");
GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);
GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
}
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2,
const ggml_tensor * src3,
ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
...
@@ -8245,9 +8462,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
...
@@ -8245,9 +8462,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2,
const ggml_tensor * src3,
ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
...
@@ -8255,9 +8473,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
...
@@ -8255,9 +8473,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2,
const ggml_tensor * src3,
ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
...
@@ -8265,9 +8484,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
...
@@ -8265,9 +8484,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2,
const ggml_tensor * src3,
ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
...
@@ -8275,9 +8495,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
...
@@ -8275,9 +8495,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src0);
GGML_UNUSED(src0);
GGML_UNUSED(src2);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2,
const ggml_tensor * src3,
ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
...
@@ -8287,9 +8508,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
...
@@ -8287,9 +8508,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
GGML_UNUSED(src2);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2,
const ggml_tensor * src3,
ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
...
@@ -8298,10 +8520,11 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
...
@@ -8298,10 +8520,11 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
}
template<typename PC>
template<typename PC>
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2,
const ggml_tensor * src3,
ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
if (src1 != nullptr) {
if (src1 != nullptr) {
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
...
@@ -8309,6 +8532,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8309,6 +8532,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
if (src2 != nullptr) {
if (src2 != nullptr) {
std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
}
}
if (src3 != nullptr) {
std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3];
}
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
...
@@ -8335,6 +8561,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8335,6 +8561,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
const uint64_t ne2 = ne20 * ne21;
const uint64_t ne2 = ne20 * ne21;
const bool use_src3 = src3 != nullptr;
const uint64_t ne30 = use_src3 ? src3->ne[0] : 0;
const uint64_t ne31 = use_src3 ? src3->ne[1] : 0;
const uint64_t ne32 = use_src3 ? src3->ne[2] : 0;
const uint64_t ne33 = use_src3 ? src3->ne[3] : 0;
const uint64_t ne3 = ne30 * ne31;
const uint64_t ned0 = dst->ne[0];
const uint64_t ned0 = dst->ne[0];
const uint64_t ned1 = dst->ne[1];
const uint64_t ned1 = dst->ne[1];
const uint64_t ned2 = dst->ne[2];
const uint64_t ned2 = dst->ne[2];
...
@@ -8365,6 +8598,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8365,6 +8598,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr;
vk_buffer d_X = nullptr;
vk_buffer d_X = nullptr;
size_t x_buf_offset = 0;
size_t x_buf_offset = 0;
...
@@ -8372,10 +8606,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8372,10 +8606,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
size_t y_buf_offset = 0;
size_t y_buf_offset = 0;
vk_buffer d_Z = nullptr;
vk_buffer d_Z = nullptr;
size_t z_buf_offset = 0;
size_t z_buf_offset = 0;
vk_buffer d_W = nullptr;
size_t w_buf_offset = 0;
bool src0_uma = false;
bool src0_uma = false;
bool src1_uma = false;
bool src1_uma = false;
bool src2_uma = false;
bool src2_uma = false;
bool src3_uma = false;
if (ctx->device->uma) {
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
...
@@ -8388,6 +8625,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8388,6 +8625,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
src2_uma = d_Z != nullptr;
src2_uma = d_Z != nullptr;
}
}
if (use_src3) {
ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset);
src3_uma = d_W != nullptr;
}
}
}
vk_buffer d_D = dst_buf_ctx->dev_buffer;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
...
@@ -8409,11 +8650,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8409,11 +8650,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
GGML_ASSERT(d_Z != nullptr);
GGML_ASSERT(d_Z != nullptr);
}
}
if (use_src3 && !src3_uma) {
d_W = src3_buf_ctx->dev_buffer;
w_buf_offset = vk_tensor_offset(src3) + src3->view_offs;
GGML_ASSERT(d_W != nullptr);
}
// Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
// Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst);
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2,
src3,
dst);
x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
std::array<uint32_t, 3> elements;
std::array<uint32_t, 3> elements;
...
@@ -8470,6 +8717,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8470,6 +8717,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
break;
case GGML_OP_ARGSORT:
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
break;
case GGML_OP_IM2COL:
case GGML_OP_IM2COL:
{
{
...
@@ -8614,12 +8862,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8614,12 +8862,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
break;
}
}
uint64_t x_sz, y_sz, z_sz, d_sz;
uint64_t x_sz, y_sz, z_sz,
w_sz,
d_sz;
if (op_supports_incontiguous) {
if (op_supports_incontiguous) {
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0;
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
if (x_buf_offset + x_sz >= d_X->size) {
if (x_buf_offset + x_sz >= d_X->size) {
...
@@ -8631,6 +8880,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8631,6 +8880,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
}
}
if (use_src3 && w_buf_offset + w_sz >= d_W->size) {
w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset);
}
if (d_buf_offset + d_sz >= d_D->size) {
if (d_buf_offset + d_sz >= d_D->size) {
d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
}
}
...
@@ -8638,6 +8890,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8638,6 +8890,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0;
d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
}
}
...
@@ -8679,14 +8932,19 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8679,14 +8932,19 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
// Empty src2 is possible in rope, but the shader needs a buffer
vk_subbuffer subbuf_z;
vk_subbuffer subbuf_z
, subbuf_w
;
if (use_src2) {
if (use_src2) {
subbuf_z = { d_Z, z_buf_offset, z_sz };
subbuf_z = { d_Z, z_buf_offset, z_sz };
} else {
} else {
subbuf_z = { d_X, 0, x_sz };
subbuf_z = { d_X, 0, x_sz };
}
}
if (use_src3) {
subbuf_w = { d_W, w_buf_offset, w_sz };
} else {
subbuf_w = { d_X, 0, x_sz };
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }
, subbuf_w
}, pc, elements);
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
// buffer device address path doesn't use dst buffer
// buffer device address path doesn't use dst buffer
...
@@ -8702,6 +8960,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8702,6 +8960,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} else if (op == GGML_OP_OPT_STEP_SGD) {
} else if (op == GGML_OP_OPT_STEP_SGD) {
// OPT_STEP_SGD works on src0, it does not need dst
// OPT_STEP_SGD works on src0, it does not need dst
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
} else if (use_src3) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src2) {
} else if (use_src2) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src1) {
} else if (use_src1) {
...
@@ -8716,7 +8976,7 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -8716,7 +8976,7 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_GET_ROWS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -8736,7 +8996,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
...
@@ -8736,7 +8996,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
int offset = dst->op_params[3] / 4; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_ACC, {
(uint32_t)ggml_nelements(src0),
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -8861,7 +9121,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
...
@@ -8861,7 +9121,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_ADD, {
(uint32_t)ggml_nelements(src0),
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -8876,7 +9136,7 @@ static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const
...
@@ -8876,7 +9136,7 @@ static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_SUB, {
(uint32_t)ggml_nelements(src0),
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -8891,7 +9151,7 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const
...
@@ -8891,7 +9151,7 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_MUL, {
(uint32_t)ggml_nelements(src0),
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -8906,7 +9166,7 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
...
@@ -8906,7 +9166,7 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_DIV, {
(uint32_t)ggml_nelements(src0),
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -8921,7 +9181,7 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8921,7 +9181,7 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src2_type_size = ggml_type_size(src2->type);
const uint32_t src2_type_size = ggml_type_size(src2->type);
ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2,
nullptr,
dst, GGML_OP_ADD_ID, {
(uint32_t)dst->ne[0],
(uint32_t)dst->ne[0],
(uint32_t)dst->ne[1],
(uint32_t)dst->ne[1],
(uint32_t)src0->nb[1] / src0_type_size,
(uint32_t)src0->nb[1] / src0_type_size,
...
@@ -9154,7 +9414,7 @@ static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9154,7 +9414,7 @@ static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx,
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src1 = dst->src[1];
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
...
@@ -9272,7 +9532,7 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
...
@@ -9272,7 +9532,7 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const size_t n = ggml_nelements(dst->src[0]);
const size_t n = ggml_nelements(dst->src[0]);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2,
nullptr,
dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
}
}
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9282,7 +9542,7 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -9282,7 +9542,7 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_CONCAT, {
(uint32_t)ggml_nelements(dst),
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -9306,7 +9566,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
...
@@ -9306,7 +9566,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
}
}
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_UPSCALE, {
(uint32_t)ggml_nelements(dst), 0, 0,
(uint32_t)ggml_nelements(dst), 0, 0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
...
@@ -9320,23 +9580,23 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
...
@@ -9320,23 +9580,23 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param2 = ggml_get_op_params_f32(dst, 1);
p.param2 = ggml_get_op_params_f32(dst, 1);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_SCALE, std::move(p), dryrun);
}
}
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
}
static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
}
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
}
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9344,12 +9604,12 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
...
@@ -9344,12 +9604,12 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param2 = ggml_get_op_params_f32(dst, 1);
p.param2 = ggml_get_op_params_f32(dst, 1);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_CLAMP, std::move(p), dryrun);
}
}
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_PAD, std::move(p), dryrun);
}
}
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9364,17 +9624,17 @@ static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, cons
...
@@ -9364,17 +9624,17 @@ static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, cons
memcpy(&p.param1, &s01_packed, sizeof(float));
memcpy(&p.param1, &s01_packed, sizeof(float));
memcpy(&p.param2, &s23_packed, sizeof(float));
memcpy(&p.param2, &s23_packed, sizeof(float));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_ROLL, std::move(p), dryrun);
}
}
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_REPEAT, std::move(p), dryrun);
}
}
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
}
}
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9390,7 +9650,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
...
@@ -9390,7 +9650,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
}
}
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_CPY, std::move(p), dryrun);
}
}
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9405,7 +9665,7 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9405,7 +9665,7 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
return;
return;
}
}
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_SET_ROWS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -9416,13 +9676,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9416,13 +9676,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
}
}
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
}
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9433,7 +9693,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
...
@@ -9433,7 +9693,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
const float eps = float_op_params[1];
const float eps = float_op_params[1];
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}
}
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
...
@@ -9456,7 +9716,7 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9456,7 +9716,7 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
...
@@ -9473,16 +9733,16 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9473,16 +9733,16 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9505,7 +9765,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
...
@@ -9505,7 +9765,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_GLU,
{
{
(uint32_t)ggml_nelements(dst),
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0],
(uint32_t)src0->ne[0],
...
@@ -9518,7 +9778,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
...
@@ -9518,7 +9778,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
int32_t * op_params = (int32_t *)dst->op_params;
int32_t * op_params = (int32_t *)dst->op_params;
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
}
}
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9543,7 +9803,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9543,7 +9803,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2,
nullptr,
dst, GGML_OP_SOFT_MAX, {
ncols,
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
...
@@ -9559,15 +9819,17 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9559,15 +9819,17 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
}
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1
;
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops)
;
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
(mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
cgraph->nodes[node_idx + 5];
ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
...
@@ -9626,9 +9888,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9626,9 +9888,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
GGML_ASSERT(d_ids != nullptr);
GGML_ASSERT(d_ids != nullptr);
}
}
vk_op_topk_moe_push_constants pc;
vk_op_topk_moe_push_constants pc
{}
;
pc.n_rows = n_rows;
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
pc.n_expert_used = n_expert_used;
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
}
GGML_ASSERT(n_expert_used <= n_experts);
GGML_ASSERT(n_expert_used <= n_experts);
...
@@ -9643,7 +9910,12 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9643,7 +9910,12 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
}, pc, elements);
}, pc, elements);
}
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop, bool dryrun = false) {
ggml_tensor * dst = cgraph->nodes[node_idx];
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = nullptr;
const int n_dims = ((int32_t *) dst->op_params)[1];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
...
@@ -9667,11 +9939,20 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
...
@@ -9667,11 +9939,20 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
uint32_t set_rows_stride = 0;
// Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
// and overrides the dst and sets src3=row_indices
if (ctx->num_additional_fused_ops > 0) {
set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type);
src3 = cgraph->nodes[node_idx + 2]->src[1];
dst = cgraph->nodes[node_idx + 2];
}
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
{ sections[0], sections[1], sections[2], sections[3] }, backprop
{ sections[0], sections[1], sections[2], sections[3] }, backprop
, set_rows_stride,
}, dryrun);
}, dryrun);
}
}
...
@@ -9679,35 +9960,37 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
...
@@ -9679,35 +9960,37 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
int32_t * op_params = (int32_t *)dst->op_params;
int32_t * op_params = (int32_t *)dst->op_params;
uint32_t ncols = src0->ne[0];
uint32_t ncols = src0->ne[0];
uint32_t nrows = ggml_nrows(src0);
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_ARGSORT, {
ncols,
ncols,
nrows,
op_params[0],
op_params[0],
}, dryrun);
}, dryrun);
}
}
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_SUM, p, dryrun);
}
}
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_SUM_ROWS, p, dryrun);
}
}
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
p.weight = 1.0f / (float)src0->ne[0];
p.weight = 1.0f / (float)src0->ne[0];
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_MEAN, p, dryrun);
}
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun);
}
}
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
}
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9740,7 +10023,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -9740,7 +10023,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_IM2COL, {
dst_addr,
dst_addr,
batch_offset, offset_delta,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
IC, IW, IH, OW, OH, KW, KH,
...
@@ -9813,7 +10096,7 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9813,7 +10096,7 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
}
}
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9821,7 +10104,7 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
...
@@ -9821,7 +10104,7 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
const uint32_t max_period = dst->op_params[1];
const uint32_t max_period = dst->op_params[1];
const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_TIMESTEP_EMBEDDING, {
nb1, dim, max_period,
nb1, dim, max_period,
}, dryrun);
}, dryrun);
}
}
...
@@ -9854,7 +10137,7 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
...
@@ -9854,7 +10137,7 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.s0 = static_cast<uint32_t>(s0);
p.s0 = static_cast<uint32_t>(s0);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
}
}
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
...
@@ -9877,7 +10160,7 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
...
@@ -9877,7 +10160,7 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
const uint32_t parallel_elements = N * OC * OH * OW;
const uint32_t parallel_elements = N * OC * OH * OW;
ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_POOL_2D, {
IW, IH, OW, OH, OC,
IW, IH, OW, OH, OC,
parallel_elements,
parallel_elements,
op,
op,
...
@@ -9931,7 +10214,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
...
@@ -9931,7 +10214,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
GGML_ASSERT(ne03 == ne2);
GGML_ASSERT(ne03 == ne2);
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne02 == ne12);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_CONV_2D, std::move(p), dryrun);
}
}
static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
...
@@ -9980,7 +10263,7 @@ static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context
...
@@ -9980,7 +10263,7 @@ static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context
GGML_ASSERT(ne02 == ne2);
GGML_ASSERT(ne02 == ne2);
GGML_ASSERT(ne03 == ne12);
GGML_ASSERT(ne03 == ne12);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
}
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
...
@@ -10004,12 +10287,12 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
...
@@ -10004,12 +10287,12 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(src0->ne[3] == p.channels);
GGML_ASSERT(src0->ne[3] == p.channels);
GGML_ASSERT(src1->ne[3] == p.batches);
GGML_ASSERT(src1->ne[3] == p.batches);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr,
nullptr,
dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
}
}
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const float * op_params = (const float *)dst->op_params;
const float * op_params = (const float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr,
nullptr,
dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
}
}
#ifdef GGML_VULKAN_RUN_TESTS
#ifdef GGML_VULKAN_RUN_TESTS
...
@@ -11135,7 +11418,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11135,7 +11418,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_ROPE_BACK:
case GGML_OP_ARGSORT:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM:
...
@@ -11209,9 +11491,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11209,9 +11491,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// nodes require synchronization.
// nodes require synchronization.
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
// If the node actually writes to memory, then check if it needs to sync
need_sync = true;
if (ctx->fused_ops_write_mask & (1 << i)) {
break;
if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
need_sync = true;
break;
}
}
}
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
if (!cur_node->src[j]) {
...
@@ -11223,7 +11508,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11223,7 +11508,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
}
}
}
#define ENABLE_SYNC_LOGGING 0
if (need_sync) {
if (need_sync) {
#if ENABLE_SYNC_LOGGING
std::cerr << "sync" << std::endl;
#endif
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
ggml_vk_sync_buffers(ctx, compute_ctx);
...
@@ -11232,7 +11523,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11232,7 +11523,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
ctx->unsynced_nodes_written.push_back(cur_node);
if (ctx->fused_ops_write_mask & (1 << i)) {
ctx->unsynced_nodes_written.push_back(cur_node);
}
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
if (!cur_node->src[j]) {
continue;
continue;
...
@@ -11241,6 +11534,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11241,6 +11534,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
}
}
}
#if ENABLE_SYNC_LOGGING
if (!dryrun) {
for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
auto *n = cgraph->nodes[node_idx + i];
std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
if (n->op == GGML_OP_GLU) {
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
}
std::cerr << std::endl;
}
}
#endif
switch (node->op) {
switch (node->op) {
case GGML_OP_REPEAT:
case GGML_OP_REPEAT:
...
@@ -11411,15 +11716,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11411,15 +11716,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
break;
case GGML_OP_ROPE:
case GGML_OP_ROPE:
ggml_vk_rope(ctx, compute_ctx,
src0, src1, src2
, node, false, dryrun);
ggml_vk_rope(ctx, compute_ctx,
cgraph
, node
_idx
, false, dryrun);
break;
break;
case GGML_OP_ROPE_BACK:
case GGML_OP_ROPE_BACK:
ggml_vk_rope(ctx, compute_ctx,
src0, src1, src2
, node, true, dryrun);
ggml_vk_rope(ctx, compute_ctx,
cgraph
, node
_idx
, true, dryrun);
break;
break;
case GGML_OP_ARGSORT:
case GGML_OP_ARGSORT:
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
if (ctx->num_additional_fused_ops) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
} else {
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
}
break;
break;
case GGML_OP_SUM:
case GGML_OP_SUM:
...
@@ -12217,30 +12526,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
...
@@ -12217,30 +12526,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
}
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx,
bool with_norm
) {
int node_idx,
topk_moe_mode mode
) {
if (with_norm) {
const ggml_tensor * softmax;
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
const ggml_tensor * weights;
return false;
}
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
return false;
}
}
} else {
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
return false;
}
}
}
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
switch (mode) {
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
case TOPK_MOE_EARLY_SOFTMAX_NORM:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9];
break;
case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4];
break;
case TOPK_MOE_LATE_SOFTMAX:
softmax = cgraph->nodes[node_idx + 4];
weights = cgraph->nodes[node_idx + 5];
break;
default:
return false;
}
const float * op_params = (const float *)softmax->op_params;
const float * op_params = (const float *)softmax->op_params;
...
@@ -12266,64 +12572,45 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
...
@@ -12266,64 +12572,45 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return false;
return false;
}
}
// Check that the nodes don't have any unexpected uses
if (!ctx->device->subgroup_arithmetic ||
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
!ctx->device->subgroup_shuffle ||
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
!ctx->device->subgroup_require_full_support ||
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
ctx->device->disable_fusion) {
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
// softmax is used by reshape and argsort
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
reshape1->src[0] != softmax ||
argsort->src[0] != softmax) {
return false;
return false;
}
}
// reshape is used by get_rows
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
return true;
get_rows->src[0] != reshape1) {
}
static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx) {
GGML_UNUSED(ctx);
const ggml_tensor *rope = cgraph->nodes[node_idx + 0];
const ggml_tensor *view = cgraph->nodes[node_idx + 1];
const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2];
// ne3 not tested
if (rope->src[0]->ne[3] != 1) {
return false;
return false;
}
}
// argsort is used by view
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
view->src[0] != argsort) {
return false;
return false;
}
}
// view is written (via argsort), we can skip checking it
if (with_norm) {
// get_rows is used by reshape
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
reshape5->src[0] != get_rows) {
return false;
}
// reshape is used by sum_rows and div
if (set_rows->src[1]->type != GGML_TYPE_I64) {
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
return false;
sum_rows->src[0] != reshape5 ||
}
div->src[0] != reshape5) {
return false;
}
// sum_rows is used by div
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
div->src[1] != sum_rows) {
return false;
}
//
div/reshape are written
//
The view should flatten two dims of rope into one dim
if (
reshape8->src[0] != div) {
if (
!ggml_is_contiguous(view) ||
return false;
view->ne[0] != rope->ne[0] * rope->ne[1]) {
}
return false;
}
}
if (!ctx->device->subgroup_arithmetic ||
// Only norm/neox shaders have the fusion code
!ctx->device->subgroup_shuffle ||
const int mode = ((const int32_t *) rope->op_params)[2];
!ctx->device->subgroup_require_full_support ||
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
ctx->device->disable_fusion) {
return false;
return false;
}
}
...
@@ -12405,10 +12692,22 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
...
@@ -12405,10 +12692,22 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
ctx->num_additional_fused_ops = 2;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
...
@@ -12506,12 +12805,31 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
...
@@ -12506,12 +12805,31 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
ctx->num_additional_fused_ops = 2;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 3;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 3;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 1;
}
}
}
}
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
...
@@ -12557,6 +12875,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
...
@@ -12557,6 +12875,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
}
}
i += ctx->num_additional_fused_ops;
i += ctx->num_additional_fused_ops;
ctx->num_additional_fused_ops = 0;
ctx->num_additional_fused_ops = 0;
ctx->fused_ops_write_mask = 0;
}
}
if (vk_perf_logger_enabled) {
if (vk_perf_logger_enabled) {
...
@@ -12642,25 +12961,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
...
@@ -12642,25 +12961,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
std::vector<int> current_set;
// Avoid reordering topk_moe_norm
// Check for fusion patterns and avoid reordering them
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
bool is_topk_moe_norm = true;
if (start + (int)pattern.size() <= graph->n_nodes) {
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
bool is_pattern = true;
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
for (size_t j = 0; j < pattern.size(); ++j) {
is_topk_moe_norm = false;
if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
is_pattern = false;
}
}
}
return is_pattern;
}
}
if (is_topk_moe_norm) {
return false;
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
};
auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
if (match_pattern(pattern, first_unused)) {
for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
used[first_unused + j] = true;
}
}
while (first_unused < graph->n_nodes && used[first_unused]) {
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
first_unused++;
}
}
contin
ue;
return tr
ue;
}
}
return false;
};
if (keep_pattern(topk_moe_early_softmax_norm)) {
continue;
}
if (keep_pattern(topk_moe_early_softmax)) {
continue;
}
if (keep_pattern(topk_moe_late_softmax)) {
continue;
}
}
// First, grab the next unused node.
// First, grab the next unused node.
current_set.push_back(first_unused);
current_set.push_back(first_unused);
...
@@ -12678,6 +13016,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
...
@@ -12678,6 +13016,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (is_empty(graph->nodes[j])) {
if (is_empty(graph->nodes[j])) {
continue;
continue;
}
}
// Don't pull forward nodes from fusion patterns
if (match_pattern(topk_moe_early_softmax_norm, j) ||
match_pattern(topk_moe_early_softmax, j) ||
match_pattern(topk_moe_late_softmax, j)) {
continue;
}
bool ok = true;
bool ok = true;
for (int c = first_unused; c < j; ++c) {
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
if (!used[c] &&
...
@@ -12689,6 +13033,32 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
...
@@ -12689,6 +13033,32 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
}
}
if (ok) {
if (ok) {
current_set.push_back(j);
current_set.push_back(j);
// Look for ROPE + VIEW + SET_ROWS and make them consecutive
if (graph->nodes[j]->op == GGML_OP_ROPE) {
int view_idx = -1;
int set_rows_idx = -1;
for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) {
if (view_idx == -1 &&
graph->nodes[k]->op == GGML_OP_VIEW &&
graph->nodes[k]->src[0] == graph->nodes[j]) {
view_idx = k;
continue;
}
if (view_idx != -1 &&
set_rows_idx == -1 &&
graph->nodes[k]->op == GGML_OP_SET_ROWS &&
graph->nodes[k]->src[0] == graph->nodes[view_idx]) {
set_rows_idx = k;
break;
}
}
if (set_rows_idx != -1) {
current_set.push_back(view_idx);
current_set.push_back(set_rows_idx);
used[view_idx] = true;
used[set_rows_idx] = true;
}
}
}
}
}
}
// Second pass grabs view nodes.
// Second pass grabs view nodes.
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
View file @
3a9e8e9f
...
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
...
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
layout (push_constant) uniform parameter {
uint ncols;
uint ncols;
uint nrows;
uint order;
uint order;
} p;
} p;
...
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
...
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
dst_row[idx1] = tmp;
dst_row[idx1] = tmp;
}
}
void argsort(bool needs_bounds_check) {
void argsort(bool needs_bounds_check
, const uint row
) {
// bitonic sort
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
const uint row_offset = row * p.ncols;
...
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
...
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
void main() {
void main() {
if (p.ncols == BLOCK_SIZE) {
if (p.ncols == BLOCK_SIZE) {
argsort(false);
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
} else {
argsort(true);
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}
}
}
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
View file @
3a9e8e9f
...
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
...
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#if defined(DATA_A_MXFP4)
#if defined(DATA_A_MXFP4)
vec2
dequantize
(
uint
ib
,
uint
iqs
,
uint
a_offset
)
{
vec2
dequantize
(
uint
ib
,
uint
iqs
,
uint
a_offset
)
{
const
uint
vui
=
uint
(
data_a
[
a_offset
+
ib
].
qs
[
iqs
]);
const
uint
vui
=
uint
(
data_a
[
a_offset
+
ib
].
qs
[
iqs
]);
return
vec2
(
kvalues_mxfp4
[
vui
&
0xF
],
kvalues_mxfp4
[
vui
>>
4
]);
return
vec2
(
kvalues_mxfp4
[
vui
&
0xF
],
kvalues_mxfp4
[
vui
>>
4
])
*
0
.
5
;
}
}
vec4
dequantize4
(
uint
ib
,
uint
iqs
,
uint
a_offset
)
{
vec4
dequantize4
(
uint
ib
,
uint
iqs
,
uint
a_offset
)
{
vec2
v0
=
dequantize
(
ib
,
iqs
,
a_offset
);
vec2
v0
=
dequantize
(
ib
,
iqs
,
a_offset
);
...
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
...
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const
uvec2
qs
=
uvec2
(
data_a
[
a_offset
+
ib
].
qs
[
qsi
],
data_a
[
a_offset
+
ib
].
qs
[
qsi
+
1
]);
const
uvec2
qs
=
uvec2
(
data_a
[
a_offset
+
ib
].
qs
[
qsi
],
data_a
[
a_offset
+
ib
].
qs
[
qsi
+
1
]);
const
uint
scales
=
data_a
[
a_offset
+
ib
].
scales
[
scalesi
];
const
uint
scales
=
data_a
[
a_offset
+
ib
].
scales
[
scalesi
];
const
vec2
d
=
vec2
(
data_a
[
a_offset
+
ib
].
d
);
const
vec2
d
m
=
vec2
(
data_a
[
a_offset
+
ib
].
d
m
);
return
d
.
x
*
float
(
scales
&
0xF
)
*
vec2
((
qs
>>
qsshift
)
&
3
)
-
d
.
y
*
float
(
scales
>>
4
);
return
d
m
.
x
*
float
(
scales
&
0xF
)
*
vec2
((
qs
>>
qsshift
)
&
3
)
-
d
m
.
y
*
float
(
scales
>>
4
);
}
}
vec2
get_dm
(
uint
ib
,
uint
a_offset
)
{
vec2
get_dm
(
uint
ib
,
uint
a_offset
)
{
return
vec2
(
1
,
0
);
return
vec2
(
1
,
0
);
...
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
...
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const
uint
is
=
2
*
n
+
b
;
// 0..7
const
uint
is
=
2
*
n
+
b
;
// 0..7
const
uint
qsi
=
n
*
32
+
(
iqs
%
16
)
*
2
;
// 0,2,4..126
const
uint
qsi
=
n
*
32
+
(
iqs
%
16
)
*
2
;
// 0,2,4..126
const
vec2
loadd
=
vec2
(
data_a
[
a_offset
+
ib
].
d
);
const
vec2
loadd
=
vec2
(
data_a
[
a_offset
+
ib
].
d
m
);
const
uint
scidx0
=
(
is
<
4
)
?
is
:
(
is
+
4
);
const
uint
scidx0
=
(
is
<
4
)
?
is
:
(
is
+
4
);
const
uint
scidx1
=
(
is
<
4
)
?
is
:
(
is
-
4
);
const
uint
scidx1
=
(
is
<
4
)
?
is
:
(
is
-
4
);
...
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
...
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const
uint8_t
hm
=
uint8_t
(
1
<<
(
iqs
/
16
));
const
uint8_t
hm
=
uint8_t
(
1
<<
(
iqs
/
16
));
const
vec2
loadd
=
vec2
(
data_a
[
a_offset
+
ib
].
d
);
const
vec2
loadd
=
vec2
(
data_a
[
a_offset
+
ib
].
d
m
);
const
uint
scidx0
=
(
is
<
4
)
?
is
:
(
is
+
4
);
const
uint
scidx0
=
(
is
<
4
)
?
is
:
(
is
+
4
);
const
uint
scidx1
=
(
is
<
4
)
?
is
:
(
is
-
4
);
const
uint
scidx1
=
(
is
<
4
)
?
is
:
(
is
-
4
);
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
View file @
3a9e8e9f
...
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
...
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
float16_t
dequantFuncQ2_K
(
const
in
decodeBufQ2_K
bl
,
const
in
uint
blockCoords
[
2
],
const
in
uint
coordInBlock
[
2
])
float16_t
dequantFuncQ2_K
(
const
in
decodeBufQ2_K
bl
,
const
in
uint
blockCoords
[
2
],
const
in
uint
coordInBlock
[
2
])
{
{
decodeBufQ2_K_packed16
bl16
=
decodeBufQ2_K_packed16
(
bl
);
decodeBufQ2_K_packed16
bl16
=
decodeBufQ2_K_packed16
(
bl
);
const
f16vec2
d
=
bl
.
block
.
d
;
const
f16vec2
d
m
=
bl
.
block
.
d
m
;
const
uint
idx
=
coordInBlock
[
1
];
const
uint
idx
=
coordInBlock
[
1
];
const
uint
scalesi
=
(
idx
&
0xF0
)
>>
4
;
// 0..15
const
uint
scalesi
=
(
idx
&
0xF0
)
>>
4
;
// 0..15
...
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
...
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
qs
=
unpack8
(
qs
)[
idx
&
1
];
qs
=
unpack8
(
qs
)[
idx
&
1
];
const
uint
scales
=
bl
.
block
.
scales
[
scalesi
];
const
uint
scales
=
bl
.
block
.
scales
[
scalesi
];
float16_t
ret
=
d
.
x
*
float16_t
(
scales
&
0xF
)
*
float16_t
(
qs
)
-
d
.
y
*
float16_t
(
scales
>>
4
);
float16_t
ret
=
d
m
.
x
*
float16_t
(
scales
&
0xF
)
*
float16_t
(
qs
)
-
d
m
.
y
*
float16_t
(
scales
>>
4
);
return
ret
;
return
ret
;
}
}
...
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
...
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
uint32_t
qs
=
bl
.
block
.
qs
[
iqs
];
uint32_t
qs
=
bl
.
block
.
qs
[
iqs
];
qs
>>=
shift
;
qs
>>=
shift
;
qs
&=
0xF
;
qs
&=
0xF
;
float16_t
ret
=
float16_t
(
kvalues_mxfp4
[
qs
]
*
d
);
float16_t
ret
=
float16_t
(
kvalues_mxfp4
[
qs
]
*
d
*
0
.
5
);
return
ret
;
return
ret
;
}
}
#endif
#endif
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
View file @
3a9e8e9f
...
@@ -26,7 +26,7 @@ void main() {
...
@@ -26,7 +26,7 @@ void main() {
const float d = e8m0_to_fp32(data_a[ib].e);
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 0] = D_TYPE(d *
0.5 * float(
kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])
)
;
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
data_b[b_idx + l + 16] = D_TYPE(d *
0.5 * float(
kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])
)
;
}
}
}
}
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
View file @
3a9e8e9f
...
@@ -24,8 +24,8 @@ void main() {
...
@@ -24,8 +24,8 @@ void main() {
const uint ql_idx = 32 * ip + il;
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d
m
.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d
m
.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
View file @
3a9e8e9f
...
@@ -20,8 +20,8 @@ void main() {
...
@@ -20,8 +20,8 @@ void main() {
const uint is = 2 * il;
const uint is = 2 * il;
const uint n = 4;
const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d
m
.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d
m
.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
const uint qs_idx = 32*il + n * ir;
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
View file @
3a9e8e9f
...
@@ -19,8 +19,8 @@ void main() {
...
@@ -19,8 +19,8 @@ void main() {
const uint ir = tid % 16;
const uint ir = tid % 16;
const uint is = 2 * il;
const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d
m
.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d
m
.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
View file @
3a9e8e9f
...
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
...
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
...
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
...
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
}
temp[j][n] = fma(d
all
, sum1, fma(-dm
in
, sum2, temp[j][n]));
temp[j][n] = fma(d
m.x
, sum1, fma(-dm
.y
, sum2, temp[j][n]));
}
}
}
}
}
}
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
View file @
3a9e8e9f
...
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
...
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
...
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
...
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[j][n] = fma(d
all
, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm
in
, smin, temp[j][n]));
temp[j][n] = fma(d
m.x
, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm
.y
, smin, temp[j][n]));
}
}
}
}
}
}
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
View file @
3a9e8e9f
...
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
...
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
...
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
...
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(d
all
, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm
in
, smin, temp[j][n]));
temp[j][n] = fma(d
m.x
, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm
.y
, smin, temp[j][n]));
}
}
}
}
}
}
...
...
Prev
1
2
Next
xuxzh1
🎱
@xuxzh1
mentioned in commit
0cf7794b
·
Jan 09, 2026
mentioned in commit
0cf7794b
mentioned in commit 0cf7794b16fab8d4561bc5f6379f6d48bd59e101
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment