Unverified Commit 0cf7794b authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

ggml update to b7108 (#12992)

* Revert "vulkan: temporary cary of vulkan fixes (#12971)"

This reverts commit 3a9e8e9f.

* ggml update to b7087

* fix argsort on metal

* update to b7108

* fix bakllava regression

This model lacks the metadata for the projector type.

* update to b7209

* fix TopK perf

* only build arm code on arm
parent 854d40ed
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jesse Gross <jesse@ollama.com> From: Daniel Hiltgen <daniel@ollama.com>
Date: Thu, 24 Apr 2025 14:48:51 -0700 Date: Sun, 30 Nov 2025 11:05:56 -0800
Subject: [PATCH] ggml: Export GPU UUIDs Subject: [PATCH] ggml: Export GPU UUIDs
This enables matching up devices and information reported by the backend
with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml).
--- ---
ggml/include/ggml-backend.h | 1 + ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-cuda/ggml-cuda.cu | 67 +++++++++++++++++++++++++++--- ggml/src/ggml-cuda/ggml-cuda.cu | 67 +++++++++++++++++++++++++++---
...@@ -24,10 +22,10 @@ index c54ff98bf..229bf387b 100644 ...@@ -24,10 +22,10 @@ index c54ff98bf..229bf387b 100644
size_t memory_total; size_t memory_total;
// device type // device type
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index aefc6935e..cc201afff 100644 index 8f3b1c173..e803f4af6 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -183,6 +183,51 @@ static int ggml_cuda_parse_id(char devName[]) { @@ -185,6 +185,51 @@ static int ggml_cuda_parse_id(char devName[]) {
} }
#endif // defined(GGML_USE_HIP) #endif // defined(GGML_USE_HIP)
...@@ -79,7 +77,7 @@ index aefc6935e..cc201afff 100644 ...@@ -79,7 +77,7 @@ index aefc6935e..cc201afff 100644
static ggml_cuda_device_info ggml_cuda_init() { static ggml_cuda_device_info ggml_cuda_init() {
ggml_cuda_device_info info = {}; ggml_cuda_device_info info = {};
@@ -249,22 +294,24 @@ static ggml_cuda_device_info ggml_cuda_init() { @@ -251,22 +296,24 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc += prop.minor * 0x10; info.devices[id].cc += prop.minor * 0x10;
} }
} }
...@@ -110,7 +108,7 @@ index aefc6935e..cc201afff 100644 ...@@ -110,7 +108,7 @@ index aefc6935e..cc201afff 100644
std::string device_name(prop.name); std::string device_name(prop.name);
if (device_name == "NVIDIA GeForce MX450") { if (device_name == "NVIDIA GeForce MX450") {
turing_devices_without_mma.push_back({ id, device_name }); turing_devices_without_mma.push_back({ id, device_name });
@@ -3268,6 +3315,7 @@ struct ggml_backend_cuda_device_context { @@ -4048,6 +4095,7 @@ struct ggml_backend_cuda_device_context {
std::string name; std::string name;
std::string description; std::string description;
std::string pci_bus_id; std::string pci_bus_id;
...@@ -118,9 +116,9 @@ index aefc6935e..cc201afff 100644 ...@@ -118,9 +116,9 @@ index aefc6935e..cc201afff 100644
}; };
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -3280,6 +3328,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t @@ -4136,6 +4184,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k
return ctx->description.c_str();
} }
#endif // defined(__linux__)
+static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { +static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
...@@ -130,7 +128,7 @@ index aefc6935e..cc201afff 100644 ...@@ -130,7 +128,7 @@ index aefc6935e..cc201afff 100644
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device); ggml_cuda_set_device(ctx->device);
@@ -3296,6 +3349,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back @@ -4176,6 +4229,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->name = ggml_backend_cuda_device_get_name(dev); props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev); props->description = ggml_backend_cuda_device_get_description(dev);
...@@ -138,7 +136,7 @@ index aefc6935e..cc201afff 100644 ...@@ -138,7 +136,7 @@ index aefc6935e..cc201afff 100644
props->type = ggml_backend_cuda_device_get_type(dev); props->type = ggml_backend_cuda_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -3869,6 +3923,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { @@ -4767,6 +4821,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name; dev_ctx->description = prop.name;
...@@ -147,10 +145,10 @@ index aefc6935e..cc201afff 100644 ...@@ -147,10 +145,10 @@ index aefc6935e..cc201afff 100644
char pci_bus_id[16] = {}; char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index bf0962274..f2ff9f322 100644 index f2b7fe692..8fc1c2fb5 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp --- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -538,6 +538,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen @@ -547,6 +547,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_metal_device_get_name(dev); props->name = ggml_backend_metal_device_get_name(dev);
props->description = ggml_backend_metal_device_get_description(dev); props->description = ggml_backend_metal_device_get_description(dev);
......
...@@ -10,10 +10,10 @@ Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> ...@@ -10,10 +10,10 @@ Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2 files changed, 13 insertions(+) 2 files changed, 13 insertions(+)
diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp
index 4d487581a..35a0d25ed 100644 index dfad9cd79..9858de630 100644
--- a/tools/mtmd/mtmd.cpp --- a/tools/mtmd/mtmd.cpp
+++ b/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp
@@ -79,6 +79,16 @@ enum mtmd_slice_tmpl { @@ -87,6 +87,16 @@ enum mtmd_slice_tmpl {
MTMD_SLICE_TMPL_IDEFICS3, MTMD_SLICE_TMPL_IDEFICS3,
}; };
...@@ -31,7 +31,7 @@ index 4d487581a..35a0d25ed 100644 ...@@ -31,7 +31,7 @@ index 4d487581a..35a0d25ed 100644
return "<__media__>"; return "<__media__>";
} }
diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h
index f4ea07d3a..cf287224b 100644 index 015119be8..8d3fa5d34 100644
--- a/tools/mtmd/mtmd.h --- a/tools/mtmd/mtmd.h
+++ b/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h
@@ -75,6 +75,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk; @@ -75,6 +75,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk;
......
...@@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc ...@@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc
1 file changed, 1 insertion(+), 1 deletion(-) 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 4b2f8b7bd..046646282 100644 index 5be08d6f4..7a0df30c3 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c --- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -2441,7 +2441,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { @@ -2463,7 +2463,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place // Newer Windows 11 versions aggresively park (offline) CPU cores and often place
// all our threads onto the first 4 cores which results in terrible performance with // all our threads onto the first 4 cores which results in terrible performance with
// n_threads > 4 // n_threads > 4
......
...@@ -58,7 +58,7 @@ index 6792ba986..0f5b03cef 100644 ...@@ -58,7 +58,7 @@ index 6792ba986..0f5b03cef 100644
// (optional) event synchronization // (optional) event synchronization
// record an event on this stream // record an event on this stream
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index cb2b99562..41eef3b5f 100644 index ff41c7712..f511e8d76 100644
--- a/ggml/src/ggml-backend.cpp --- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp
@@ -348,14 +348,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba @@ -348,14 +348,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba
...@@ -97,7 +97,7 @@ index cb2b99562..41eef3b5f 100644 ...@@ -97,7 +97,7 @@ index cb2b99562..41eef3b5f 100644
for (int b = 0; b < src_backend_id; b++) { for (int b = 0; b < src_backend_id; b++) {
if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
SET_CAUSE(tensor, "1.off"); SET_CAUSE(tensor, "1.off");
@@ -1550,7 +1552,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s @@ -1556,7 +1558,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
} }
if (!sched->callback_eval) { if (!sched->callback_eval) {
...@@ -106,7 +106,7 @@ index cb2b99562..41eef3b5f 100644 ...@@ -106,7 +106,7 @@ index cb2b99562..41eef3b5f 100644
if (ec != GGML_STATUS_SUCCESS) { if (ec != GGML_STATUS_SUCCESS) {
return ec; return ec;
} }
@@ -1572,7 +1574,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s @@ -1578,7 +1580,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
...@@ -115,7 +115,7 @@ index cb2b99562..41eef3b5f 100644 ...@@ -115,7 +115,7 @@ index cb2b99562..41eef3b5f 100644
if (ec != GGML_STATUS_SUCCESS) { if (ec != GGML_STATUS_SUCCESS) {
return ec; return ec;
} }
@@ -1651,6 +1653,7 @@ ggml_backend_sched_t ggml_backend_sched_new( @@ -1657,6 +1659,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
sched->op_offload = op_offload; sched->op_offload = op_offload;
...@@ -123,7 +123,7 @@ index cb2b99562..41eef3b5f 100644 ...@@ -123,7 +123,7 @@ index cb2b99562..41eef3b5f 100644
ggml_backend_sched_reset(sched); ggml_backend_sched_reset(sched);
@@ -1682,6 +1685,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { @@ -1688,6 +1691,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
free(sched); free(sched);
} }
...@@ -178,10 +178,10 @@ index 3191faaa4..32f14c811 100644 ...@@ -178,10 +178,10 @@ index 3191faaa4..32f14c811 100644
static const struct ggml_backend_i ggml_backend_cpu_i = { static const struct ggml_backend_i ggml_backend_cpu_i = {
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index cc201afff..02d413467 100644 index e803f4af6..78fb2d8b3 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2693,7 +2693,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { @@ -2885,7 +2885,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility(ggml_cgraph * cgraph, static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
...@@ -190,7 +190,7 @@ index cc201afff..02d413467 100644 ...@@ -190,7 +190,7 @@ index cc201afff..02d413467 100644
// Loop over nodes in GGML graph to obtain info needed for CUDA graph // Loop over nodes in GGML graph to obtain info needed for CUDA graph
@@ -2726,24 +2726,34 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, @@ -2918,24 +2918,34 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
#endif #endif
} }
...@@ -241,7 +241,7 @@ index cc201afff..02d413467 100644 ...@@ -241,7 +241,7 @@ index cc201afff..02d413467 100644
} }
if (!use_cuda_graph) { if (!use_cuda_graph) {
@@ -3128,7 +3138,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx @@ -3679,7 +3689,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
} }
} }
...@@ -250,7 +250,7 @@ index cc201afff..02d413467 100644 ...@@ -250,7 +250,7 @@ index cc201afff..02d413467 100644
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device); ggml_cuda_set_device(cuda_ctx->device);
@@ -3166,7 +3176,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, @@ -3717,7 +3727,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
if (use_cuda_graph) { if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
...@@ -260,10 +260,10 @@ index cc201afff..02d413467 100644 ...@@ -260,10 +260,10 @@ index cc201afff..02d413467 100644
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) { if (use_cuda_graph && cuda_graph_update_required) {
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index f2ff9f322..05ff6a5a6 100644 index 8fc1c2fb5..ba95b4acc 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp --- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -410,10 +410,12 @@ static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml @@ -419,10 +419,12 @@ static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml
GGML_UNUSED(dst); GGML_UNUSED(dst);
} }
...@@ -278,10 +278,10 @@ index f2ff9f322..05ff6a5a6 100644 ...@@ -278,10 +278,10 @@ index f2ff9f322..05ff6a5a6 100644
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 216dc167c..3a6bbe564 100644 index 83cdec29e..a36c6560c 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -12357,7 +12357,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru @@ -13103,7 +13103,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
return num_adds; return num_adds;
} }
...@@ -290,7 +290,7 @@ index 216dc167c..3a6bbe564 100644 ...@@ -290,7 +290,7 @@ index 216dc167c..3a6bbe564 100644
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -12561,6 +12561,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg @@ -13320,6 +13320,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
UNUSED(backend); UNUSED(backend);
......
...@@ -12,8 +12,8 @@ must be recreated with no-alloc set to false before loading data. ...@@ -12,8 +12,8 @@ must be recreated with no-alloc set to false before loading data.
ggml/src/ggml-backend-impl.h | 16 +++ ggml/src/ggml-backend-impl.h | 16 +++
ggml/src/ggml-backend.cpp | 72 ++++++++++- ggml/src/ggml-backend.cpp | 72 ++++++++++-
ggml/src/ggml-cuda/common.cuh | 58 ++++++++- ggml/src/ggml-cuda/common.cuh | 58 ++++++++-
ggml/src/ggml-cuda/ggml-cuda.cu | 217 ++++++++++++++++++++++++++------ ggml/src/ggml-cuda/ggml-cuda.cu | 218 ++++++++++++++++++++++++++------
5 files changed, 320 insertions(+), 44 deletions(-) 5 files changed, 321 insertions(+), 44 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 2763f2bd6..b3b5b356a 100644 index 2763f2bd6..b3b5b356a 100644
...@@ -75,7 +75,7 @@ index 0f5b03cef..7bdf9d81f 100644 ...@@ -75,7 +75,7 @@ index 0f5b03cef..7bdf9d81f 100644
struct ggml_backend { struct ggml_backend {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 41eef3b5f..c81a2e48a 100644 index f511e8d76..74b7f070c 100644
--- a/ggml/src/ggml-backend.cpp --- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp
@@ -41,6 +41,19 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t @@ -41,6 +41,19 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t
...@@ -134,7 +134,7 @@ index 41eef3b5f..c81a2e48a 100644 ...@@ -134,7 +134,7 @@ index 41eef3b5f..c81a2e48a 100644
}; };
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
@@ -1608,6 +1634,17 @@ ggml_backend_sched_t ggml_backend_sched_new( @@ -1614,6 +1640,17 @@ ggml_backend_sched_t ggml_backend_sched_new(
size_t graph_size, size_t graph_size,
bool parallel, bool parallel,
bool op_offload) { bool op_offload) {
...@@ -152,7 +152,7 @@ index 41eef3b5f..c81a2e48a 100644 ...@@ -152,7 +152,7 @@ index 41eef3b5f..c81a2e48a 100644
GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends > 0);
GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
@@ -1649,11 +1686,14 @@ ggml_backend_sched_t ggml_backend_sched_new( @@ -1655,11 +1692,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
sched->events[b][c] = ggml_backend_event_new(backends[b]->device); sched->events[b][c] = ggml_backend_event_new(backends[b]->device);
} }
} }
...@@ -167,7 +167,7 @@ index 41eef3b5f..c81a2e48a 100644 ...@@ -167,7 +167,7 @@ index 41eef3b5f..c81a2e48a 100644
ggml_backend_sched_reset(sched); ggml_backend_sched_reset(sched);
@@ -1668,6 +1708,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { @@ -1674,6 +1714,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
for (int c = 0; c < sched->n_copies; c++) { for (int c = 0; c < sched->n_copies; c++) {
ggml_backend_event_free(sched->events[b][c]); ggml_backend_event_free(sched->events[b][c]);
} }
...@@ -178,7 +178,7 @@ index 41eef3b5f..c81a2e48a 100644 ...@@ -178,7 +178,7 @@ index 41eef3b5f..c81a2e48a 100644
} }
ggml_gallocr_free(sched->galloc); ggml_gallocr_free(sched->galloc);
ggml_free(sched->ctx); ggml_free(sched->ctx);
@@ -1715,6 +1759,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * @@ -1719,6 +1763,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
return false; return false;
} }
...@@ -203,7 +203,7 @@ index 41eef3b5f..c81a2e48a 100644 ...@@ -203,7 +203,7 @@ index 41eef3b5f..c81a2e48a 100644
ggml_backend_sched_reset(sched); ggml_backend_sched_reset(sched);
return true; return true;
@@ -1820,7 +1882,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, @@ -1824,7 +1886,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched,
int backend_index = ggml_backend_sched_backend_id(sched, backend); int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
...@@ -219,10 +219,10 @@ index 41eef3b5f..c81a2e48a 100644 ...@@ -219,10 +219,10 @@ index 41eef3b5f..c81a2e48a 100644
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 41ff89c4d..2931c15ca 100644 index 611341deb..c3f8ca914 100644
--- a/ggml/src/ggml-cuda/common.cuh --- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh
@@ -35,6 +35,41 @@ @@ -37,6 +37,41 @@
#include "vendors/cuda.h" #include "vendors/cuda.h"
#endif // defined(GGML_USE_HIP) #endif // defined(GGML_USE_HIP)
...@@ -264,7 +264,7 @@ index 41ff89c4d..2931c15ca 100644 ...@@ -264,7 +264,7 @@ index 41ff89c4d..2931c15ca 100644
#define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
@@ -856,6 +891,9 @@ struct ggml_cuda_pool { @@ -891,6 +926,9 @@ struct ggml_cuda_pool {
virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void * alloc(size_t size, size_t * actual_size) = 0;
virtual void free(void * ptr, size_t size) = 0; virtual void free(void * ptr, size_t size) = 0;
...@@ -274,46 +274,48 @@ index 41ff89c4d..2931c15ca 100644 ...@@ -274,46 +274,48 @@ index 41ff89c4d..2931c15ca 100644
}; };
template<typename T> template<typename T>
@@ -992,11 +1030,11 @@ struct ggml_backend_cuda_context { @@ -1179,11 +1217,11 @@ struct ggml_backend_cuda_context {
// pool // pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES]; std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
- static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device); - static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
+ static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, bool alloc); + static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no, bool alloc);
ggml_cuda_pool & pool(int device) { ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) { if (pools[device][curr_stream_no] == nullptr) {
- pools[device] = new_pool_for_device(device); - pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
+ pools[device] = new_pool_for_device(device, true); + pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no, true);
} }
return *pools[device]; return *pools[device][curr_stream_no];
} }
@@ -1004,4 +1042,20 @@ struct ggml_backend_cuda_context { @@ -1191,6 +1229,22 @@ struct ggml_backend_cuda_context {
ggml_cuda_pool & pool() { ggml_cuda_pool & pool() {
return pool(device); return pool(device);
} }
+ +
+ void pool_set_alloc(bool alloc) { + void pool_set_alloc(bool alloc) {
+ GGML_ASSERT(pools[device] == nullptr || pools[device]->alloc_memory() == alloc); + GGML_ASSERT(pools[device][curr_stream_no] == nullptr || pools[device][curr_stream_no]->alloc_memory() == alloc);
+ +
+ if (pools[device] == nullptr) { + if (pools[device][curr_stream_no] == nullptr) {
+ pools[device] = new_pool_for_device(device, alloc); + pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no, alloc);
+ } + }
+ } + }
+ +
+ size_t pool_get_alloc_size() { + size_t pool_get_alloc_size() {
+ if (pools[device] == nullptr) { + if (pools[device][curr_stream_no] == nullptr) {
+ return 0; + return 0;
+ } + }
+ +
+ return pools[device]->alloc_size(); + return pools[device][curr_stream_no]->alloc_size();
+ } + }
}; };
struct ggml_cuda_mm_fusion_args_host {
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 02d413467..f79e5d65c 100644 index 78fb2d8b3..fe0da71ca 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -359,6 +359,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { @@ -361,6 +361,8 @@ const ggml_cuda_device_info & ggml_cuda_info() {
// #define DEBUG_CUDA_MALLOC // #define DEBUG_CUDA_MALLOC
...@@ -322,7 +324,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -322,7 +324,7 @@ index 02d413467..f79e5d65c 100644
// buffer pool for cuda (legacy) // buffer pool for cuda (legacy)
struct ggml_cuda_pool_leg : public ggml_cuda_pool { struct ggml_cuda_pool_leg : public ggml_cuda_pool {
static const int MAX_BUFFERS = 256; static const int MAX_BUFFERS = 256;
@@ -371,9 +373,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { @@ -373,9 +375,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
size_t pool_size = 0; size_t pool_size = 0;
...@@ -337,7 +339,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -337,7 +339,7 @@ index 02d413467..f79e5d65c 100644
} }
~ggml_cuda_pool_leg() { ~ggml_cuda_pool_leg() {
@@ -381,7 +386,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { @@ -383,7 +388,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
for (int i = 0; i < MAX_BUFFERS; ++i) { for (int i = 0; i < MAX_BUFFERS; ++i) {
ggml_cuda_buffer & b = buffer_pool[i]; ggml_cuda_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) { if (b.ptr != nullptr) {
...@@ -348,7 +350,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -348,7 +350,7 @@ index 02d413467..f79e5d65c 100644
pool_size -= b.size; pool_size -= b.size;
} }
} }
@@ -429,8 +436,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { @@ -431,8 +438,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
void * ptr; void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size); size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256); look_ahead_size = 256 * ((look_ahead_size + 255)/256);
...@@ -366,7 +368,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -366,7 +368,7 @@ index 02d413467..f79e5d65c 100644
*actual_size = look_ahead_size; *actual_size = look_ahead_size;
pool_size += look_ahead_size; pool_size += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC #ifdef DEBUG_CUDA_MALLOC
@@ -450,10 +464,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { @@ -452,10 +466,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
} }
} }
GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
...@@ -389,7 +391,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -389,7 +391,7 @@ index 02d413467..f79e5d65c 100644
}; };
// pool with virtual memory // pool with virtual memory
@@ -465,18 +489,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { @@ -467,18 +491,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
CUdeviceptr pool_addr = 0; CUdeviceptr pool_addr = 0;
size_t pool_used = 0; size_t pool_used = 0;
size_t pool_size = 0; size_t pool_size = 0;
...@@ -417,7 +419,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -417,7 +419,7 @@ index 02d413467..f79e5d65c 100644
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
// Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
for (std::pair<CUdeviceptr, size_t> & mapping : mappings) { for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
@@ -503,35 +533,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { @@ -505,35 +535,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
...@@ -493,7 +495,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -493,7 +495,7 @@ index 02d413467..f79e5d65c 100644
// add to the pool // add to the pool
pool_size += reserve_size; pool_size += reserve_size;
@@ -564,16 +608,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { @@ -566,17 +610,27 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
// all deallocations must be in reverse order of the allocations // all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
} }
...@@ -505,11 +507,14 @@ index 02d413467..f79e5d65c 100644 ...@@ -505,11 +507,14 @@ index 02d413467..f79e5d65c 100644
+ size_t alloc_size() override { + size_t alloc_size() override {
+ return pool_size + last_alloc; + return pool_size + last_alloc;
+ } + }
+
}; };
#endif // defined(GGML_USE_VMM) #endif // defined(GGML_USE_VMM)
-std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) { std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
+std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device, bool alloc) { - [[maybe_unused]] int stream_no) {
+ [[maybe_unused]] int stream_no,
+ bool alloc) {
#if defined(GGML_USE_VMM) #if defined(GGML_USE_VMM)
if (ggml_cuda_info().devices[device].vmm) { if (ggml_cuda_info().devices[device].vmm) {
- return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device)); - return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
...@@ -521,7 +526,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -521,7 +526,7 @@ index 02d413467..f79e5d65c 100644
} }
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
@@ -757,11 +809,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac @@ -760,11 +814,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
} }
static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
...@@ -543,7 +548,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -543,7 +548,7 @@ index 02d413467..f79e5d65c 100644
static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
size_t size = ggml_nbytes(tensor); size_t size = ggml_nbytes(tensor);
int64_t ne0 = tensor->ne[0]; int64_t ne0 = tensor->ne[0];
@@ -785,6 +846,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface @@ -788,6 +851,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface
/* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
/* .is_host = */ NULL, /* .is_host = */ NULL,
...@@ -551,7 +556,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -551,7 +556,7 @@ index 02d413467..f79e5d65c 100644
}; };
ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
@@ -2986,6 +3048,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, @@ -3258,6 +3322,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
...@@ -559,7 +564,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -559,7 +564,7 @@ index 02d413467..f79e5d65c 100644
// flag used to determine whether it is an integrated_gpu // flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
@@ -3001,6 +3064,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx @@ -3347,6 +3412,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
continue; continue;
} }
...@@ -567,11 +572,10 @@ index 02d413467..f79e5d65c 100644 ...@@ -567,11 +572,10 @@ index 02d413467..f79e5d65c 100644
+ if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { + if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
+ continue; + continue;
+ } + }
+
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
@@ -3140,6 +3208,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
@@ -3691,6 +3760,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
...@@ -579,7 +583,7 @@ index 02d413467..f79e5d65c 100644 ...@@ -579,7 +583,7 @@ index 02d413467..f79e5d65c 100644
ggml_cuda_set_device(cuda_ctx->device); ggml_cuda_set_device(cuda_ctx->device);
@@ -3215,6 +3284,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, @@ -3766,6 +3836,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
...@@ -645,16 +649,16 @@ index 02d413467..f79e5d65c 100644 ...@@ -645,16 +649,16 @@ index 02d413467..f79e5d65c 100644
+ +
+static void ggml_backend_cuda_reset(ggml_backend_t backend) { +static void ggml_backend_cuda_reset(ggml_backend_t backend) {
+ ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context;
+ ctx->pools[ctx->device] = NULL; + ctx->pools[ctx->device][ctx->curr_stream_no] = NULL;
+} +}
+ +
static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
@@ -3255,6 +3389,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { @@ -4035,6 +4170,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .event_record = */ ggml_backend_cuda_event_record, /* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait, /* .event_wait = */ ggml_backend_cuda_event_wait,
/* .graph_optimize = */ NULL, /* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
+ /* .graph_reserve = */ ggml_backend_cuda_graph_reserve, + /* .graph_reserve = */ ggml_backend_cuda_graph_reserve,
+ /* .buffer_size = */ ggml_backend_cuda_buffer_size, + /* .buffer_size = */ ggml_backend_cuda_buffer_size,
+ /* .reset = */ ggml_backend_cuda_reset, + /* .reset = */ ggml_backend_cuda_reset,
......
...@@ -8,12 +8,12 @@ Subject: [PATCH] decode: disable output_all ...@@ -8,12 +8,12 @@ Subject: [PATCH] decode: disable output_all
1 file changed, 1 insertion(+), 2 deletions(-) 1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/src/llama-context.cpp b/src/llama-context.cpp diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index bd348bcad..8b4a89d38 100644 index e04f0fc4f..1359c614b 100644
--- a/src/llama-context.cpp --- a/src/llama-context.cpp
+++ b/src/llama-context.cpp +++ b/src/llama-context.cpp
@@ -974,8 +974,7 @@ int llama_context::decode(const llama_batch & batch_inp) { @@ -999,8 +999,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const int64_t n_vocab = vocab.n_tokens(); const int64_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd_inp();
- // when computing embeddings, all tokens are output - // when computing embeddings, all tokens are output
- const bool output_all = cparams.embeddings; - const bool output_all = cparams.embeddings;
......
...@@ -43,7 +43,7 @@ index 7bdf9d81f..21b35ac5c 100644 ...@@ -43,7 +43,7 @@ index 7bdf9d81f..21b35ac5c 100644
struct ggml_backend_device { struct ggml_backend_device {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index c81a2e48a..9b0a9b91f 100644 index 74b7f070c..8d2cc167f 100644
--- a/ggml/src/ggml-backend.cpp --- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp
@@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par @@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
...@@ -62,10 +62,10 @@ index c81a2e48a..9b0a9b91f 100644 ...@@ -62,10 +62,10 @@ index c81a2e48a..9b0a9b91f 100644
GGML_ASSERT(device); GGML_ASSERT(device);
return device->iface.get_buffer_type(device); return device->iface.get_buffer_type(device);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f79e5d65c..c9333689f 100644 index fe0da71ca..0787e443c 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -107,6 +107,11 @@ int ggml_cuda_get_device() { @@ -109,6 +109,11 @@ int ggml_cuda_get_device() {
return id; return id;
} }
...@@ -77,7 +77,7 @@ index f79e5d65c..c9333689f 100644 ...@@ -77,7 +77,7 @@ index f79e5d65c..c9333689f 100644
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device); ggml_cuda_set_device(device);
cudaError_t err; cudaError_t err;
@@ -3499,7 +3504,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back @@ -4380,7 +4385,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->id = ggml_backend_cuda_device_get_id(dev); props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev); props->type = ggml_backend_cuda_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
...@@ -89,7 +89,7 @@ index f79e5d65c..c9333689f 100644 ...@@ -89,7 +89,7 @@ index f79e5d65c..c9333689f 100644
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY #ifdef GGML_CUDA_NO_PEER_COPY
@@ -3936,6 +3944,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g @@ -4835,6 +4843,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
} }
...@@ -101,7 +101,7 @@ index f79e5d65c..c9333689f 100644 ...@@ -101,7 +101,7 @@ index f79e5d65c..c9333689f 100644
static const ggml_backend_device_i ggml_backend_cuda_device_interface = { static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description, /* .get_description = */ ggml_backend_cuda_device_get_description,
@@ -3952,6 +3965,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { @@ -4851,6 +4864,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
...@@ -110,7 +110,7 @@ index f79e5d65c..c9333689f 100644 ...@@ -110,7 +110,7 @@ index f79e5d65c..c9333689f 100644
// backend reg // backend reg
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 890c10364..1f06be80e 100644 index b7d6edf7f..b987d7aeb 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h --- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -45,6 +45,7 @@ @@ -45,6 +45,7 @@
......
...@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling ...@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
ggml/src/ggml-cuda/vendors/hip.h | 3 + ggml/src/ggml-cuda/vendors/hip.h | 3 +
ggml/src/ggml-impl.h | 8 + ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 2 + ggml/src/ggml-metal/ggml-metal.cpp | 2 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 209 +++++++++-- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 ++++++++-
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++ ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 +++++++++++ ggml/src/mem_nvml.cpp | 209 +++++++++++
9 files changed, 1003 insertions(+), 30 deletions(-) 9 files changed, 976 insertions(+), 17 deletions(-)
create mode 100644 ggml/src/mem_hip.cpp create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp create mode 100644 ggml/src/mem_nvml.cpp
...@@ -45,7 +45,7 @@ index 69223c488..6510e0cba 100644 ...@@ -45,7 +45,7 @@ index 69223c488..6510e0cba 100644
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index f9a6587f1..03f359ae9 100644 index 6d493a4ff..ac8f38464 100644
--- a/ggml/src/CMakeLists.txt --- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt
@@ -209,6 +209,8 @@ add_library(ggml-base @@ -209,6 +209,8 @@ add_library(ggml-base
...@@ -56,12 +56,12 @@ index f9a6587f1..03f359ae9 100644 ...@@ -56,12 +56,12 @@ index f9a6587f1..03f359ae9 100644
+ mem_nvml.cpp + mem_nvml.cpp
gguf.cpp) gguf.cpp)
target_include_directories(ggml-base PRIVATE .) set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c9333689f..f1a20e7fe 100644 index 0787e443c..736d47c07 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() { @@ -263,6 +263,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
for (int id = 0; id < info.device_count; ++id) { for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0; int device_vmm = 0;
...@@ -78,7 +78,7 @@ index c9333689f..f1a20e7fe 100644 ...@@ -78,7 +78,7 @@ index c9333689f..f1a20e7fe 100644
#if defined(GGML_USE_VMM) #if defined(GGML_USE_VMM)
CUdevice device; CUdevice device;
CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGet(&device, id));
@@ -314,6 +324,11 @@ static ggml_cuda_device_info ggml_cuda_init() { @@ -316,6 +326,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
#else #else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor; info.devices[id].cc = 100*prop.major + 10*prop.minor;
...@@ -90,7 +90,7 @@ index c9333689f..f1a20e7fe 100644 ...@@ -90,7 +90,7 @@ index c9333689f..f1a20e7fe 100644
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
ggml_cuda_parse_uuid(prop, id).c_str()); ggml_cuda_parse_uuid(prop, id).c_str());
@@ -3468,6 +3483,11 @@ struct ggml_backend_cuda_device_context { @@ -4249,6 +4264,11 @@ struct ggml_backend_cuda_device_context {
std::string description; std::string description;
std::string pci_bus_id; std::string pci_bus_id;
std::string id; std::string id;
...@@ -102,7 +102,7 @@ index c9333689f..f1a20e7fe 100644 ...@@ -102,7 +102,7 @@ index c9333689f..f1a20e7fe 100644
}; };
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -3488,6 +3508,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { @@ -4345,6 +4365,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device); ggml_cuda_set_device(ctx->device);
...@@ -129,9 +129,9 @@ index c9333689f..f1a20e7fe 100644 ...@@ -129,9 +129,9 @@ index c9333689f..f1a20e7fe 100644
+ } + }
+#endif +#endif
CUDA_CHECK(cudaMemGetInfo(free, total)); CUDA_CHECK(cudaMemGetInfo(free, total));
}
@@ -3496,6 +3538,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend // ref: https://github.com/ggml-org/llama.cpp/pull/17368
@@ -4377,6 +4419,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
return GGML_BACKEND_DEVICE_TYPE_GPU; return GGML_BACKEND_DEVICE_TYPE_GPU;
} }
...@@ -139,7 +139,7 @@ index c9333689f..f1a20e7fe 100644 ...@@ -139,7 +139,7 @@ index c9333689f..f1a20e7fe 100644
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
@@ -3509,6 +3552,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back @@ -4390,6 +4433,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
// If you need the memory data, call ggml_backend_dev_memory() explicitly. // If you need the memory data, call ggml_backend_dev_memory() explicitly.
props->memory_total = props->memory_free = 0; props->memory_total = props->memory_free = 0;
...@@ -159,7 +159,7 @@ index c9333689f..f1a20e7fe 100644 ...@@ -159,7 +159,7 @@ index c9333689f..f1a20e7fe 100644
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY #ifdef GGML_CUDA_NO_PEER_COPY
bool events = false; bool events = false;
@@ -4075,6 +4131,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { @@ -4974,6 +5030,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (!initialized) { if (!initialized) {
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
...@@ -167,7 +167,7 @@ index c9333689f..f1a20e7fe 100644 ...@@ -167,7 +167,7 @@ index c9333689f..f1a20e7fe 100644
for (int i = 0; i < ggml_cuda_info().device_count; i++) { for (int i = 0; i < ggml_cuda_info().device_count; i++) {
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
@@ -4090,6 +4147,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { @@ -4989,6 +5046,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
dev_ctx->pci_bus_id = pci_bus_id; dev_ctx->pci_bus_id = pci_bus_id;
...@@ -183,7 +183,7 @@ index c9333689f..f1a20e7fe 100644 ...@@ -183,7 +183,7 @@ index c9333689f..f1a20e7fe 100644
/* .iface = */ ggml_backend_cuda_device_interface, /* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg, /* .reg = */ &reg,
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 1f06be80e..2f9ef2dc0 100644 index b987d7aeb..5ad5623ae 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h --- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -5,6 +5,8 @@ @@ -5,6 +5,8 @@
...@@ -204,7 +204,7 @@ index 1f06be80e..2f9ef2dc0 100644 ...@@ -204,7 +204,7 @@ index 1f06be80e..2f9ef2dc0 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index e9201cdc6..44ae76d66 100644 index fe57d4c58..1c07e767a 100644
--- a/ggml/src/ggml-impl.h --- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h
@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, @@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
...@@ -223,10 +223,10 @@ index e9201cdc6..44ae76d66 100644 ...@@ -223,10 +223,10 @@ index e9201cdc6..44ae76d66 100644
} }
#endif #endif
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index 05ff6a5a6..032dee76d 100644 index ba95b4acc..f6f8f7a10 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp --- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -537,6 +537,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen @@ -546,6 +546,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
GGML_UNUSED(dev); GGML_UNUSED(dev);
} }
...@@ -234,7 +234,7 @@ index 05ff6a5a6..032dee76d 100644 ...@@ -234,7 +234,7 @@ index 05ff6a5a6..032dee76d 100644
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_metal_device_get_name(dev); props->name = ggml_backend_metal_device_get_name(dev);
props->description = ggml_backend_metal_device_get_description(dev); props->description = ggml_backend_metal_device_get_description(dev);
@@ -545,6 +546,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac @@ -554,6 +555,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
...@@ -243,18 +243,18 @@ index 05ff6a5a6..032dee76d 100644 ...@@ -243,18 +243,18 @@ index 05ff6a5a6..032dee76d 100644
/* .async = */ true, /* .async = */ true,
/* .host_buffer = */ false, /* .host_buffer = */ false,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 3a6bbe564..ca02ea079 100644 index a36c6560c..a234eda2e 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -229,6 +229,7 @@ class vk_memory_logger; @@ -236,6 +236,7 @@ class vk_memory_logger;
#endif
class vk_perf_logger; class vk_perf_logger;
static void ggml_vk_destroy_buffer(vk_buffer& buf); static void ggml_vk_destroy_buffer(vk_buffer& buf);
static void ggml_vk_synchronize(ggml_backend_vk_context * ctx);
+static std::string ggml_vk_get_device_id(int device); +static std::string ggml_vk_get_device_id(int device);
static constexpr uint32_t mul_mat_vec_max_cols = 8; static constexpr uint32_t mul_mat_vec_max_cols = 8;
static constexpr uint32_t p021_max_gqa_ratio = 8; static constexpr uint32_t p021_max_gqa_ratio = 8;
@@ -11813,6 +11814,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ @@ -12353,6 +12354,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
snprintf(description, description_size, "%s", props.deviceName.data()); snprintf(description, description_size, "%s", props.deviceName.data());
} }
...@@ -284,7 +284,7 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -284,7 +284,7 @@ index 3a6bbe564..ca02ea079 100644
// backend interface // backend interface
#define UNUSED GGML_UNUSED #define UNUSED GGML_UNUSED
@@ -12761,31 +12785,102 @@ void ggml_backend_vk_get_device_description(int device, char * description, size @@ -13614,15 +13638,72 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
ggml_vk_get_device_description(dev_idx, description, description_size); ggml_vk_get_device_description(dev_idx, description, description_size);
} }
...@@ -312,24 +312,23 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -312,24 +312,23 @@ index 3a6bbe564..ca02ea079 100644
+ int driver_major; + int driver_major;
+ int driver_minor; + int driver_minor;
+}; +};
+
- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
+void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { +void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) {
+ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); + GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size());
+ GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); + GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size());
+ +
+ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]];
vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; vk::PhysicalDeviceMemoryProperties2 memprops = {};
- vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; - const bool membudget_supported = vk_instance.device_supports_membudget[device];
- vk::PhysicalDeviceMemoryProperties2 memprops = {}; + const bool membudget_supported = vk_instance.device_supports_membudget[ctx->device];
- bool membudget_supported = vk_instance.device_supports_membudget[device]; const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); +
+ vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceProperties2 props2;
+ vkdev.getProperties2(&props2); + vkdev.getProperties2(&props2);
+
- if (membudget_supported) { + if (!is_integrated_gpu)
- memprops.pNext = &budgetprops;
+ if (!ctx->is_integrated_gpu)
+ { + {
+ // Use vendor specific management libraries for best VRAM reporting if available + // Use vendor specific management libraries for best VRAM reporting if available
+ switch (props2.properties.vendorID) { + switch (props2.properties.vendorID) {
...@@ -356,55 +355,13 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -356,55 +355,13 @@ index 3a6bbe564..ca02ea079 100644
+ } + }
+ break; + break;
+ } + }
}
- vkdev.getMemoryProperties2(&memprops);
+ // else fallback to memory budget if supported
- for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
- const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
+ *total = 0;
+ *free = 0;
+ vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props;
+ vk::PhysicalDeviceMemoryProperties2 memprops2;
+ memprops2.pNext = &mem_budget_props;
+ vkdev.getMemoryProperties2(&memprops2);
+ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) {
+ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *total += memprops2.memoryProperties.memoryHeaps[i].size;
+ } else if (ctx->is_integrated_gpu) {
+ // Include shared memory on iGPUs
+ *total += memprops2.memoryProperties.memoryHeaps[i].size;
+ }
+ }
+ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) {
+ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *free += mem_budget_props.heapBudget[i];
+ } else if (ctx->is_integrated_gpu) {
+ *free += mem_budget_props.heapBudget[i];
+ }
+ }
+ if (*total > 0 && *free > 0) {
+ return;
+ } else if (*total > 0) {
+ *free = *total;
+ return;
+ } + }
+ // else fallback to memory budget if supported
+
+ // else just report the physical memory if (membudget_supported) {
+ for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { memprops.pNext = &budgetprops;
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { @@ -13674,8 +13755,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
*total = heap.size;
-
- if (membudget_supported && i < budgetprops.heapUsage.size()) {
- *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
- } else {
- *free = heap.size;
- }
+ *free = heap.size;
break;
}
}
@@ -12818,8 +12913,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
} }
} }
...@@ -419,7 +376,7 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -419,7 +376,7 @@ index 3a6bbe564..ca02ea079 100644
} }
vk::PhysicalDeviceProperties2 props = {}; vk::PhysicalDeviceProperties2 props = {};
@@ -12836,19 +12936,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { @@ -13692,19 +13778,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
char pci_bus_id[16] = {}; char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
...@@ -453,7 +410,7 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -453,7 +410,7 @@ index 3a6bbe564..ca02ea079 100644
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
@@ -12860,9 +12965,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de @@ -13716,9 +13807,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
return ctx->description.c_str(); return ctx->description.c_str();
} }
...@@ -469,7 +426,7 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -469,7 +426,7 @@ index 3a6bbe564..ca02ea079 100644
} }
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -12886,8 +12996,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml @@ -13742,8 +13838,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->name = ggml_backend_vk_device_get_name(dev); props->name = ggml_backend_vk_device_get_name(dev);
props->description = ggml_backend_vk_device_get_description(dev); props->description = ggml_backend_vk_device_get_description(dev);
...@@ -480,7 +437,7 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -480,7 +437,7 @@ index 3a6bbe564..ca02ea079 100644
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = { props->caps = {
/* .async = */ false, /* .async = */ false,
@@ -12895,6 +13006,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml @@ -13751,6 +13848,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
/* .buffer_from_host_ptr = */ false, /* .buffer_from_host_ptr = */ false,
/* .events = */ false, /* .events = */ false,
}; };
...@@ -494,7 +451,7 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -494,7 +451,7 @@ index 3a6bbe564..ca02ea079 100644
} }
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
@@ -13365,6 +13483,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, @@ -14319,6 +14423,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
static std::mutex mutex; static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (!initialized) { if (!initialized) {
...@@ -503,7 +460,7 @@ index 3a6bbe564..ca02ea079 100644 ...@@ -503,7 +460,7 @@ index 3a6bbe564..ca02ea079 100644
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
char desc[256]; char desc[256];
@@ -13373,12 +13493,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, @@ -14327,12 +14433,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->name = GGML_VK_NAME + std::to_string(i); ctx->name = GGML_VK_NAME + std::to_string(i);
ctx->description = desc; ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
......
...@@ -6,108 +6,101 @@ Subject: [PATCH] interleave multi rope ...@@ -6,108 +6,101 @@ Subject: [PATCH] interleave multi rope
since ollama doesn't use mrope for anything else, change it to mean the since ollama doesn't use mrope for anything else, change it to mean the
interleaved version used for qwen3vl interleaved version used for qwen3vl
--- ---
ggml/src/ggml-cpu/ops.cpp | 7 ++----- ggml/src/ggml-cpu/ops.cpp | 8 ++++----
ggml/src/ggml-cuda/rope.cu | 12 +++--------- ggml/src/ggml-cuda/rope.cu | 8 ++++----
ggml/src/ggml-metal/ggml-metal.metal | 10 +++------- ggml/src/ggml-metal/ggml-metal.metal | 8 ++++----
ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp | 12 +++--------- ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl | 8 ++++----
4 files changed, 11 insertions(+), 30 deletions(-) 4 files changed, 16 insertions(+), 16 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 902fdad69..70955347d 100644 index 40666bab6..3155cb4bb 100644
--- a/ggml/src/ggml-cpu/ops.cpp --- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp
@@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init( @@ -5599,14 +5599,14 @@ static void ggml_mrope_cache_init(
}
float theta = theta_t; float theta = theta_t;
- if (sector >= sections[0] && sector < sec_w) { if (is_imrope) { // qwen3vl apply interleaved mrope
+ if (sector % 3 == 1 && sector < 1 + 3 * sections[1]) { - if (sector % 3 == 1 && sector < 3 * sections[1]) {
theta = theta_h; + if (sector % 3 == 1 && sector < 1 + 3 * sections[1]) {
} theta = theta_h;
- else if (sector >= sec_w && sector < sec_w + sections[2]) { - } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
+ else if (sector % 3 == 2 && sector < 2 + 3 * sections[2]) { + } else if (sector % 3 == 2 && sector < 2 + 3 * sections[2]) {
theta = theta_w; theta = theta_w;
} } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
- else if (sector >= sec_w + sections[2]) { theta = theta_t;
- theta = theta_e; - } else {
- } - theta = theta_e;
+ // } else {
rope_yarn( + // theta = theta_e;
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] }
} else {
if (sector >= sections[0] && sector < sec_w) {
diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu
index d058504cd..287fe9d2c 100644 index 88ed79111..71ca60214 100644
--- a/ggml/src/ggml-cuda/rope.cu --- a/ggml/src/ggml-cuda/rope.cu
+++ b/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu
@@ -151,19 +151,13 @@ static __global__ void rope_multi( @@ -200,14 +200,14 @@ static __global__ void rope_multi(
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
- float theta_base = 0.0;
- if (sector < sections.v[0]) {
- theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
- }
- else if (sector >= sections.v[0] && sector < sec_w) {
+ float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+ if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
- else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+ else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
- else if (sector >= sec_w + sections.v[2]) {
- theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
- }
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float theta_base = 0.0;
if (is_imrope) {
- if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
+ if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) { // h
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
- } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
+ } else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) { // w
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
- } else {
- theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+ // } else {
+ // theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
} else {
if (sector < sections.v[0]) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 50b8071de..65a3183c8 100644 index aed013a9d..a489de435 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal --- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3888,15 +3888,11 @@ kernel void kernel_rope_multi( @@ -4009,14 +4009,14 @@ kernel void kernel_rope_multi(
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
- float theta_base;
- if (sector < args.sect_0) {
- theta_base = (float) pos[i2];
- } else if (sector < sec_w01) {
+ float theta_base = (float) pos[i2];
+ if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) {
theta_base = (float) pos[i2 + args.ne02];
- } else if (sector < sec_w012) {
+ } else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) {
theta_base = (float) pos[i2 + args.ne02 * 2];
- } else {
- theta_base = (float) pos[i2 + args.ne02 * 3];
}
// end of mrope
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp float theta_base;
index 111286b49..633dc20ff 100644 if (FC_rope_is_imrope) {
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp - if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp + if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { // h
@@ -31,19 +31,13 @@ void main() { theta_base = (float) pos[i2 + args.ne02 * 1];
const int sec_w = p.sections[1] + p.sections[0]; - } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
const uint sector = (i0 / 2) % sect_dims; + } else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) { // w
theta_base = (float) pos[i2 + args.ne02 * 2];
- float theta_base = 0.0; } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
- if (sector < p.sections[0]) { theta_base = (float) pos[i2 + args.ne02 * 0];
- theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } else { // e
- } - theta_base = (float) pos[i2 + args.ne02 * 3];
- else if (sector >= p.sections[0] && sector < sec_w) { + // } else { // e
+ float theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + // theta_base = (float) pos[i2 + args.ne02 * 3];
+ if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) { }
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); } else {
} if (sector < args.sect_0) {
- else if (sector >= sec_w && sector < sec_w + p.sections[2]) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
+ else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) { index 9726b722d..1c8c69422 100644
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
} +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
- else if (sector >= sec_w + p.sections[2]) { @@ -148,14 +148,14 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
- theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
- }
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float theta_base = 0.0;
if (p.is_imrope != 0) {
- if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
+ if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) {
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
- } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
+ } else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) {
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
- } else {
- theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+ //} else {
+ // theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
} else {
if (sector < p.sections[0]) {
...@@ -6,13 +6,13 @@ Subject: [PATCH] Add memory detection using DXGI + PDH ...@@ -6,13 +6,13 @@ Subject: [PATCH] Add memory detection using DXGI + PDH
--- ---
ggml/src/CMakeLists.txt | 1 + ggml/src/CMakeLists.txt | 1 +
ggml/src/ggml-impl.h | 3 + ggml/src/ggml-impl.h | 3 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 29 ++- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 26 ++-
ggml/src/mem_dxgi_pdh.cpp | 297 +++++++++++++++++++++++++++ ggml/src/mem_dxgi_pdh.cpp | 297 +++++++++++++++++++++++++++
4 files changed, 327 insertions(+), 3 deletions(-) 4 files changed, 325 insertions(+), 2 deletions(-)
create mode 100644 ggml/src/mem_dxgi_pdh.cpp create mode 100644 ggml/src/mem_dxgi_pdh.cpp
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 03f359ae9..4b3e5efb5 100644 index ac8f38464..faa1beed2 100644
--- a/ggml/src/CMakeLists.txt --- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt
@@ -211,6 +211,7 @@ add_library(ggml-base @@ -211,6 +211,7 @@ add_library(ggml-base
...@@ -22,9 +22,9 @@ index 03f359ae9..4b3e5efb5 100644 ...@@ -22,9 +22,9 @@ index 03f359ae9..4b3e5efb5 100644
+ mem_dxgi_pdh.cpp + mem_dxgi_pdh.cpp
gguf.cpp) gguf.cpp)
target_include_directories(ggml-base PRIVATE .) set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 44ae76d66..639d551a2 100644 index 1c07e767a..0da3e065b 100644
--- a/ggml/src/ggml-impl.h --- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h
@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release(); @@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
...@@ -38,10 +38,10 @@ index 44ae76d66..639d551a2 100644 ...@@ -38,10 +38,10 @@ index 44ae76d66..639d551a2 100644
#ifdef __cplusplus #ifdef __cplusplus
} }
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index ca02ea079..c12b069e5 100644 index a234eda2e..c98f98c73 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); @@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16" #define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000) #define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000) #define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
...@@ -49,7 +49,7 @@ index ca02ea079..c12b069e5 100644 ...@@ -49,7 +49,7 @@ index ca02ea079..c12b069e5 100644
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
VkStructureType sType; VkStructureType sType;
@@ -12802,6 +12803,7 @@ struct ggml_backend_vk_device_context { @@ -13655,6 +13656,7 @@ struct ggml_backend_vk_device_context {
std::string pci_id; std::string pci_id;
std::string id; std::string id;
std::string uuid; std::string uuid;
...@@ -57,8 +57,8 @@ index ca02ea079..c12b069e5 100644 ...@@ -57,8 +57,8 @@ index ca02ea079..c12b069e5 100644
int major; int major;
int minor; int minor;
int driver_major; int driver_major;
@@ -12817,8 +12819,22 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size @@ -13673,6 +13675,20 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
vk::PhysicalDeviceProperties2 props2; vk::PhysicalDeviceProperties2 props2;
vkdev.getProperties2(&props2); vkdev.getProperties2(&props2);
+ GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: uuid %s\n", ctx->uuid.c_str()); + GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: uuid %s\n", ctx->uuid.c_str());
...@@ -76,22 +76,17 @@ index ca02ea079..c12b069e5 100644 ...@@ -76,22 +76,17 @@ index ca02ea079..c12b069e5 100644
+ ggml_dxgi_pdh_release(); + ggml_dxgi_pdh_release();
+ } + }
- if (!ctx->is_integrated_gpu) if (!is_integrated_gpu)
+ if (!ctx->is_integrated_gpu)
{ {
// Use vendor specific management libraries for best VRAM reporting if available @@ -13704,7 +13720,6 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
switch (props2.properties.vendorID) {
@@ -12846,8 +12862,8 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
break;
}
} }
- // else fallback to memory budget if supported // else fallback to memory budget if supported
+ // else fallback to memory budget if supported -
*total = 0; if (membudget_supported) {
*free = 0; memprops.pNext = &budgetprops;
vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; }
@@ -13500,7 +13516,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, @@ -14440,7 +14455,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
/* .reg = */ reg, /* .reg = */ reg,
/* .context = */ ctx, /* .context = */ ctx,
}); });
...@@ -99,7 +94,7 @@ index ca02ea079..c12b069e5 100644 ...@@ -99,7 +94,7 @@ index ca02ea079..c12b069e5 100644
// Gather additional information about the device // Gather additional information about the device
int dev_idx = vk_instance.device_indices[i]; int dev_idx = vk_instance.device_indices[i];
vk::PhysicalDeviceProperties props1; vk::PhysicalDeviceProperties props1;
@@ -13523,6 +13538,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, @@ -14463,6 +14477,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
} }
} }
ctx->uuid = oss.str(); ctx->uuid = oss.str();
......
...@@ -10,10 +10,10 @@ fallback to cpu ...@@ -10,10 +10,10 @@ fallback to cpu
1 file changed, 3 insertions(+) 1 file changed, 3 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f1a20e7fe..1a71e07c9 100644 index 736d47c07..7350f6758 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3677,6 +3677,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g @@ -4564,6 +4564,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false; return false;
} }
......
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 03:53:04 -0500
Subject: [PATCH] vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy
(#16793)
This lets the copy to the destination device use the host-visible
vidmem optimization.
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index c12b069e5..76c78c2ea 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
// Copy device to device
ggml_vk_ensure_sync_staging_buffer(src->device, size);
- ggml_vk_ensure_sync_staging_buffer(dst->device, size);
// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
- // memcpy to dst staging buffer
- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
// Copy to dst buffer
- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
+ ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
}
}
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Ruben Ortlam <picard12@live.de>
Date: Wed, 29 Oct 2025 14:39:03 +0100
Subject: [PATCH] Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
* vulkan: add mmq q2_k integer dot support
* Refactor mmq caching
* Reduce mmq register use
* Load 4 quant blocks into shared memory in one step
* Pack q2_k blocks into caches of 32
* Use 32-bit accumulators for integer dot matmul
* Add q4_k mmq
* Add q3_k mmq
* Add q5_k mmq
* Add q6_k mmq
* Add mxfp4 mmq, enable MMQ MUL_MAT_ID
* Fix mmv dm loads
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 165 +++++-
.../vulkan-shaders/dequant_funcs.glsl | 10 +-
.../vulkan-shaders/dequant_funcs_cm2.glsl | 6 +-
.../vulkan-shaders/dequant_mxfp4.comp | 4 +-
.../vulkan-shaders/dequant_q2_k.comp | 4 +-
.../vulkan-shaders/dequant_q4_k.comp | 4 +-
.../vulkan-shaders/dequant_q5_k.comp | 4 +-
.../vulkan-shaders/mul_mat_vec_q2_k.comp | 6 +-
.../vulkan-shaders/mul_mat_vec_q4_k.comp | 6 +-
.../vulkan-shaders/mul_mat_vec_q5_k.comp | 6 +-
.../ggml-vulkan/vulkan-shaders/mul_mm.comp | 72 +--
.../vulkan-shaders/mul_mm_funcs.glsl | 14 +-
.../vulkan-shaders/mul_mm_id_funcs.glsl | 70 +++
.../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 288 +++-------
.../vulkan-shaders/mul_mmq_funcs.glsl | 538 ++++++++++++++++--
.../vulkan-shaders/mul_mmq_shmem_types.glsl | 78 +++
.../src/ggml-vulkan/vulkan-shaders/types.glsl | 53 +-
.../vulkan-shaders/vulkan-shaders-gen.cpp | 5 +-
18 files changed, 928 insertions(+), 405 deletions(-)
create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 76c78c2ea..7669ed206 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -488,6 +488,7 @@ struct vk_device_struct {
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_quantize_q8_1;
@@ -2449,8 +2450,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
+ l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
- l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
+ l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
+ l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
@@ -2513,10 +2517,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+ // Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
+ // K-quants use even more registers, mitigate by setting WMITER to 1
+ l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
+ m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
+ s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
+
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
@@ -2525,10 +2535,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
+ l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
+ m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
+ s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
+
+ l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
+ m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
+ s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
+
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
- m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
+ m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2913,18 +2931,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
-#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _m[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _s[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
// Create 2 variants, {f16,f32} accumulator
@@ -2963,11 +2978,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
}
#endif
@@ -2997,6 +3020,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ if (device->integer_dot_product) {
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ }
+#endif
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
@@ -3023,6 +3064,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ if (device->integer_dot_product) {
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ }
+#endif
}
#undef CREATE_MM2
#undef CREATE_MMQ
@@ -3087,6 +3146,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
}
#endif
@@ -3146,7 +3211,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
-#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
&& !device->coopmat_bf16_support
#endif
) {
@@ -4930,7 +4995,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
- vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
return nullptr;
@@ -5077,6 +5142,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
}
+ // MMQ
+ if (src1_type == GGML_TYPE_Q8_1) {
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
+
+ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ return nullptr;
+ }
+
+ return pipelines;
+ }
+
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
switch (src0_type) {
@@ -6879,10 +6955,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
+
+ // Check for mmq first
+ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
+
+ if (mmp == nullptr) {
+ // Fall back to f16 dequant mul mat
+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
+ quantize_y = false;
+ }
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
- const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat
@@ -6892,8 +6977,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
- const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
+ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
@@ -6906,12 +6991,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
const uint64_t ids_sz = nbi2;
const uint64_t d_sz = sizeof(float) * d_ne;
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
+ vk_pipeline to_q8_1 = nullptr;
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
@@ -6926,9 +7012,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
+ if (quantize_y) {
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
+ }
+
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
+ }
if (
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
@@ -6937,7 +7030,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
ctx->prealloc_size_x = x_sz_upd;
}
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = y_sz_upd;
}
@@ -6949,6 +7042,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
}
+ if (quantize_y) {
+ ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
+ }
return;
}
@@ -6985,6 +7081,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
+ } else if (quantize_y) {
+ d_Y = ctx->prealloc_y;
+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
@@ -7016,6 +7115,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_y_last_tensor_used = src1;
}
}
+ if (quantize_y) {
+ if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
+ ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
+ ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
+ ctx->prealloc_y_last_tensor_used = src1;
+ }
+ }
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
@@ -7024,14 +7134,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
}
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
+ uint32_t y_sz_total = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
+ }
+
// compute
ggml_vk_matmul_id(
ctx, subctx, pipeline,
- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
index 0d98f5a9d..09676a623 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
+ return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
- const vec2 d = vec2(data_a[a_offset + ib].d);
+ const vec2 dm = vec2(data_a[a_offset + ib].dm);
- return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+ return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
- const vec2 loadd = vec2(data_a[a_offset + ib].d);
+ const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint8_t hm = uint8_t(1 << (iqs / 16));
- const vec2 loadd = vec2(data_a[a_offset + ib].d);
+ const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
index 67baedf7c..8ac6482dc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
- const f16vec2 d = bl.block.d;
+ const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
- float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
+ float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
return ret;
}
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
- float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
+ float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
}
#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
index ffba5a77d..3194ba291 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
@@ -26,7 +26,7 @@ void main() {
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
- data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
- data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
+ data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
+ data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
index 58dc2e5df..dc05a7834 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
@@ -24,8 +24,8 @@ void main() {
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
- FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+ FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
+ FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
index 8b7be557e..0f23dc0a3 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
@@ -20,8 +20,8 @@ void main() {
const uint is = 2 * il;
const uint n = 4;
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
index 6bc04670f..970469a60 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
@@ -19,8 +19,8 @@ void main() {
const uint ir = tid % 16;
const uint is = 2 * il;
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
index 03ed25d3b..14093c0de 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
- temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
+ temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
}
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
index 21d07d2e5..49d91ad59 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+ temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
index 9e46c89a1..0d61b4966 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+ temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
index a20788c4b..d260969f0 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[BN];
-uint _ne1;
-
-#ifdef MUL_MAT_ID_USE_SUBGROUPS
-shared uvec4 ballots_sh[NUM_WARPS];
-
-void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
- _ne1 = 0;
- uint num_elements = p.nei1 * p.nei0;
- uint nei0shift = findLSB(p.nei0);
-
- uint ids[16];
- uint iter = 0;
-
- for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
- // prefetch up to 16 elements
- if (iter == 0) {
- [[unroll]] for (uint k = 0; k < 16; ++k) {
- uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
- }
- }
- uint i = j + gl_LocalInvocationIndex;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- uint id = ids[iter++];
- uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
-
- ballots_sh[gl_SubgroupID] = ballot;
- barrier();
-
- uint subgroup_base = 0;
- uint total = 0;
- for (uint k = 0; k < gl_NumSubgroups; ++k) {
- if (k == gl_SubgroupID) {
- subgroup_base = total;
- }
- total += subgroupBallotBitCount(ballots_sh[k]);
- }
- barrier();
-
- uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
- row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
- }
- _ne1 += total;
- iter &= 15;
- if (_ne1 >= (ic + 1) * BN) {
- break;
- }
- }
- barrier();
-}
-#endif // MUL_MAT_ID_USE_SUBGROUPS
-#endif // MUL_MAT_ID
-
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
+#include "mul_mm_id_funcs.glsl"
#include "mul_mm_funcs.glsl"
void main() {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
index 0ebfbd646..ee5ded2e8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
- const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
+ const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
- const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
+ const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi];
- const vec2 d = vec2(data_a[ib].d);
+ const vec2 dm = vec2(data_a[ib].dm);
- const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+ const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
- const vec2 loadd = vec2(data_a[ib].d);
+ const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint8_t hm = uint8_t(1 << (iqs / 16));
- const vec2 loadd = vec2(data_a[ib].d);
+ const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
- const float d = e8m0_to_fp32(data_a[ib].e);
+ const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
new file mode 100644
index 000000000..1d0e84ac9
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
@@ -0,0 +1,70 @@
+#ifdef MUL_MAT_ID
+shared u16vec2 row_ids[BN];
+uint _ne1;
+
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+shared uvec4 ballots_sh[NUM_WARPS];
+
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+ uint nei0shift = findLSB(p.nei0);
+
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_LocalInvocationIndex;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+ ballots_sh[gl_SubgroupID] = ballot;
+ barrier();
+
+ uint subgroup_base = 0;
+ uint total = 0;
+ for (uint k = 0; k < gl_NumSubgroups; ++k) {
+ if (k == gl_SubgroupID) {
+ subgroup_base = total;
+ }
+ total += subgroupBallotBitCount(ballots_sh[k]);
+ }
+ barrier();
+
+ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
+ }
+ _ne1 += total;
+ iter &= 15;
+ if (_ne1 >= (ic + 1) * BN) {
+ break;
+ }
+ }
+ barrier();
+}
+#endif // MUL_MAT_ID_USE_SUBGROUPS
+#endif // MUL_MAT_ID
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index b5d761c0b..8b238ac4b 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -10,10 +10,9 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
-#ifdef COOPMAT
-#extension GL_KHR_cooperative_matrix : enable
-#extension GL_KHR_memory_scope_semantics : enable
+#if defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
@@ -24,7 +23,10 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
@@ -76,40 +78,27 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32
-#ifdef COOPMAT
-#define SHMEM_STRIDE (BK / 4 + 4)
-#else
-#define SHMEM_STRIDE (BK / 4 + 1)
-#endif
+#define MMQ_SHMEM
-shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
+#include "mul_mmq_shmem_types.glsl"
-#ifndef COOPMAT
-#if QUANT_AUXF == 1
-shared FLOAT_TYPE buf_a_dm[BM];
-#else
-shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
-#endif
+#ifndef BK_STEP
+#define BK_STEP 4
#endif
-shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
-#ifndef COOPMAT
-shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
-#endif
+// Shared memory cache
+shared block_a_cache buf_a[BM * BK_STEP];
+shared block_b_cache buf_b[BN * BK_STEP];
+// Register cache
+block_a_cache cache_a[WMITER * TM];
+block_b_cache cache_b;
-#define LOAD_VEC_A (4 * QUANT_R)
+#define LOAD_VEC_A (4 * QUANT_R_MMQ)
#define LOAD_VEC_B 16
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[4096];
-#endif // MUL_MAT_ID
-
#define NUM_WARPS (BLOCK_SIZE / WARP)
-#ifdef COOPMAT
-shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
-#endif
-
+#include "mul_mm_id_funcs.glsl"
#include "mul_mmq_funcs.glsl"
void main() {
@@ -139,26 +128,12 @@ void main() {
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
-
-#ifdef COOPMAT
- const uint warp_i = gl_SubgroupID;
-
- const uint tiw = gl_SubgroupInvocationID;
-
- const uint cms_per_row = WM / TM;
- const uint cms_per_col = WN / TN;
-
- const uint storestride = WARP / TM;
- const uint store_r = tiw % TM;
- const uint store_c = tiw / TM;
-#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
-#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
@@ -172,17 +147,27 @@ void main() {
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
- uint _ne1 = 0;
- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+ if (bitCount(p.nei0) == 1) {
+ load_row_ids(expert_idx, true, ic);
+ } else {
+ load_row_ids(expert_idx, false, ic);
+ }
+#else
+ _ne1 = 0;
+ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
+ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
- row_ids[_ne1] = u16vec2(ii0, ii1);
+ if (_ne1 >= ic * BN) {
+ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
+ }
_ne1++;
}
}
}
barrier();
+#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
@@ -209,159 +194,70 @@ void main() {
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
-#ifdef COOPMAT
- coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
- coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
- coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
-
- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
-
- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
-
- [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
- sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
- }
-#else
- int32_t cache_a_qs[WMITER * TM * BK / 4];
-
- int32_t cache_b_qs[TN * BK / 4];
-
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
-#endif
-#if QUANT_AUXF == 1
- FLOAT_TYPE cache_a_dm[WMITER * TM];
-#else
- FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
-#endif
-
- FLOAT_TYPE_VEC2 cache_b_ds[TN];
-
- for (uint block = start_k; block < end_k; block += BK) {
+ for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
- const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
- const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
+ const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
+ const uint iqs = loadr_a;
- if (iqs == 0) {
-#if QUANT_AUXF == 1
- buf_a_dm[buf_ib] = get_d(ib);
-#else
- buf_a_dm[buf_ib] = get_dm(ib);
-#endif
+ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
-#if QUANT_R == 1
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
-#else
- const i32vec2 vals = repack(ib, iqs);
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
-#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
+ const uint buf_ib = loadc_b + l;
+
#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
- const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
- const uint ib = idx / 8;
- const uint iqs = idx & 0x7;
+ const u16vec2 row_idx = row_ids[buf_ib];
+ const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
#else
- const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
- const uint ib_outer = ib / 4;
- const uint ib_inner = ib % 4;
-
- const uint iqs = loadr_b;
+ const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
#endif
+ const uint iqs = loadr_b;
- const uint buf_ib = loadc_b + l;
-
- if (iqs == 0) {
- buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
+ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
- const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
}
barrier();
- pos_a_ib += 1;
- pos_b_ib += 1;
+ pos_a_ib += BK_STEP;
+ pos_b_ib += BK_STEP;
-#ifdef COOPMAT
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- const uint ib_a = warp_r * WM + cm_row * TM;
+ for (uint k_step = 0; k_step < BK_STEP; k_step++) {
// Load from shared into cache
- coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
-
- // TODO: only cache values that are actually needed
- [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
- cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
- }
-
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- const uint ib_b = warp_c * WN + cm_col * TN;
- coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
-
- // TODO: only cache values that are actually needed
- [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
- cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
- }
-
- cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
- cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
- }
-
- coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
- sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
- }
- }
-#else
- // Load from shared into cache
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
- cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
- }
- }
- }
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint reg_ib = wsir * TM + cr;
+ const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
- cache_b_ds[cc] = buf_b_ds[ib];
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
+ block_a_to_registers(reg_ib, k_step * BM + buf_ib);
}
}
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint cache_a_idx = wsir * TM + cr;
- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
- int32_t q_sum = 0;
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
- cache_b_qs[cc * (BK / 4) + idx_k]);
- }
+ const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
+ block_b_to_registers(ib);
- sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint cache_a_idx = wsir * TM + cr;
+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+
+ sums[sums_idx] += mmq_dot_product(cache_a_idx);
+ }
}
}
}
}
-#endif
barrier();
}
@@ -373,54 +269,6 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
-#ifdef COOPMAT
-#ifdef MUL_MAT_ID
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < BN; col += storestride) {
- const uint row_i = dc + cm_col * TN + col + store_c;
- if (row_i >= _ne1) break;
-
- const u16vec2 row_idx = row_ids[row_i];
-
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- }
- }
-#else
- const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
-
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
-
- if (is_aligned && is_in_bounds) {
- // Full coopMat is within bounds and stride_d is aligned with 16B
- coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
- coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
- } else if (is_in_bounds) {
- // Full coopMat is within bounds, but stride_d is not aligned
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
- // Partial coopMat is within bounds
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- }
- }
- }
- }
-#endif // MUL_MAT_ID
-#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@@ -431,19 +279,21 @@ void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
#ifdef MUL_MAT_ID
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ if (dr_warp + cr < p.M) {
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
+ }
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#endif // MUL_MAT_ID
}
}
}
}
-#endif // COOPMAT
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
index fe71eb131..c0c03fedc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
@@ -6,41 +6,89 @@
// Each iqs value maps to a 32-bit integer
-#if defined(DATA_A_Q4_0)
+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
+// 2-byte loads for Q4_0 blocks (18 bytes)
+// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
- const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]);
+#ifdef DATA_A_Q4_0
+ const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
+#else // DATA_A_Q4_1
+ const uint32_t vui = data_a_packed32[ib].qs[iqs];
+ return i32vec2( vui & 0x0F0F0F0F,
+ (vui >> 4) & 0x0F0F0F0F);
+#endif
}
+#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
+#else // DATA_A_Q4_1
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+}
#endif
-#if defined(DATA_A_Q4_1)
-i32vec2 repack(uint ib, uint iqs) {
- // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
- const uint32_t vui = data_a_packed32[ib].qs[iqs];
- return i32vec2( vui & 0x0F0F0F0F,
- (vui >> 4) & 0x0F0F0F0F);
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+#ifdef DATA_A_Q4_0
+ buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ }
+#else // DATA_A_Q4_1
+ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+ }
+#endif
}
-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
}
-#endif
-#if defined(DATA_A_Q5_0)
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const uint32_t vui = cache_a[ib_a].qs[iqs];
+ const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
+ (vui >> 4) & 0x0F0F0F0F);
+
+ const int32_t qs_b0 = cache_b.qs[iqs];
+ const int32_t qs_b1 = cache_b.qs[iqs + 4];
+
+ q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
+ q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+
+#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
+// 2-byte loads for Q5_0 blocks (22 bytes)
+// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
- const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]);
+ const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
- const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
+#ifdef DATA_A_Q5_0
+ const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
+#else // DATA_A_Q5_1
+ const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
+#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1);
}
+#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
+#else // DATA_A_Q5_1
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+}
#endif
-#if defined(DATA_A_Q5_1)
-i32vec2 repack(uint ib, uint iqs) {
- // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
- const uint32_t vui = data_a_packed32[ib].qs[iqs];
- const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
- const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
- | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+#ifdef DATA_A_Q5_0
+ buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
- const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
- | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
+ }
+#else // DATA_A_Q5_1
+ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
- return i32vec2(v0, v1);
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+ buf_a[buf_ib].qh = data_a_packed32[ib].qh;
+ }
+#endif
}
-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+ cache_a[reg_ib].qh = buf_a[buf_ib].qh;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const uint32_t vui = cache_a[ib_a].qs[iqs];
+ const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
+ const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
+ | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+ const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
+ | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+
+ const int32_t qs_b0 = cache_b.qs[iqs];
+ const int32_t qs_b1 = cache_b.qs[iqs + 4];
+
+ q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
+ q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
+// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
- return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]));
+ return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+ const int32_t qs_b = cache_b.qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, qs_b);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_MXFP4)
+// 1-byte loads for mxfp4 blocks (17 bytes)
+i32vec2 repack(uint ib, uint iqs) {
+ const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
+ data_a[ib].qs[iqs * 4 + 1],
+ data_a[ib].qs[iqs * 4 + 2],
+ data_a[ib].qs[iqs * 4 + 3]));
+
+ return i32vec2( quants & 0x0F0F0F0F,
+ (quants >> 4) & 0x0F0F0F0F);
+}
+
+ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(da * dsb.x * float(q_sum));
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
+ data_a[ib].qs[iqs * 4 + 1],
+ data_a[ib].qs[iqs * 4 + 2],
+ data_a[ib].qs[iqs * 4 + 3]));
+
+ const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
+ const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
+
+ buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
+ buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d = buf_a[buf_ib].d;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
+// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
+#if defined(DATA_A_Q2_K)
+// 4-byte loads for Q2_K blocks (84 bytes)
+int32_t repack(uint ib, uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+
+ return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
+}
+
+uint8_t get_scale(uint ib, uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ return data_a[ib_k].scales[iqs_k / 4];
+}
+
+ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+
+ // Repack 4x4 quants into one int
+ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
+ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
+ const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
+ const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
+
+ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
+ buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+ cache_a[reg_ib].scales = buf_a[buf_ib].scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t sum_d = 0;
+ int32_t sum_m = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
+ const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
+ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
+
+ sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
+ sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_Q3_K)
+// 2-byte loads for Q3_K blocks (110 bytes)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint hm_idx = iqs * QUANT_R_MMQ;
+ const uint iqs_k = (ib % 8) * 8 + hm_idx;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+ const uint hm_shift = iqs_k / 8;
+
+ // Repack 2x4 quants into one int
+ // Add the 3rd bit instead of subtracting it to allow packing the quants
+ const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
+ buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
+ (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
+
+ if (iqs == 0) {
+ const uint is = iqs_k / 4;
+ const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
+ (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
+
+ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ float result = 0.0;
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ // Subtract 4 from the quants to correct the 3rd bit offset
+ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
+ q_sum = 0;
+
+ [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
+ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
+
+ return ACC_TYPE(cache_b.ds.x * result);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
+// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
+
+ const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 16) / 8) * 4;
+
+ // Repack 2x4 quants into one int
+#if defined(DATA_A_Q4_K)
+ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
+ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
+
+ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
+#else // defined(DATA_A_Q5_K)
+ const uint qh_idx = iqs * QUANT_R_MMQ;
+ const uint qh_shift = iqs_k / 8;
+
+ buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
+ (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
+#endif
+
+
+ if (iqs == 0) {
+ // Scale index
+ const uint is = iqs_k / 8;
+ u8vec2 scale_dm;
+ if (is < 4) {
+ scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
+ } else {
+ scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
+ (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
+ }
+
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+#if defined(DATA_A_Q4_K)
+ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
+#else // defined(DATA_A_Q5_K)
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+#endif
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#ifdef MMQ_SHMEM
+void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_outer = ib / 4;
+ const uint ib_inner = ib % 4;
+
+ if (iqs == 0) {
+ buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
+ }
+
+ const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
+ buf_b[buf_ib].qs[iqs * 4 ] = values.x;
+ buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
+ buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
+ buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
+}
+
+void block_b_to_registers(const uint ib) {
+ cache_b.ds = buf_b[ib].ds;
+ [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
+ cache_b.qs[iqs] = buf_b[ib].qs[iqs];
+ }
+}
+#endif
+
+#if defined(DATA_A_Q6_K)
+// 2-byte loads for Q6_K blocks (210 bytes)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
+ const uint ql_shift = ((iqs_k % 32) / 16) * 4;
+
+ const uint qh_idx = (iqs_k / 32) * 8 + iqs;
+ const uint qh_shift = ((iqs_k % 32) / 8) * 2;
+
+ const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
+ const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
+ buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
+
+ if (iqs == 0) {
+ const uint is = iqs_k / 4;
+ const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
+
+ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ float result = 0.0;
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
+ q_sum = 0;
+
+ [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
+
+ return ACC_TYPE(cache_b.ds.x * result);
+}
+#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
+
+#if defined(DATA_A_Q2_K)
+FLOAT_TYPE_VEC2 get_dm(uint ib) {
+ const uint ib_k = ib / 8;
+ return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
+}
+#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
new file mode 100644
index 000000000..72fec4404
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
@@ -0,0 +1,78 @@
+#if defined(DATA_A_Q4_0)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_Q4_1)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q5_0)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ uint32_t qh;
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_Q5_1)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ uint32_t qh;
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q8_0)
+#define QUANT_R_MMQ 1
+// AMD likes 4, Intel likes 1 and Nvidia likes 2
+#define BK_STEP 1
+struct block_a_cache {
+ int32_t qs[32/4];
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_MXFP4)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE d;
+};
+#elif defined(DATA_A_Q2_K)
+#define QUANT_R_MMQ 4
+struct block_a_cache {
+ uint32_t qs[2];
+ u8vec2 scales;
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q3_K)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[4];
+ FLOAT_TYPE_VEC2 d_scales;
+};
+#elif defined(DATA_A_Q4_K)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[4];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q5_K)
+#define QUANT_R_MMQ 1
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q6_K)
+#define QUANT_R_MMQ 1
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 d_scales;
+};
+#endif
+
+struct block_b_cache
+{
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 ds;
+};
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
index 2fa54ce51..02578c77c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
@@ -66,6 +66,7 @@ struct block_q4_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q4_1 32
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_0 32
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_1 32
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_0 32
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_1 32
@@ -226,21 +231,21 @@ struct block_q2_K
{
uint8_t scales[QUANT_K_Q2_K/16];
uint8_t qs[QUANT_K_Q2_K/4];
- f16vec2 d;
+ f16vec2 dm;
};
struct block_q2_K_packed16
{
uint16_t scales[QUANT_K_Q2_K/16/2];
uint16_t qs[QUANT_K_Q2_K/4/2];
- f16vec2 d;
+ f16vec2 dm;
};
struct block_q2_K_packed32
{
uint32_t scales[QUANT_K_Q2_K/16/4];
uint32_t qs[QUANT_K_Q2_K/4/4];
- f16vec2 d;
+ f16vec2 dm;
};
#if defined(DATA_A_Q2_K)
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
+#define SCALES_PER_32 2
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q3_K 256
@@ -274,27 +281,28 @@ struct block_q3_K_packed16
#define QUANT_R 1
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q4_K 256
struct block_q4_K
{
- f16vec2 d;
+ f16vec2 dm;
uint8_t scales[3*QUANT_K_Q4_K/64];
uint8_t qs[QUANT_K_Q4_K/2];
};
struct block_q4_K_packed16
{
- f16vec2 d;
+ f16vec2 dm;
uint16_t scales[3*QUANT_K_Q4_K/64/2];
uint16_t qs[QUANT_K_Q4_K/2/2];
};
struct block_q4_K_packed32
{
- f16vec2 d;
+ f16vec2 dm;
uint32_t scales[3*QUANT_K_Q4_K/64/4];
uint32_t qs[QUANT_K_Q4_K/2/4];
};
@@ -310,13 +318,14 @@ struct block_q4_K_packed128
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q5_K 256
struct block_q5_K
{
- f16vec2 d;
+ f16vec2 dm;
uint8_t scales[12];
uint8_t qh[QUANT_K_Q5_K/8];
uint8_t qs[QUANT_K_Q5_K/2];
@@ -324,12 +333,20 @@ struct block_q5_K
struct block_q5_K_packed16
{
- f16vec2 d;
+ f16vec2 dm;
uint16_t scales[12/2];
uint16_t qh[QUANT_K_Q5_K/8/2];
uint16_t qs[QUANT_K_Q5_K/2/2];
};
+struct block_q5_K_packed32
+{
+ f16vec2 dm;
+ uint32_t scales[12/4];
+ uint32_t qh[QUANT_K_Q5_K/8/4];
+ uint32_t qs[QUANT_K_Q5_K/2/4];
+};
+
struct block_q5_K_packed128
{
uvec4 q5k[11];
@@ -340,6 +357,8 @@ struct block_q5_K_packed128
#define QUANT_R 1
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
+#define A_TYPE_PACKED32 block_q5_K_packed32
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q6_K 256
@@ -356,7 +375,7 @@ struct block_q6_K_packed16
{
uint16_t ql[QUANT_K_Q6_K/2/2];
uint16_t qh[QUANT_K_Q6_K/4/2];
- int8_t scales[QUANT_K_Q6_K/16];
+ int16_t scales[QUANT_K_Q6_K/16/2];
float16_t d;
};
@@ -365,6 +384,7 @@ struct block_q6_K_packed16
#define QUANT_R 1
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
+#define DATA_A_QUANT_K
#endif
// IQuants
@@ -1363,18 +1383,11 @@ struct block_mxfp4
uint8_t qs[QUANT_K_MXFP4/2];
};
-//struct block_mxfp4_packed16
-//{
-// uint8_t e;
-// uint16_t qs[QUANT_K_MXFP4/2/2];
-//};
-
#if defined(DATA_A_MXFP4)
#define QUANT_K QUANT_K_MXFP4
#define QUANT_R QUANT_R_MXFP4
#define QUANT_AUXF 1
#define A_TYPE block_mxfp4
-//#define A_TYPE_PACKED16 block_mxfp4_packed16
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize)
#endif
#if defined(DATA_A_MXFP4)
-const FLOAT_TYPE kvalues_mxfp4_const[16] = {
- FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
- FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
+const int8_t kvalues_mxfp4_const[16] = {
+ int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
+ int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
};
-shared FLOAT_TYPE kvalues_mxfp4[16];
+shared int8_t kvalues_mxfp4[16];
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index 0f25ba345..03fa01639 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
- if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
+ // Integer dot mmq performs better with f32 accumulators
+ if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
@@ -574,7 +575,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
void process_shaders() {
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
// matmul
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
...@@ -8,7 +8,7 @@ Subject: [PATCH] win: exit instead of abort ...@@ -8,7 +8,7 @@ Subject: [PATCH] win: exit instead of abort
1 file changed, 6 insertions(+), 1 deletion(-) 1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 9be35c1be..923c33d05 100644 index b99345a2e..1c9e0bc05 100644
--- a/ggml/src/ggml.c --- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c +++ b/ggml/src/ggml.c
@@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { @@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
......
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Tue, 11 Nov 2025 11:39:43 -0800
Subject: [PATCH] fix bakllava regression
Rever to prior logic of assuming an empty projector type is mlp
---
tools/mtmd/clip.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
index f4c4d2c48..3334ff25b 100644
--- a/tools/mtmd/clip.cpp
+++ b/tools/mtmd/clip.cpp
@@ -2648,6 +2648,10 @@ struct clip_model_loader {
if (proj_type.empty()) {
if (modality == CLIP_MODALITY_VISION) {
get_string(KEY_VISION_PROJ_TYPE, proj_type, false);
+ if (proj_type.empty()) {
+ // Assume MLP if no projector type listed
+ proj_type = "mlp";
+ }
} else if (modality == CLIP_MODALITY_AUDIO) {
get_string(KEY_AUDIO_PROJ_TYPE, proj_type, false);
} else {
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 08:44:29 -0500
Subject: [PATCH] vulkan: Update topk_moe fusion to handle gpt's late softmax
(#16656)
* vulkan: Update topk_moe fusion to handle gpt's late softmax
Based on #16649.
* Add ggml_check_edges
* Add sync logging to show fusion effects
* handle clamp added in #16655
* Update ggml/src/ggml-impl.h
Co-authored-by: Diego Devesa <slarengh@gmail.com>
---
ggml/src/ggml-impl.h | 16 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 304 +++++++++++-------
.../ggml-vulkan/vulkan-shaders/topk_moe.comp | 90 ++++--
3 files changed, 272 insertions(+), 138 deletions(-)
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 639d551a2..e5c446d1d 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
#endif
#ifdef __cplusplus
+#include <array>
#include <initializer_list>
#include <vector>
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
+// Return true if the edges in the graph match expectations.
+inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
+ int start_idx,
+ std::initializer_list<std::array<int, 3>> edges) {
+ for (const auto & edge : edges) {
+ int dst_node = edge[0];
+ int src_idx = edge[1];
+ int src_node = edge[2];
+ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
+ return false;
+ }
+ }
+ return true;
+}
+
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 7669ed206..63a762ec2 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
-static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
+ GGML_OP_RESHAPE };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
+//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
+//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
+//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
+//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
+//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
+//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
+//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+ { 5, 0, 4 }, // reshape->src[0] == get_rows
+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
+ { 7, 0, 6 }, // clamp->src[0] == sum_rows
+ { 8, 0, 5 }, // div->src[0] == reshape
+ { 8, 1, 7 }, // div->src[1] == clamp
+ { 9, 0, 8 }, // reshape->src[0] == div
+};
+
+// same as early_softmax_norm but ending after the get_rows
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+};
+//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
+//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
+//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
+//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
+//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
+//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
+ { 1, 0, 0 }, // view->src[0] == argsort
+ { 2, 1, 1 }, // get_rows->src[1] == view
+ { 3, 0, 2 }, // reshape->src[0] == get_rows
+ { 4, 0, 3 }, // soft_max->src[0] == reshape
+ { 5, 0, 4 }, // reshape->src[0] == soft_max
+};
+
+enum topk_moe_mode {
+ TOPK_MOE_EARLY_SOFTMAX,
+ TOPK_MOE_EARLY_SOFTMAX_NORM,
+ TOPK_MOE_LATE_SOFTMAX,
+ TOPK_MOE_COUNT,
+};
+
+static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
+ topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
+ num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
+ TOPK_MOE_LATE_SOFTMAX;
+ return mode;
+}
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -607,8 +671,7 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
- // [2] is {!norm, norm}
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
std::vector<vk_pipeline_ref> all_pipelines;
@@ -956,6 +1019,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
struct vk_op_add_id_push_constants {
@@ -3806,8 +3871,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
}
for (auto &c : compiles) {
@@ -8085,8 +8151,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
- return ctx->device->pipeline_topk_moe[idx][with_norm];
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8141,6 +8207,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
}
case GGML_OP_ARGSORT:
+ if (ctx->num_additional_fused_ops) {
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
+ GGML_ASSERT(idx < num_topk_moe_pipelines);
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
+ }
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
@@ -9676,10 +9749,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
- ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
- ggml_tensor * ids = cgraph->nodes[node_idx + 3];
+ ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
+ (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
+ cgraph->nodes[node_idx + 5];
+ ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
@@ -9738,9 +9813,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
GGML_ASSERT(d_ids != nullptr);
}
- vk_op_topk_moe_push_constants pc;
+ vk_op_topk_moe_push_constants pc {};
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
+ if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+ }
GGML_ASSERT(n_expert_used <= n_experts);
@@ -11335,7 +11415,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+
+#define ENABLE_SYNC_LOGGING 0
+
if (need_sync) {
+#if ENABLE_SYNC_LOGGING
+ std::cerr << "sync" << std::endl;
+#endif
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -11353,6 +11439,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+#if ENABLE_SYNC_LOGGING
+ if (!dryrun) {
+ for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
+ auto *n = cgraph->nodes[node_idx + i];
+ std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
+ if (n->op == GGML_OP_GLU) {
+ std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
+ }
+ std::cerr << std::endl;
+ }
+ }
+#endif
switch (node->op) {
case GGML_OP_REPEAT:
@@ -11531,7 +11629,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ARGSORT:
- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ if (ctx->num_additional_fused_ops) {
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
+ } else {
+ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ }
break;
case GGML_OP_SUM:
@@ -12329,30 +12431,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
- int node_idx, bool with_norm) {
+ int node_idx, topk_moe_mode mode) {
- if (with_norm) {
- if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
- return false;
- }
- }
- } else {
- if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
- return false;
- }
- }
- }
+ const ggml_tensor * softmax;
+ const ggml_tensor * weights;
- const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
- const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
+ switch (mode) {
+ case TOPK_MOE_EARLY_SOFTMAX_NORM:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 9];
+ break;
+ case TOPK_MOE_EARLY_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 4];
+ break;
+ case TOPK_MOE_LATE_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 4];
+ weights = cgraph->nodes[node_idx + 5];
+ break;
+ default:
+ return false;
+ }
const float * op_params = (const float *)softmax->op_params;
@@ -12378,60 +12477,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return false;
}
- // Check that the nodes don't have any unexpected uses
- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
-
- // softmax is used by reshape and argsort
- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
- reshape1->src[0] != softmax ||
- argsort->src[0] != softmax) {
- return false;
- }
- // reshape is used by get_rows
- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
- get_rows->src[0] != reshape1) {
- return false;
- }
- // argsort is used by view
- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
- view->src[0] != argsort) {
- return false;
- }
- // view is written (via argsort), we can skip checking it
-
- if (with_norm) {
- // get_rows is used by reshape
- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
- reshape5->src[0] != get_rows) {
- return false;
- }
-
- // reshape is used by sum_rows and div
- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
- sum_rows->src[0] != reshape5 ||
- div->src[0] != reshape5) {
- return false;
- }
-
- // sum_rows is used by div
- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
- div->src[1] != sum_rows) {
- return false;
- }
-
- // div/reshape are written
- if (reshape8->src[0] != div) {
- return false;
- }
- }
-
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
@@ -12517,10 +12562,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12618,10 +12671,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
@@ -12754,25 +12815,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
- // Avoid reordering topk_moe_norm
- if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
- bool is_topk_moe_norm = true;
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
- if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
- is_topk_moe_norm = false;
+ // Check for fusion patterns and avoid reordering them
+ auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
+ if (start + (int)pattern.size() <= graph->n_nodes) {
+ bool is_pattern = true;
+ for (size_t j = 0; j < pattern.size(); ++j) {
+ if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
+ is_pattern = false;
+ }
}
+ return is_pattern;
}
- if (is_topk_moe_norm) {
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
+ return false;
+ };
+
+ auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
+ if (match_pattern(pattern, first_unused)) {
+ for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
- continue;
+ return true;
}
+ return false;
+ };
+
+ if (keep_pattern(topk_moe_early_softmax_norm)) {
+ continue;
+ }
+ if (keep_pattern(topk_moe_early_softmax)) {
+ continue;
}
+ if (keep_pattern(topk_moe_late_softmax)) {
+ continue;
+ }
+
// First, grab the next unused node.
current_set.push_back(first_unused);
@@ -12790,6 +12870,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (is_empty(graph->nodes[j])) {
continue;
}
+ // Don't pull forward nodes from fusion patterns
+ if (match_pattern(topk_moe_early_softmax_norm, j) ||
+ match_pattern(topk_moe_early_softmax, j) ||
+ match_pattern(topk_moe_late_softmax, j)) {
+ continue;
+ }
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
index 9e56d5f8a..bc1c278bf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
+layout(constant_id = 3) const bool late_softmax = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
-void main() {
- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
- if (row >= n_rows) {
- return;
- }
+const float INFINITY = 1.0 / 0.0;
- const uint logits_offset = n_experts * row;
- const uint weights_offset = n_expert_used * row;
- const uint ids_offset = n_experts * row;
-
- float logits_r[experts_per_thread];
-
- const float INFINITY = 1.0 / 0.0;
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
+ float max_val = -INFINITY;
[[unroll]]
- for (uint i = 0; i < n_experts; i += WARP_SIZE) {
- const uint expert = i + gl_LocalInvocationID.x;
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ max_val = max(max_val, vals[i]);
+ }
}
- float max_val = logits_r[0];
+ max_val = subgroupMax(max_val);
+
+ float sum = 0.f;
[[unroll]]
- for (int i = 1; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- max_val = max(val, max_val);
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ const float val = exp(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
}
- max_val = subgroupMax(max_val);
+ sum = subgroupAdd(sum);
- float wt[experts_per_thread];
- float tmp = 0.f;
+ const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- wt[i] = exp(val - max_val);
- tmp += wt[i];
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ vals[i] *= inv_sum;
+ }
}
+}
- tmp = subgroupAdd(tmp);
+void main() {
+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
+ if (row >= n_rows) {
+ return;
+ }
- const float inv_sum = 1.0f / tmp;
+ const uint logits_offset = n_experts * row;
+ const uint weights_offset = n_expert_used * row;
+ const uint ids_offset = n_experts * row;
+
+ float wt[experts_per_thread];
[[unroll]]
- for (int i = 0; i < experts_per_thread; i++) {
- wt[i] = wt[i] * inv_sum;
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + gl_LocalInvocationID.x;
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+ }
+
+ if (!late_softmax) {
+ softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
}
// at this point, each thread holds a portion of softmax,
@@ -82,6 +104,11 @@ void main() {
float output_weights[experts_per_thread];
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] = 0.f;
+ }
+
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
@@ -121,6 +148,7 @@ void main() {
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
+ wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
@@ -129,6 +157,10 @@ void main() {
}
}
+ if (late_softmax) {
+ softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
+ }
+
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 15:13:10 -0500
Subject: [PATCH] vulkan: Fuse rope+set_rows (#16769)
This pattern appears in a lot of models, the rope operation is applied right
before storing into the KV cache (usually on the K tensor).
Add a path to some of the rope shaders that computes the destination address
based on the set_rows tensor. Compile variants of the shader with D_TYPE of
f16 (the usual KV cache type).
Add a src3 operand to ggml_vk_op_f32 - sometimes rope uses three srcs and needs
the fourth for the row indices.
Add fused_ops_write_mask to indicate which intermediate tensors need to write
their results to memory. Skipping writing the roped K value helps to allow more
nodes to run concurrently.
Add logic to ggml_vk_graph_optimize to make ROPE+VIEW+SET_ROWS consecutive. It
rarely starts out that way in the graph.
Add new backend tests.
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 334 +++++++++++++-----
.../ggml-vulkan/vulkan-shaders/rope_head.glsl | 2 +
.../ggml-vulkan/vulkan-shaders/rope_neox.comp | 13 +-
.../ggml-vulkan/vulkan-shaders/rope_norm.comp | 13 +-
.../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +
tests/test-backend-ops.cpp | 122 +++++--
6 files changed, 371 insertions(+), 117 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 63a762ec2..db92a7901 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -458,6 +458,11 @@ static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
return mode;
}
+static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
+ { 1, 0, 0 }, // view->src[0] == rope
+ { 2, 0, 1 }, // set_rows->src[0] == view
+};
+
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -640,8 +645,8 @@ struct vk_device_struct {
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_soft_max_back_f32;
- vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
- vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
+ vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
+ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
@@ -1054,6 +1059,7 @@ struct vk_op_rope_push_constants {
uint32_t s2;
int32_t sections[4];
uint32_t is_back;
+ uint32_t set_rows_stride;
};
struct vk_op_soft_max_push_constants {
@@ -1563,6 +1569,10 @@ struct ggml_backend_vk_context {
// number of additional consecutive nodes that are being fused with the
// node currently being processed
int num_additional_fused_ops {};
+ // Bitmask of which fused ops need to write an intermediate value to memory.
+ // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
+ // If there's no fusion, bit 0 is still set.
+ int fused_ops_write_mask {};
};
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -3697,21 +3707,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
} else {
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
@@ -8170,7 +8186,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
- const int mode = ((const int32_t *) dst->op_params)[2];
+ const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst;
+ const int mode = ((const int32_t *) rope->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
@@ -8179,6 +8196,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_neox_f32;
}
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_rope_neox_f32_f16;
+ }
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_neox_f16;
}
@@ -8200,6 +8220,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_norm_f32;
}
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_rope_norm_f32_f16;
+ }
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_norm_f16;
}
@@ -8409,20 +8432,22 @@ static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_ten
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
}
-template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
GGML_UNUSED(p);
GGML_UNUSED(src0);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
GGML_UNUSED(dst);
static_assert(!std::is_const<T>::value, "unexpected type");
GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
+ GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);
GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8430,9 +8455,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8440,9 +8466,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8450,9 +8477,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8460,9 +8488,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src0);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8472,9 +8501,10 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
-template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8483,10 +8513,11 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src1);
GGML_UNUSED(src2);
+ GGML_UNUSED(src3);
}
template<typename PC>
-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
+static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
if (src1 != nullptr) {
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -8494,6 +8525,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
if (src2 != nullptr) {
std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
}
+ if (src3 != nullptr) {
+ std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3];
+ }
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
@@ -8520,6 +8554,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
const uint64_t ne2 = ne20 * ne21;
+ const bool use_src3 = src3 != nullptr;
+ const uint64_t ne30 = use_src3 ? src3->ne[0] : 0;
+ const uint64_t ne31 = use_src3 ? src3->ne[1] : 0;
+ const uint64_t ne32 = use_src3 ? src3->ne[2] : 0;
+ const uint64_t ne33 = use_src3 ? src3->ne[3] : 0;
+ const uint64_t ne3 = ne30 * ne31;
+
const uint64_t ned0 = dst->ne[0];
const uint64_t ned1 = dst->ne[1];
const uint64_t ned2 = dst->ne[2];
@@ -8550,6 +8591,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
+ ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr;
vk_buffer d_X = nullptr;
size_t x_buf_offset = 0;
@@ -8557,10 +8599,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
size_t y_buf_offset = 0;
vk_buffer d_Z = nullptr;
size_t z_buf_offset = 0;
+ vk_buffer d_W = nullptr;
+ size_t w_buf_offset = 0;
bool src0_uma = false;
bool src1_uma = false;
bool src2_uma = false;
+ bool src3_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
@@ -8573,6 +8618,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
src2_uma = d_Z != nullptr;
}
+ if (use_src3) {
+ ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset);
+ src3_uma = d_W != nullptr;
+ }
}
vk_buffer d_D = dst_buf_ctx->dev_buffer;
@@ -8594,11 +8643,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
GGML_ASSERT(d_Z != nullptr);
}
+ if (use_src3 && !src3_uma) {
+ d_W = src3_buf_ctx->dev_buffer;
+ w_buf_offset = vk_tensor_offset(src3) + src3->view_offs;
+ GGML_ASSERT(d_W != nullptr);
+ }
// Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
- init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst);
+ init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
+ w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
std::array<uint32_t, 3> elements;
@@ -8799,12 +8854,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
}
- uint64_t x_sz, y_sz, z_sz, d_sz;
+ uint64_t x_sz, y_sz, z_sz, w_sz, d_sz;
if (op_supports_incontiguous) {
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
+ w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0;
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
if (x_buf_offset + x_sz >= d_X->size) {
@@ -8816,6 +8872,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
}
+ if (use_src3 && w_buf_offset + w_sz >= d_W->size) {
+ w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset);
+ }
if (d_buf_offset + d_sz >= d_D->size) {
d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
}
@@ -8823,6 +8882,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
+ w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0;
d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
}
@@ -8864,14 +8924,19 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
- vk_subbuffer subbuf_z;
+ vk_subbuffer subbuf_z, subbuf_w;
if (use_src2) {
subbuf_z = { d_Z, z_buf_offset, z_sz };
} else {
subbuf_z = { d_X, 0, x_sz };
}
+ if (use_src3) {
+ subbuf_w = { d_W, w_buf_offset, w_sz };
+ } else {
+ subbuf_w = { d_X, 0, x_sz };
+ }
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }, subbuf_w }, pc, elements);
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
// buffer device address path doesn't use dst buffer
@@ -8887,6 +8952,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} else if (op == GGML_OP_OPT_STEP_SGD) {
// OPT_STEP_SGD works on src0, it does not need dst
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
+ } else if (use_src3) {
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src2) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src1) {
@@ -8901,7 +8968,7 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -8921,7 +8988,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9046,7 +9113,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9061,7 +9128,7 @@ static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SUB, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9076,7 +9143,7 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_MUL, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9091,7 +9158,7 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_DIV, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9106,7 +9173,7 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src2_type_size = ggml_type_size(src2->type);
- ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
+ ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_ADD_ID, {
(uint32_t)dst->ne[0],
(uint32_t)dst->ne[1],
(uint32_t)src0->nb[1] / src0_type_size,
@@ -9339,7 +9406,7 @@ static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx,
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
- ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
+ ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
@@ -9457,7 +9524,7 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const size_t n = ggml_nelements(dst->src[0]);
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9467,7 +9534,7 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONCAT, {
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9491,7 +9558,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
}
- ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
+ ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
(uint32_t)ggml_nelements(dst), 0, 0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
@@ -9505,23 +9572,23 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param2 = ggml_get_op_params_f32(dst, 1);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
}
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9529,12 +9596,12 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param2 = ggml_get_op_params_f32(dst, 1);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
}
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
}
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9549,17 +9616,17 @@ static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, cons
memcpy(&p.param1, &s01_packed, sizeof(float));
memcpy(&p.param2, &s23_packed, sizeof(float));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
}
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
}
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
}
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9575,7 +9642,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
}
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
}
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9590,7 +9657,7 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
return;
}
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SET_ROWS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9601,13 +9668,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
}
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9618,7 +9685,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
const float eps = float_op_params[1];
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@@ -9641,7 +9708,7 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -9658,16 +9725,16 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9690,7 +9757,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
- ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GLU,
{
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0],
@@ -9703,7 +9770,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
int32_t * op_params = (int32_t *)dst->op_params;
- ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
+ ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
}
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
@@ -9728,7 +9795,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -9744,7 +9811,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
@@ -9835,7 +9902,12 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
}, pc, elements);
}
-static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
+static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop, bool dryrun = false) {
+ ggml_tensor * dst = cgraph->nodes[node_idx];
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+ const ggml_tensor * src3 = nullptr;
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -9859,11 +9931,20 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
- ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
+ uint32_t set_rows_stride = 0;
+ // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
+ // and overrides the dst and sets src3=row_indices
+ if (ctx->num_additional_fused_ops > 0) {
+ set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type);
+ src3 = cgraph->nodes[node_idx + 2]->src[1];
+ dst = cgraph->nodes[node_idx + 2];
+ }
+
+ ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
- { sections[0], sections[1], sections[2], sections[3] }, backprop
+ { sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride,
}, dryrun);
}
@@ -9872,7 +9953,7 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
uint32_t ncols = src0->ne[0];
- ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
+ ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
op_params[0],
}, dryrun);
@@ -9880,26 +9961,26 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
}
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
}
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
p.weight = 1.0f / (float)src0->ne[0];
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9932,7 +10013,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
- ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
+ ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL, {
dst_addr,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
@@ -10005,7 +10086,7 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
- ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
+ ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
}
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -10013,7 +10094,7 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
const uint32_t max_period = dst->op_params[1];
const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
+ ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
nb1, dim, max_period,
}, dryrun);
}
@@ -10046,7 +10127,7 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.s0 = static_cast<uint32_t>(s0);
- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
}
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -10069,7 +10150,7 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
const uint32_t parallel_elements = N * OC * OH * OW;
- ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
+ ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
IW, IH, OW, OH, OC,
parallel_elements,
op,
@@ -10123,7 +10204,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
GGML_ASSERT(ne03 == ne2);
GGML_ASSERT(ne02 == ne12);
- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
}
static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
@@ -10172,7 +10253,7 @@ static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context
GGML_ASSERT(ne02 == ne2);
GGML_ASSERT(ne03 == ne12);
- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -10196,12 +10277,12 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(src0->ne[3] == p.channels);
GGML_ASSERT(src1->ne[3] == p.batches);
- ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
}
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const float * op_params = (const float *)dst->op_params;
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
}
#ifdef GGML_VULKAN_RUN_TESTS
@@ -11327,7 +11408,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
- case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
@@ -11401,9 +11481,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// nodes require synchronization.
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
- if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
- need_sync = true;
- break;
+ // If the node actually writes to memory, then check if it needs to sync
+ if (ctx->fused_ops_write_mask & (1 << i)) {
+ if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
+ need_sync = true;
+ break;
+ }
}
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
@@ -11430,7 +11513,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
- ctx->unsynced_nodes_written.push_back(cur_node);
+ if (ctx->fused_ops_write_mask & (1 << i)) {
+ ctx->unsynced_nodes_written.push_back(cur_node);
+ }
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
continue;
@@ -11621,11 +11706,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ROPE:
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
+ ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, false, dryrun);
break;
case GGML_OP_ROPE_BACK:
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
+ ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true, dryrun);
break;
case GGML_OP_ARGSORT:
@@ -12487,6 +12572,41 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return true;
}
+static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
+ int node_idx) {
+ GGML_UNUSED(ctx);
+ const ggml_tensor *rope = cgraph->nodes[node_idx + 0];
+ const ggml_tensor *view = cgraph->nodes[node_idx + 1];
+ const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2];
+
+ // ne3 not tested
+ if (rope->src[0]->ne[3] != 1) {
+ return false;
+ }
+
+ if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ if (set_rows->src[1]->type != GGML_TYPE_I64) {
+ return false;
+ }
+
+ // The view should flatten two dims of rope into one dim
+ if (!ggml_is_contiguous(view) ||
+ view->ne[0] != rope->ne[0] * rope->ne[1]) {
+ return false;
+ }
+
+ // Only norm/neox shaders have the fusion code
+ const int mode = ((const int32_t *) rope->op_params)[2];
+ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
+ return false;
+ }
+
+ return true;
+}
+
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
const ggml_tensor *first_node = cgraph->nodes[node_idx];
@@ -12562,6 +12682,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
+ ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
+ ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
+ ctx->num_additional_fused_ops = 2;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
@@ -12671,20 +12795,31 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
+ ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
+ ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
+ ctx->num_additional_fused_ops = 2;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ // view of argsort writes to memory
+ ctx->fused_ops_write_mask |= 1 << 3;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ // view of argsort writes to memory
+ ctx->fused_ops_write_mask |= 1 << 3;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
+ // view of argsort writes to memory
+ ctx->fused_ops_write_mask |= 1 << 1;
}
}
+ ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
@@ -12730,6 +12865,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
}
i += ctx->num_additional_fused_ops;
ctx->num_additional_fused_ops = 0;
+ ctx->fused_ops_write_mask = 0;
}
if (vk_perf_logger_enabled) {
@@ -12887,6 +13023,32 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
}
if (ok) {
current_set.push_back(j);
+ // Look for ROPE + VIEW + SET_ROWS and make them consecutive
+ if (graph->nodes[j]->op == GGML_OP_ROPE) {
+ int view_idx = -1;
+ int set_rows_idx = -1;
+ for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) {
+ if (view_idx == -1 &&
+ graph->nodes[k]->op == GGML_OP_VIEW &&
+ graph->nodes[k]->src[0] == graph->nodes[j]) {
+ view_idx = k;
+ continue;
+ }
+ if (view_idx != -1 &&
+ set_rows_idx == -1 &&
+ graph->nodes[k]->op == GGML_OP_SET_ROWS &&
+ graph->nodes[k]->src[0] == graph->nodes[view_idx]) {
+ set_rows_idx = k;
+ break;
+ }
+ }
+ if (set_rows_idx != -1) {
+ current_set.push_back(view_idx);
+ current_set.push_back(set_rows_idx);
+ used[view_idx] = true;
+ used[set_rows_idx] = true;
+ }
+ }
}
}
// Second pass grabs view nodes.
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
index 50fc1f1e2..0eda186c8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_pos[];};
layout (binding = 2) readonly buffer Z {float data_ff[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
+layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
layout (push_constant) uniform parameter {
uint ncols;
@@ -27,6 +28,7 @@ layout (push_constant) uniform parameter {
uint s2;
int sections[4];
uint is_back;
+ uint set_rows_stride;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
index 06e095bef..9f4538155 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
@@ -16,12 +16,19 @@ void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
- const uint idst = row_dst*ne0 + i0/2;
+ uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
+ // Fusion optimization: ROPE + VIEW + SET_ROWS..
+ // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
+ if (p.set_rows_stride != 0) {
+ idst = row_x*ne0 + i0/2;
+ idst += data_i[channel_x].x * p.set_rows_stride;
+ }
+
if (i0 >= p.n_dims) {
- data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
- data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
+ data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
+ data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
return;
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
index 6ba957540..f4209ed95 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
@@ -16,12 +16,19 @@ void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
- const uint idst = row_dst*ne0 + i0;
+ uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
+ // Fusion optimization: ROPE + VIEW + SET_ROWS..
+ // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
+ if (p.set_rows_stride != 0) {
+ idst = row_x*ne0 + i0;
+ idst += data_i[channel_x].x * p.set_rows_stride;
+ }
+
if (i0 >= p.n_dims) {
- data_d[idst + 0] = data_a[ix + 0];
- data_d[idst + 1] = data_a[ix + 1];
+ data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
+ data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
return;
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index 03fa01639..e6ec589fb 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -842,10 +842,14 @@ void process_shaders() {
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+ string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+ string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 9eb2b6687..657b6cc2f 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -2105,6 +2105,34 @@ struct test_get_rows_back : public test_case {
}
};
+static void init_set_rows_row_ids(ggml_tensor * t, int num_rows) {
+ std::random_device rd;
+ std::default_random_engine rng(rd());
+ for (int i2 = 0; i2 < t->ne[2]; i2++) {
+ for (int i1 = 0; i1 < t->ne[1]; i1++) {
+ // generate a shuffled subset of row indices
+ std::vector<int64_t> data(num_rows);
+ for (int i = 0; i < num_rows; i++) {
+ data[i] = i;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ data.resize(t->ne[0]);
+
+ const size_t offs = i1*t->nb[1] + i2*t->nb[2];
+ if (t->type == GGML_TYPE_I32) {
+ // TODO: Make a template or something
+ std::vector<int32_t> data_i32(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data_i32[i] = static_cast<int32_t>(data[i]);
+ }
+ ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
+ } else {
+ ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
+ }
+ }
+ }
+}
+
// GGML_OP_SET_ROWS
struct test_set_rows : public test_case {
const ggml_type type;
@@ -2148,37 +2176,13 @@ struct test_set_rows : public test_case {
}
void initialize_tensors(ggml_context * ctx) override {
- std::random_device rd;
- std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) {
continue;
}
- for (int i2 = 0; i2 < t->ne[2]; i2++) {
- for (int i1 = 0; i1 < t->ne[1]; i1++) {
- // generate a shuffled subset of row indices
- std::vector<int64_t> data(ne[1]);
- for (int i = 0; i < ne[1]; i++) {
- data[i] = i;
- }
- std::shuffle(data.begin(), data.end(), rng);
- data.resize(t->ne[0]);
-
- const size_t offs = i1*t->nb[1] + i2*t->nb[2];
- if (t->type == GGML_TYPE_I32) {
- // TODO: Make a template or something
- std::vector<int32_t> data_i32(t->ne[0]);
- for (int i = 0; i < t->ne[0]; i++) {
- data_i32[i] = static_cast<int32_t>(data[i]);
- }
- ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
- } else {
- ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
- }
- }
- }
+ init_set_rows_row_ids(t, ne[1]);
} else {
init_tensor_uniform(t);
}
@@ -2207,6 +2211,67 @@ struct test_set_rows : public test_case {
}
};
+// GGML_OP_ROPE + GGML_OP_VIEW + GGML_OP_SET_ROWS
+struct test_rope_set_rows : public test_case {
+ const ggml_type type;
+ const ggml_type type_idx;
+ const std::array<int64_t, 4> ne;
+ int mode;
+
+ std::string vars() override {
+ return VARS_TO_STR4(type, type_idx, ne, mode);
+ }
+
+ std::string op_desc(ggml_tensor * t) override {
+ GGML_UNUSED(t);
+ return "ROPE_SET_ROWS";
+ }
+
+ bool run_whole_graph() override { return true; }
+
+ test_rope_set_rows(ggml_type type,
+ ggml_type type_idx,
+ std::array<int64_t, 4> ne,
+ int mode)
+ : type(type), type_idx(type_idx), ne(ne), mode(mode) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
+ ggml_set_name(src, "src");
+
+ ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
+
+ ggml_tensor * rope = ggml_rope(ctx, src, pos, ne[0], mode);
+
+ ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
+
+ ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
+ ggml_set_name(dst, "dst");
+
+ ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne[2], 1, 1);
+ ggml_set_name(row_idxs, "row_idxs");
+
+ ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs);
+ ggml_set_name(out, "out");
+
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) {
+ continue;
+ }
+
+ init_set_rows_row_ids(t, ne[2]);
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
+};
+
// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
@@ -6008,6 +6073,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
+ for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX }) {
+ for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+ test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 1, 100 }, mode));
+ test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 512, 1 }, mode));
+ }
+ }
+
for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) {
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Thu, 30 Oct 2025 01:27:41 -0500
Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++----
2 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index db92a7901..e959674d1 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants {
uint32_t ncols;
+ uint32_t nrows;
int32_t order;
};
@@ -8710,6 +8711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_IM2COL:
{
@@ -9952,9 +9954,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
int32_t * op_params = (int32_t *)dst->op_params;
uint32_t ncols = src0->ne[0];
+ uint32_t nrows = ggml_nrows(src0);
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
+ nrows,
op_params[0],
}, dryrun);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
index c81b84452..c4e68bc02 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
+ uint nrows;
uint order;
} p;
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
dst_row[idx1] = tmp;
}
-void argsort(bool needs_bounds_check) {
+void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
- const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
void main() {
if (p.ncols == BLOCK_SIZE) {
- argsort(false);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(false, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
} else {
- argsort(true);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(true, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
}
}
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Ruben Ortlam <picard12@live.de>
Date: Fri, 31 Oct 2025 08:14:49 +0100
Subject: [PATCH] vulkan: fix shmem overrun in mmq id shader (#16873)
* vulkan: fix shmem overrun in mmq id shader
* metal : fix mul_mm_id
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +-
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +-
tests/test-backend-ops.cpp | 3 +++
4 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 758116342..c78082ac3 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256];
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
- snprintf(name, 256, "%s", base);
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index 8b238ac4b..d955b4fc7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl"
+#ifdef MUL_MAT_ID
+#define BK_STEP 1
+#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
+#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
index 72fec4404..1c0f5306f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
@@ -27,7 +27,7 @@ struct block_a_cache {
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
-#define BK_STEP 1
+// #define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 657b6cc2f..1f8dda383 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -6722,6 +6722,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
+ // gpt-oss issue with Vulkan mmq_id
+ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
+
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Masato Nakasaka <masato.nakasaka@intel.com>
Date: Fri, 31 Oct 2025 16:18:59 +0900
Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
supported (#16796)
* Experimenting crash fix
* added assert for aborting and fixed comment
* changed to check if a pipeline is empty or not
* Moved function in class definition
* replaced with is_empty
* Modified is_empty to check only unaligned pipelines
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++-------
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index e959674d1..903050b0b 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
struct vk_matmul_pipeline_struct {
vk_pipeline l, m, s;
vk_pipeline a_l, a_m, a_s;
+ // Returns true when all unaligned pipelines are null.
+ // We only check for unaligned variants since one of the unaligned pipelines must exist
+ // while aligned pipelines are optional
+ bool is_empty() const {
+ return l == nullptr && m == nullptr && s == nullptr;
+ }
};
-
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
struct vk_matmul_pipeline2 {
@@ -5080,7 +5085,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5229,7 +5234,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5264,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
}
+ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
+ bool support_fp16acc = !mmp.f16acc->is_empty();
+ bool support_fp32acc = !mmp.f32acc->is_empty();
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
+ return mmp.f16acc;
} else {
GGML_ASSERT(support_fp32acc);
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
+ return mmp.f32acc;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment