0008-add-unpad-operator.patch 14.8 KB
Newer Older
1
2
3
4
5
6
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 17 Oct 2024 17:19:25 -0700
Subject: [PATCH] add unpad operator

---
7
 ggml/include/ggml.h                  | 10 +++++
8
 ggml/src/ggml-cpu/ggml-cpu.c         | 58 ++++++++++++++++++++++++++++
9
10
11
12
 ggml/src/ggml-cuda/ggml-cuda.cu      |  4 ++
 ggml/src/ggml-cuda/pad.cu            | 46 ++++++++++++++++++++++
 ggml/src/ggml-cuda/pad.cuh           |  1 +
 ggml/src/ggml-metal/ggml-metal.m     | 33 ++++++++++++++++
13
 ggml/src/ggml-metal/ggml-metal.metal | 45 +++++++++++++++++++++
14
 ggml/src/ggml.c                      | 25 +++++++++++-
15
 8 files changed, 220 insertions(+), 2 deletions(-)
16
17

diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
18
index c714fc8c..1bc50fca 100644
19
20
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
21
@@ -499,6 +499,7 @@ extern "C" {
22
23
         GGML_OP_UPSCALE, // nearest interpolate
         GGML_OP_PAD,
24
         GGML_OP_PAD_REFLECT_1D,
25
26
27
28
+        GGML_OP_UNPAD,
         GGML_OP_ARANGE,
         GGML_OP_TIMESTEP_EMBEDDING,
         GGML_OP_ARGSORT,
29
@@ -1735,6 +1736,15 @@ extern "C" {
30
31
             int                   p0,
             int                   p1);
32
33
34
35
36
37
38
39
40
41
42
43
44
 
+    // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x]
+    GGML_API struct ggml_tensor * ggml_unpad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
     // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
     // timesteps: [N,]
     // return: [N, dim]
45
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
46
index b7fefb9d..b307d554 100644
47
48
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
49
@@ -10588,6 +10588,59 @@ static void ggml_compute_forward_pad_reflect_1d(
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
     }
 }
 
+static void ggml_compute_forward_unpad_f32(
+    const struct ggml_compute_params *params,
+    struct ggml_tensor *dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_unpad(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_unpad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
105
+
106
107
 // ggml_compute_forward_arange
 
108
109
 static void ggml_compute_forward_arange_f32(
@@ -12690,6 +12743,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
110
             {
111
                 ggml_compute_forward_pad_reflect_1d(params, tensor);
112
113
114
115
116
117
118
119
             } break;
+        case GGML_OP_UNPAD:
+            {
+                ggml_compute_forward_unpad(params, tensor);
+            } break;
         case GGML_OP_ARANGE:
             {
                 ggml_compute_forward_arange(params, tensor);
120
@@ -13033,6 +13090,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
121
122
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
123
         case GGML_OP_PAD_REFLECT_1D:
124
125
126
127
128
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
Michael Yang's avatar
Michael Yang committed
129
index aaa79ea4..9286f866 100644
130
131
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
Michael Yang's avatar
Michael Yang committed
132
@@ -2082,6 +2082,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
133
134
135
136
137
138
139
140
141
         case GGML_OP_PAD:
             ggml_cuda_op_pad(ctx, dst);
             break;
+        case GGML_OP_UNPAD:
+            ggml_cuda_op_unpad(ctx, dst);
+            break;
         case GGML_OP_ARANGE:
             ggml_cuda_op_arange(ctx, dst);
             break;
Michael Yang's avatar
Michael Yang committed
142
@@ -3010,6 +3013,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
         case GGML_OP_GROUP_NORM:
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
index aba539e8..39fd4b16 100644
--- a/ggml/src/ggml-cuda/pad.cu
+++ b/ggml/src/ggml-cuda/pad.cu
@@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
         dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
 }
+
+static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
+    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
+    // blockIdx.y: idx of ne1
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+        dst[offset_dst] = x[offset_src];
+    }
+}
+
+static void unpad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02, const int ne03,
+    const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2*ne3);
+    unpad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+}
+
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    unpad_f32_cuda(src0_d, dst_d,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+}
diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh
index 8fd386b0..e2ededc3 100644
--- a/ggml/src/ggml-cuda/pad.cuh
+++ b/ggml/src/ggml-cuda/pad.cuh
@@ -3,3 +3,4 @@
 #define CUDA_PAD_BLOCK_SIZE 256
 
 void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
213
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
Michael Yang's avatar
Michael Yang committed
214
index cd8ef741..318addec 100644
215
216
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
Michael Yang's avatar
Michael Yang committed
217
@@ -311,6 +311,7 @@ enum ggml_metal_kernel_type {
218
219
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
     GGML_METAL_KERNEL_TYPE_PAD_F32,
220
     GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
221
222
223
224
+    GGML_METAL_KERNEL_TYPE_UNPAD_F32,
     GGML_METAL_KERNEL_TYPE_ARANGE_F32,
     GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
Michael Yang's avatar
Michael Yang committed
225
@@ -910,6 +911,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
226
227
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
228
229
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,            pad_reflect_1d_f32,             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                     unpad_f32,                      true);
230
231
232
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
233
@@ -1145,6 +1147,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
234
235
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
236
         case GGML_OP_PAD_REFLECT_1D:
237
238
239
240
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
241
@@ -3348,6 +3351,36 @@ static void ggml_metal_encode_node(
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
 
                 const int nth = MIN(1024, ne0);
 
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_UNPAD:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                const int nth = MIN(1024, ne0);
+
                 [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_ARANGE:
278
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
279
index 8ba43904..204c93e6 100644
280
281
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
282
@@ -2944,6 +2944,51 @@ kernel void kernel_pad_reflect_1d_f32(
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
     }
 }
 
+kernel void kernel_unpad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            }
+        }
+
+        return;
+    }
+}
+
 kernel void kernel_arange_f32(
     device        char * dst,
     constant   int64_t & ne0,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
335
index 2bbe5f48..7ffcd907 100644
336
337
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
338
@@ -954,6 +954,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
339
340
     "UPSCALE",
     "PAD",
341
     "PAD_REFLECT_1D",
342
343
344
345
+    "UNPAD",
     "ARANGE",
     "TIMESTEP_EMBEDDING",
     "ARGSORT",
346
@@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
347
348
349
     "OPT_STEP_ADAMW",
 };
 
350
351
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
352
353
354
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
355
@@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
356
357
     "upscale(x)",
     "pad(x)",
358
     "pad_reflect_1d(x)",
359
360
361
362
+    "unpad(x)",
     "arange(start, stop, step)",
     "timestep_embedding(timesteps, dim, max_period)",
     "argsort(x)",
363
@@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
364
365
366
     "adamw(x)",
 };
 
367
368
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
369
370
371
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
372
@@ -4214,6 +4216,25 @@ struct ggml_tensor * ggml_pad_reflect_1d(
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
     return result;
 }
 
+// ggml_unpad
+
+struct ggml_tensor * ggml_unpad(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    int p0, int p1, int p2, int p3) {
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] - p0,
+            a->ne[1] - p1,
+            a->ne[2] - p2,
+            a->ne[3] - p3);
+
+    result->op = GGML_OP_UNPAD;
+    result->src[0] = a;
+
+    return result;
+}
+
 // ggml_arange
 
 struct ggml_tensor * ggml_arange(