0032-interleave-multi-rope.patch 4.81 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Web, 16 Oct 2025 20:37:19 -0700
Subject: [PATCH] interleave multi rope

since ollama doesn't use mrope for anything else, change it to mean the
interleaved version used for qwen3vl
---
 ggml/src/ggml-cpu/ops.cpp                           |  7 ++-----
 ggml/src/ggml-cuda/rope.cu                          | 12 +++---------
 ggml/src/ggml-metal/ggml-metal.metal                | 10 +++-------
 ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp | 12 +++---------
 4 files changed, 11 insertions(+), 30 deletions(-)

diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 31478dd8e..4d1ed207e 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init(
         }
 
         float theta = theta_t;
-        if (sector >= sections[0] && sector < sec_w) {
+        if (sector % 3 == 1 && sector < 1 + 3 * sections[1]) {
             theta = theta_h;
         }
-        else if (sector >= sec_w && sector < sec_w + sections[2]) {
+        else if (sector % 3 == 2 && sector < 2 + 3 * sections[2]) {
             theta = theta_w;
         }
-        else if (sector >= sec_w + sections[2]) {
-            theta = theta_e;
-        }
 
         rope_yarn(
             theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu
index d058504cd..287fe9d2c 100644
--- a/ggml/src/ggml-cuda/rope.cu
+++ b/ggml/src/ggml-cuda/rope.cu
@@ -151,19 +151,13 @@ static __global__ void rope_multi(
     const int sec_w = sections.v[1] + sections.v[0];
     const int sector = (i0 / 2) % sect_dims;
 
-    float theta_base = 0.0;
-    if (sector < sections.v[0]) {
-        theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
-    }
-    else if (sector >= sections.v[0] && sector < sec_w) {
+    float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+    if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) {
         theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
     }
-    else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+    else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) {
         theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
     }
-    else if (sector >= sec_w + sections.v[2]) {
-        theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
-    }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 375a0c7fd..9866c96b4 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3858,15 +3858,11 @@ kernel void kernel_rope_multi(
             const int sec_w012  = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
             const int sector    = ic % sect_dims;
 
-            float theta_base;
-            if (sector < args.sect_0) {
-                theta_base = (float) pos[i2];
-            } else if (sector < sec_w01) {
+            float theta_base = (float) pos[i2];
+            if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) {
                 theta_base = (float) pos[i2 + args.ne02];
-            } else if (sector < sec_w012) {
+            } else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) {
                 theta_base = (float) pos[i2 + args.ne02 * 2];
-            } else {
-                theta_base = (float) pos[i2 + args.ne02 * 3];
             }
             // end of mrope
 
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
index 111286b49..6fc2b42f8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
@@ -31,19 +31,13 @@ void main() {
     const int sec_w = p.sections[1] + p.sections[0];
     const uint sector = (i0 / 2) % sect_dims;
 
-    float theta_base = 0.0;
-    if (sector < p.sections[0]) {
-        theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
-    }
-    else if (sector >= p.sections[0] && sector < sec_w) {
+    float theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
+    if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) {
         theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
     }
-    else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
+    else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) {
         theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
     }
-    else if (sector >= sec_w + p.sections[2]) {
-        theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
-    }
 
     const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;