0033-vulkan-Handle-argsort-with-a-large-number-of-rows-16.patch 2.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Thu, 30 Oct 2025 01:27:41 -0500
Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)

---
 ggml/src/ggml-vulkan/ggml-vulkan.cpp             |  4 ++++
 ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++----
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index aaf4334b5..3604ceb04 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {
 
 struct vk_op_argsort_push_constants {
     uint32_t ncols;
+    uint32_t nrows;
     int32_t order;
 };
 
@@ -8710,6 +8711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         break;
     case GGML_OP_ARGSORT:
         elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+        elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
         break;
     case GGML_OP_IM2COL:
         {
@@ -9952,9 +9954,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
     int32_t * op_params = (int32_t *)dst->op_params;
 
     uint32_t ncols = src0->ne[0];
+    uint32_t nrows = ggml_nrows(src0);
 
     ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
         ncols,
+        nrows,
         op_params[0],
     }, dryrun);
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
index c81b84452..c4e68bc02 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
@@ -14,6 +14,7 @@ layout (binding = 1)          buffer D {int data_d[];};
 
 layout (push_constant) uniform parameter {
     uint ncols;
+    uint nrows;
     uint order;
 } p;
 
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
     dst_row[idx1] = tmp;
 }
 
-void argsort(bool needs_bounds_check) {
+void argsort(bool needs_bounds_check, const uint row) {
     // bitonic sort
     const int col = int(gl_LocalInvocationID.x);
-    const uint row = gl_WorkGroupID.y;
 
     const uint row_offset = row * p.ncols;
 
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
 
 void main() {
     if (p.ncols == BLOCK_SIZE) {
-        argsort(false);
+        uint row = gl_WorkGroupID.y;
+        while (row < p.nrows) {
+            argsort(false, row);
+            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+        }
     } else {
-        argsort(true);
+        uint row = gl_WorkGroupID.y;
+        while (row < p.nrows) {
+            argsort(true, row);
+            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+        }
     }
 }