0011-add-unpad-operator.patch 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 17 Oct 2024 17:19:25 -0700
Subject: [PATCH] add unpad operator

---
 ggml/include/ggml.h        | 10 ++++
 ggml/src/ggml-cuda.cu      |  4 ++
 ggml/src/ggml-cuda/pad.cu  | 46 +++++++++++++++++++
 ggml/src/ggml-cuda/pad.cuh |  1 +
 ggml/src/ggml-metal.m      | 33 ++++++++++++++
 ggml/src/ggml-metal.metal  | 45 ++++++++++++++++++
 ggml/src/ggml.c            | 93 +++++++++++++++++++++++++++++++++++++-
 7 files changed, 230 insertions(+), 2 deletions(-)

diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index ce3d92cb..962cb5f7 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -506,6 +506,7 @@ extern "C" {
         GGML_OP_POOL_2D_BACK,
         GGML_OP_UPSCALE, // nearest interpolate
         GGML_OP_PAD,
+        GGML_OP_UNPAD,
         GGML_OP_ARANGE,
         GGML_OP_TIMESTEP_EMBEDDING,
         GGML_OP_ARGSORT,
@@ -1764,6 +1765,15 @@ extern "C" {
             int                  p2,
             int                  p3);
 
+    // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x]
+    GGML_API struct ggml_tensor * ggml_unpad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
     // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
     // timesteps: [N,]
     // return: [N, dim]
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index fe77b81c..6e84af56 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2270,6 +2270,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_PAD:
             ggml_cuda_op_pad(ctx, dst);
             break;
+        case GGML_OP_UNPAD:
+            ggml_cuda_op_unpad(ctx, dst);
+            break;
         case GGML_OP_ARANGE:
             ggml_cuda_op_arange(ctx, dst);
             break;
@@ -2992,6 +2995,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_GROUP_NORM:
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
index aba539e8..39fd4b16 100644
--- a/ggml/src/ggml-cuda/pad.cu
+++ b/ggml/src/ggml-cuda/pad.cu
@@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
         dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
 }
+
+static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
+    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
+    // blockIdx.y: idx of ne1
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+        dst[offset_dst] = x[offset_src];
+    }
+}
+
+static void unpad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02, const int ne03,
+    const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2*ne3);
+    unpad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+}
+
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    unpad_f32_cuda(src0_d, dst_d,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+}
diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh
index 8fd386b0..e2ededc3 100644
--- a/ggml/src/ggml-cuda/pad.cuh
+++ b/ggml/src/ggml-cuda/pad.cuh
@@ -3,3 +3,4 @@
 #define CUDA_PAD_BLOCK_SIZE 256
 
 void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 829c5e39..25702d85 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -193,6 +193,7 @@
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
     GGML_METAL_KERNEL_TYPE_PAD_F32,
+    GGML_METAL_KERNEL_TYPE_UNPAD_F32,
     GGML_METAL_KERNEL_TYPE_ARANGE_F32,
     GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -689,6 +690,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                     unpad_f32,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
@@ -846,6 +848,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
             return false;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
@@ -2655,6 +2658,36 @@ static void ggml_metal_encode_node(
 
                 const int nth = MIN(1024, ne0);
 
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_UNPAD:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                const int nth = MIN(1024, ne0);
+
                 [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_ARANGE:
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 2b200032..09887511 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -2029,6 +2029,51 @@ kernel void kernel_pad_f32(
     }
 }
 
+kernel void kernel_unpad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            }
+        }
+
+        return;
+    }
+}
+
 kernel void kernel_arange_f32(
     device        char * dst,
     constant   int64_t & ne0,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index bcbc32d9..f4864ac8 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2997,6 +2997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_2D_BACK",
     "UPSCALE",
     "PAD",
+    "UNPAD",
     "ARANGE",
     "TIMESTEP_EMBEDDING",
     "ARGSORT",
@@ -3030,7 +3031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3091,6 +3092,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_2d_back(x)",
     "upscale(x)",
     "pad(x)",
+    "unpad(x)",
     "arange(start, stop, step)",
     "timestep_embedding(timesteps, dim, max_period)",
     "argsort(x)",
@@ -3124,7 +3126,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -6955,6 +6957,32 @@ struct ggml_tensor * ggml_pad(
     return result;
 }
 
+// ggml_unpad
+
+struct ggml_tensor * ggml_unpad(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    int p0, int p1, int p2, int p3) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ABORT("fatal error"); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] - p0,
+            a->ne[1] - p1,
+            a->ne[2] - p2,
+            a->ne[3] - p3);
+
+    result->op = GGML_OP_UNPAD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
 // ggml_arange
 
 struct ggml_tensor * ggml_arange(
@@ -15312,6 +15340,58 @@ static void ggml_compute_forward_pad(
     }
 }
 
+static void ggml_compute_forward_unpad_f32(
+    const struct ggml_compute_params *params,
+    struct ggml_tensor *dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_unpad(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_unpad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
 
 // ggml_compute_forward_arange
 
@@ -17294,6 +17374,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_pad(params, tensor);
             } break;
+        case GGML_OP_UNPAD:
+            {
+                ggml_compute_forward_unpad(params, tensor);
+            } break;
         case GGML_OP_ARANGE:
             {
                 ggml_compute_forward_arange(params, tensor);
@@ -18369,6 +18453,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ABORT("fatal error"); // TODO: not implemented
             }
+        case GGML_OP_UNPAD:
+            {
+                GGML_ABORT("fatal error"); // TODO: not implemented
+            }
         case GGML_OP_ARANGE:
             {
                 GGML_ABORT("fatal error"); // TODO: not implemented
@@ -19165,6 +19253,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT: