0022-metal-add-op_neg.patch 3.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Wed, 2 Apr 2025 15:26:15 -0700
Subject: [PATCH] metal: add op_neg

---
 ggml/src/ggml-metal/ggml-metal.m     | 15 +++++++++++++++
 ggml/src/ggml-metal/ggml-metal.metal |  7 +++++++
 2 files changed, 22 insertions(+)

diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index e4c093f9..d8422f1b 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SQRT,
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
+    GGML_METAL_KERNEL_TYPE_NEG,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                           neg,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                        argmax,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true);
@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_NEG:
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_NEG:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 default:
                 {
                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index f38909d0..bb0ff668 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -945,6 +945,13 @@ kernel void kernel_cos(
     dst[tpig] = cos(src0[tpig]);
 }
 
+kernel void kernel_neg(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = -src0[tpig];
+}
+
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,