rope.cu 21 KB
Newer Older
1
/**
2
 * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
 *
 * MIT License
 *
 * Copyright (c) 2023-2024 The ggml authors
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include "rope.cuh"

struct rope_corr_dims {
    float v[2];
};

33
34
35
36
37

struct mrope_sections {
    int v[4];
};

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
    const float y = (i0 / 2 - low) / max(0.001f, high - low);
    return 1.0f - min(1.0f, max(0.0f, y));
}

// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static __device__ void rope_yarn(
    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
    float * cos_theta, float * sin_theta) {
    // Get n-d rotational scaling corrected for extrapolation
    float theta_interp = freq_scale * theta_extrap;
    float theta = theta_interp;
    if (ext_factor != 0.0f) {
        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;

        // Get n-d magnitude scaling corrected for interpolation
        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
    }
    *cos_theta = cosf(theta) * mscale;
    *sin_theta = sinf(theta) * mscale;
}

template<typename T, bool has_ff>
63
static __global__ void rope_norm(
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

    if (i0 >= ne0) {
        return;
    }

    const int row = blockDim.x*blockIdx.x + threadIdx.x;

    if (i0 >= n_dims) {
        const int i = row*ne0 + i0;

        dst[i + 0] = x[i + 0];
        dst[i + 1] = x[i + 1];

        return;
    }

    const int i  = row*ne0 + i0;
    const int i2 = row/p_delta_rows;

    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);

    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

    float cos_theta;
    float sin_theta;

    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

    const float x0 = x[i + 0];
    const float x1 = x[i + 1];

    dst[i + 0] = x0*cos_theta - x1*sin_theta;
    dst[i + 1] = x0*sin_theta + x1*cos_theta;
}

template<typename T, bool has_ff>
static __global__ void rope_neox(
    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

    if (i0 >= ne0) {
        return;
    }

    const int row = blockDim.x*blockIdx.x + threadIdx.x;

    if (i0 >= n_dims) {
        const int i = row*ne0 + i0;

        dst[i + 0] = x[i + 0];
        dst[i + 1] = x[i + 1];

        return;
    }

    const int i  = row*ne0 + i0/2;
    const int i2 = row/p_delta_rows;

    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);

    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

    float cos_theta;
    float sin_theta;

    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

    const float x0 = x[i + 0];
    const float x1 = x[i + n_dims/2];

    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
template<typename T, bool has_ff>
static __global__ void rope_multi(
    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

    if (i0 >= ne0) {
        return;
    }

    const int row = blockDim.x*blockIdx.x + threadIdx.x;

    if (i0 >= n_dims) {
        const int i = row*ne0 + i0;

        dst[i + 0] = x[i + 0];
        dst[i + 1] = x[i + 1];

        return;
    }

    const int i  = row*ne0 + i0/2;
    const int i2 = row/p_delta_rows;

    int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
    int sec_w = sections.v[1] + sections.v[0];
    int sector = (i0 / 2) % sect_dims;

    float theta_base = 0.0;
    if (sector < sections.v[0]) {
        theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
    }
    else if (sector >= sections.v[0] && sector < sec_w) {
        theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
    }
    else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
        theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
    }
    else if (sector >= sec_w + sections.v[2]) {
        theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
    }

    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

    float cos_theta;
    float sin_theta;

    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

    const float x0 = x[i + 0];
    const float x1 = x[i + n_dims/2];

    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}

template<typename T, bool has_ff>
static __global__ void rope_vision(
    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

    if (i0 >= ne0) {
        return;
    }

    const int row = blockDim.x*blockIdx.x + threadIdx.x;

    const int i  = row*ne0 + i0/2;
    const int i2 = row/p_delta_rows; // i2-th tokens

    int sect_dims = sections.v[0] + sections.v[1];
    int sec_w = sections.v[1] + sections.v[0];
    int sector = (i0 / 2) % sect_dims;

    float theta_base = 0.0;
    if (sector < sections.v[0]) {
        const int p = sector;
        theta_base = pos[i2]*powf(theta_scale, p);
    }
    else if (sector >= sections.v[0] && sector < sec_w) {
        const int p = sector - sections.v[0];
        theta_base = pos[i2 + ne2]*powf(theta_scale, p);
    }

    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

    float cos_theta;
    float sin_theta;

    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

    const float x0 = x[i + 0];
    const float x1 = x[i + n_dims];

    dst[i + 0]      = x0*cos_theta - x1*sin_theta;
    dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
}

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
template<typename T>
static void rope_norm_cuda(
    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
    GGML_ASSERT(ne0 % 2 == 0);
    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
    const dim3 block_nums(nr, n_blocks_x, 1);

    const float theta_scale = powf(freq_base, -2.0f/n_dims);

    if (freq_factors == nullptr) {
        rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                theta_scale, freq_factors
                );
    } else {
        rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                theta_scale, freq_factors
                );
    }
}

template<typename T>
static void rope_neox_cuda(
    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
    GGML_ASSERT(ne0 % 2 == 0);
    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
    const dim3 block_nums(nr, n_blocks_x, 1);

    const float theta_scale = powf(freq_base, -2.0f/n_dims);

    if (freq_factors == nullptr) {
        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                theta_scale, freq_factors
                );
    } else {
        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                theta_scale, freq_factors
                );
    }
}

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
template<typename T>
static void rope_multi_cuda(
    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
    GGML_ASSERT(ne0 % 2 == 0);
    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
    const dim3 block_nums(nr, n_blocks_x, 1);

    const float theta_scale = powf(freq_base, -2.0f/n_dims);

    if (freq_factors == nullptr) {
        rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                theta_scale, freq_factors, sections
                );
    } else {
        rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                theta_scale, freq_factors, sections
                );
    }
}

template<typename T>
static void rope_vision_cuda(
    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
    GGML_ASSERT(ne0 % 2 == 0);
    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
    const dim3 block_nums(nr, n_blocks_x, 1);
    // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
    // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);

    const float theta_scale = powf(freq_base, -2.0f/n_dims);

    if (freq_factors == nullptr) {
        rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                theta_scale, freq_factors, sections
                );
    } else {
        rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                theta_scale, freq_factors, sections
                );
    }
}

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
static void rope_norm_cuda_f16(
    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {

    rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_norm_cuda_f32(
    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {

    rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_neox_cuda_f16(
    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {

    rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_neox_cuda_f32(
    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {

    rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
static void rope_multi_cuda_f16(
    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {

    rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}

static void rope_multi_cuda_f32(
    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {

    rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}

static void rope_vision_cuda_f16(
    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {

    rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}

static void rope_vision_cuda_f32(
    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {

    rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}

400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
    const ggml_tensor * src0 = dst->src[0];
    const ggml_tensor * src1 = dst->src[1];
    const ggml_tensor * src2 = dst->src[2];

    const float * src0_d = (const float *)src0->data;
    const float * src1_d = (const float *)src1->data;

    float * dst_d = (float *)dst->data;
    cudaStream_t stream = ctx.stream();

    GGML_ASSERT(ggml_is_contiguous(src0));
    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
    GGML_ASSERT(src0->type == dst->type);

416
417
418
    const int64_t ne00 = src0->ne[0]; // head dims
    const int64_t ne01 = src0->ne[1]; // num heads
    const int64_t ne02 = src0->ne[2]; // num heads
419
420
421
422
423
424
425
    const int64_t nr = ggml_nrows(src0);

    //const int n_past     = ((int32_t *) dst->op_params)[0];
    const int n_dims     = ((int32_t *) dst->op_params)[1];
    const int mode       = ((int32_t *) dst->op_params)[2];
    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
426
    mrope_sections sections;
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

    // RoPE alteration for extended context
    float freq_base;
    float freq_scale;
    float ext_factor;
    float attn_factor;
    float beta_fast;
    float beta_slow;

    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
442
    memcpy(&sections.v,  (int32_t *) dst->op_params + 11, sizeof(int)*4);
443
444

    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
445
446
447
448
449
450
451
452
453
454
    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;

    if (is_mrope) {
        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
    }

    if (is_vision) {
        GGML_ASSERT(n_dims == ne00/2);
    }
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

    const int32_t * pos = (const int32_t *) src1_d;

    const float * freq_factors = nullptr;
    if (src2 != nullptr) {
        freq_factors = (const float *) src2->data;
    }

    rope_corr_dims corr_dims;
    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);

    // compute
    if (is_neox) {
        if (src0->type == GGML_TYPE_F32) {
            rope_neox_cuda_f32(
                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                attn_factor, corr_dims, freq_factors, stream
            );
        } else if (src0->type == GGML_TYPE_F16) {
            rope_neox_cuda_f16(
                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                attn_factor, corr_dims, freq_factors, stream
            );
        } else {
            GGML_ABORT("fatal error");
        }
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    } else if (is_mrope && !is_vision) {
        if (src0->type == GGML_TYPE_F32) {
            rope_multi_cuda_f32(
                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                attn_factor, corr_dims, freq_factors, sections, stream
            );
        } else if (src0->type == GGML_TYPE_F16) {
            rope_multi_cuda_f16(
                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                attn_factor, corr_dims, freq_factors, sections, stream
            );
        } else {
            GGML_ABORT("fatal error");
        }
    } else if (is_vision) {
        if (src0->type == GGML_TYPE_F32) {
            rope_vision_cuda_f32(
                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                attn_factor, corr_dims, freq_factors, sections, stream
            );
        } else if (src0->type == GGML_TYPE_F16) {
            rope_vision_cuda_f16(
                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                attn_factor, corr_dims, freq_factors, sections, stream
            );
        } else {
            GGML_ABORT("fatal error");
        }
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    } else {
        if (src0->type == GGML_TYPE_F32) {
            rope_norm_cuda_f32(
                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                attn_factor, corr_dims, freq_factors, stream
            );
        } else if (src0->type == GGML_TYPE_F16) {
            rope_norm_cuda_f16(
                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                attn_factor, corr_dims, freq_factors, stream
            );
        } else {
            GGML_ABORT("fatal error");
        }
    }
}