0018-add-op_neg.patch 3.73 KB
Newer Older
1
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
2
3
4
From: jmorganca <jmorganca@gmail.com>
Date: Tue, 8 Apr 2025 20:41:24 -0700
Subject: [PATCH] add op_neg
5

6
adds the neg operator to ggml
7
8
9
10
11
12
---
 ggml/src/ggml-metal/ggml-metal.m     | 15 +++++++++++++++
 ggml/src/ggml-metal/ggml-metal.metal |  7 +++++++
 2 files changed, 22 insertions(+)

diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
13
index b121ab9e..fea50521 100644
14
15
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
16
@@ -461,6 +461,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
17
18
19
20
21
22
23
     GGML_METAL_KERNEL_TYPE_SQRT,
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
+    GGML_METAL_KERNEL_TYPE_NEG,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
24
25
26
27
28
29
30
31
32
@@ -1119,6 +1120,7 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                            sqrt,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                             sin,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
@@ -1280,6 +1282,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
33
34
35
36
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_NEG:
37
                     return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
38
39
                 default:
                     return false;
40
@@ -1966,6 +1969,18 @@ static void ggml_metal_encode_node(
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_NEG:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 default:
                 {
                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
60
index e3185e5b..ede9d1e6 100644
61
62
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
63
@@ -949,6 +949,13 @@ kernel void kernel_cos(
64
65
66
67
68
69
70
71
72
73
74
75
76
     dst[tpig] = cos(src0[tpig]);
 }
 
+kernel void kernel_neg(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = -src0[tpig];
+}
+
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,