ggml.go 25 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
package ggml

3
4
5
6
7
8
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
// #include <stdlib.h>
// #include <stdint.h>
// #include "ggml.h"
// #include "ggml-cpu.h"
// #include "ggml-backend.h"
Michael Yang's avatar
Michael Yang committed
9
10
11
import "C"

import (
12
	"context"
Michael Yang's avatar
Michael Yang committed
13
14
15
	"fmt"
	"io"
	"log/slog"
16
	"maps"
Michael Yang's avatar
Michael Yang committed
17
	"os"
18
	"runtime"
19
20
21
	"slices"
	"strconv"
	"strings"
22
	"sync/atomic"
23
	"unicode"
Michael Yang's avatar
Michael Yang committed
24
25
26
27
28
	"unsafe"

	"github.com/ollama/ollama/format"
	fs "github.com/ollama/ollama/fs/ggml"
	"github.com/ollama/ollama/ml"
29
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
Michael Yang's avatar
Michael Yang committed
30
31
32
	"golang.org/x/sync/errgroup"
)

Michael Yang's avatar
Michael Yang committed
33
34
35
36
37
func devices() []*C.struct_ggml_backend_device {
	ggml.OnceLoad()
	ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
	for i := range ds {
		ds[i] = C.ggml_backend_dev_get(C.size_t(i))
Michael Yang's avatar
Michael Yang committed
38
	}
Michael Yang's avatar
Michael Yang committed
39
40

	return ds
41
}
Michael Yang's avatar
Michael Yang committed
42
43

type Backend struct {
44
45
46
	meta    *fs.GGML
	sched   *C.struct_ggml_backend_sched
	tensors map[string]*C.struct_ggml_tensor
Michael Yang's avatar
Michael Yang committed
47
48

	// input is the backend used for inputs
49
	input *C.struct_ggml_backend_buffer_type
Michael Yang's avatar
Michael Yang committed
50
51

	// layers is the backend used for repeating layers
52
	layers map[int]*C.struct_ggml_backend_buffer_type
53

54
	flashAttention bool
Michael Yang's avatar
Michael Yang committed
55
56
57

	// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
	maxGraphNodes int
Michael Yang's avatar
Michael Yang committed
58
59
}

60
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
Michael Yang's avatar
Michael Yang committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
	meta, n, err := fs.Decode(r, -1)
	if err != nil {
		return nil, err
	}

	slog.Info(
		"",
		"architecture", meta.KV().Architecture(),
		"file_type", meta.KV().FileType(),
		"name", meta.KV().String("general.name"),
		"description", meta.KV().String("general.description"),
		"num_tensors", len(meta.Tensors().Items()),
		"num_key_values", len(meta.KV()),
	)

76
	type deviceBufferType struct {
77
78
79
80
81
		d   *C.struct_ggml_backend_device
		bts []*C.struct_ggml_backend_buffer_type
	}

	var cpus, accels, gpus []*C.struct_ggml_backend_device
Michael Yang's avatar
Michael Yang committed
82
	for _, d := range devices() {
83
84
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU:
85
86
87
88
			if len(cpus) == 0 {
				// only the first cpu device should be used
				cpus = append(cpus, d)
			}
89
90
		case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
			accels = append(accels, d)
Michael Yang's avatar
Michael Yang committed
91
		case C.GGML_BACKEND_DEVICE_TYPE_GPU:
92
			gpus = append(gpus, d)
Michael Yang's avatar
Michael Yang committed
93
94
95
		}
	}

Michael Yang's avatar
Michael Yang committed
96
	// create list of buffer types for the cpu
Michael Yang's avatar
Michael Yang committed
97
	cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
98
99
100
101
	for _, d := range append(accels, append(gpus, cpus...)...) {
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU,
			C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
Michael Yang's avatar
Michael Yang committed
102
			cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
Michael Yang's avatar
Michael Yang committed
103
		}
104
105
	}

Michael Yang's avatar
Michael Yang committed
106
	// create list of buffer types for each gpu
107
	var gpuDeviceBufferTypes []deviceBufferType
108
109
	for _, d := range gpus {
		bt := C.ggml_backend_dev_buffer_type(d)
110
		gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
111
			d:   d,
Michael Yang's avatar
Michael Yang committed
112
			bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
113
		})
Michael Yang's avatar
Michael Yang committed
114
115
	}

Michael Yang's avatar
Michael Yang committed
116
117
118
119
120
	useDefaultSplit := true
	for _, s := range params.TensorSplit {
		if s != 0 {
			useDefaultSplit = false
			break
121
		}
Michael Yang's avatar
Michael Yang committed
122
	}
123

Michael Yang's avatar
Michael Yang committed
124
125
126
127
	// calculate splits
	splits := make([]float32, len(gpus))
	if useDefaultSplit {
		// default: split on free memory
128
129
130
131
132
		for i := range splits {
			var free, total C.size_t
			C.ggml_backend_dev_memory(gpus[i], &free, &total)
			splits[i] = float32(free)
		}
Michael Yang's avatar
Michael Yang committed
133
134
	} else {
		splits = params.TensorSplit
135
136
137
	}

	var sum float32
Michael Yang's avatar
Michael Yang committed
138
	// cumulative sum of all splits
139
140
141
142
143
	for i := range splits {
		sum += splits[i]
		splits[i] = sum
	}

Michael Yang's avatar
Michael Yang committed
144
	// normalize splits
145
	for i := range splits {
146
		splits[i] /= sum
147
148
	}

Michael Yang's avatar
Michael Yang committed
149
	// inputs always use cpu
Michael Yang's avatar
Michael Yang committed
150
	input := cpuDeviceBufferType
151

152
	blocks := int(meta.KV().BlockCount())
Michael Yang's avatar
Michael Yang committed
153
154
155
156

	// define a range of gpu layers. anything outside of this range is assigned to the cpu
	gpuRangeStart := max(0, blocks-params.NumGPULayers)
	gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
Michael Yang's avatar
Michael Yang committed
157
	assignLayer := func(i int) deviceBufferType {
Michael Yang's avatar
Michael Yang committed
158
		if i < gpuRangeStart || i >= gpuRangeStop {
Michael Yang's avatar
Michael Yang committed
159
			return cpuDeviceBufferType
160
		}
161

Michael Yang's avatar
Michael Yang committed
162
		index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
163
		if index < 0 || index >= len(gpuDeviceBufferTypes) {
Michael Yang's avatar
Michael Yang committed
164
			return cpuDeviceBufferType
165
166
167
		}

		return gpuDeviceBufferTypes[index]
168
169
	}

Michael Yang's avatar
Michael Yang committed
170
	// repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
171
	layers := make([]deviceBufferType, blocks)
172
	for i := range layers {
173
		layers[i] = assignLayer(i)
174
175
	}

Michael Yang's avatar
Michael Yang committed
176
	// outputs are assigned iff allowed by splits and configured number of gpu layers
177
	output := assignLayer(blocks)
178
179
180

	maxTensors := len(meta.Tensors().Items())
	maxTensors += 1
Michael Yang's avatar
Michael Yang committed
181
	// each layer has at most 2 extra tensors for rope operations
182
183
	maxTensors += blocks * 2

184
185
186
187
188
	type tensor struct {
		source *fs.Tensor
		target string
	}

Michael Yang's avatar
Michael Yang committed
189
	// some tensors are mapped to different names so keep a list
190
191
	targets := make(map[string][]string)

Michael Yang's avatar
Michael Yang committed
192
	// contexts are shared by tensors of the same buffer type
193
	ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
194
	createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
195
196
197
198
199
200
201
		for _, bt := range bts {
			if _, ok := ctxs[bt]; !ok {
				ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
					mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
					no_alloc: true,
				})
			}
Michael Yang's avatar
Michael Yang committed
202

203
204
205
206
207
208
209
210
			targets[t.source.Name] = append(targets[t.source.Name], t.target)

			name := t.source.Name
			if t.target != "" {
				name = t.target
			}

			cname := C.CString(name)
Michael Yang's avatar
Michael Yang committed
211
			defer C.free(unsafe.Pointer(cname))
212
213
214
215
			if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
				return tt
			}

216
			tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
Michael Yang's avatar
Michael Yang committed
217
218
			C.ggml_set_name(tt, cname)

219
			slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
220
221
222
223
224
			//nolint:staticcheck // TODO: check if buffer type supports this tensor
			return tt
		}

		return nil
Michael Yang's avatar
Michael Yang committed
225
226
	}

227
	contains := func(s string, parts ...string) bool {
228
229
230
231
232
233
234
235
		split := strings.Split(s, ".")
		for _, part := range parts {
			if slices.Contains(split, part) {
				return true
			}
		}

		return false
Michael Yang's avatar
Michael Yang committed
236
237
	}

238
239
	for _, t := range meta.Tensors().Items() {
		switch {
240
		case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
241
			createTensor(tensor{source: t}, input.bts)
Michael Yang's avatar
Michael Yang committed
242
243
244
			if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
				createTensor(tensor{source: t, target: "output.weight"}, output.bts)
			}
245
		case contains(t.Name, "cls", "output", "output_norm"):
246
			createTensor(tensor{source: t}, output.bts)
247
		case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
Michael Yang's avatar
Michael Yang committed
248
			// TODO: assign vision tensors to the gpu if possible
Michael Yang's avatar
Michael Yang committed
249
			createTensor(tensor{source: t}, output.bts)
Michael Yang's avatar
Michael Yang committed
250
251
252
253
254
255
256
257
		case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
			// these tensors should be repeated per layer
			for i, layer := range layers {
				createTensor(tensor{
					source: t,
					target: "blk." + strconv.Itoa(i) + "." + t.Name,
				}, layer.bts)
			}
258
		default:
Michael Yang's avatar
Michael Yang committed
259
260
261
262
			layerIndex := -1
			if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
				if i, err := strconv.Atoi(fields[0]); err == nil {
					layerIndex = i
263
				}
Michael Yang's avatar
Michael Yang committed
264
			}
265

Michael Yang's avatar
Michael Yang committed
266
267
			if layerIndex >= 0 {
				createTensor(tensor{source: t}, layers[layerIndex].bts)
268
			} else {
Michael Yang's avatar
Michael Yang committed
269
270
				// load all other tensors on the cpu
				createTensor(tensor{source: t}, input.bts)
271
272
273
			}
		}
	}
Michael Yang's avatar
Michael Yang committed
274

Michael Yang's avatar
Michael Yang committed
275
276
	// allocate buffers for each context
	bbs := make(map[*C.struct_ggml_context]*C.struct_ggml_backend_buffer, len(ctxs))
277
278
279
280
281
282
283
	for bt, c := range ctxs {
		if C.ggml_get_first_tensor(c) == nil {
			continue
		}

		b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
		C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
Michael Yang's avatar
Michael Yang committed
284
		bbs[c] = b
285
286
287
	}

	for bs := range maps.Values(bbs) {
Michael Yang's avatar
Michael Yang committed
288
		slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
289
290
	}

Michael Yang's avatar
Michael Yang committed
291
	// map tensor names to tensors for easy lookup later
292
293
294
295
296
297
298
	tensors := make(map[string]*C.struct_ggml_tensor)
	for _, c := range ctxs {
		for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
			tensors[C.GoString(C.ggml_get_name(t))] = t
		}
	}

299
300
301
302
303
	var doneBytes atomic.Uint64
	totalBytes := uint64(n) - meta.Tensors().Offset

	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(runtime.GOMAXPROCS(0))
304
	for _, t := range meta.Tensors().Items() {
305
306
307
308
		g.Go(func() error {
			tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
			for i := range tts {
				target := targets[t.Name][i]
309
310
311
				if target == "" {
					target = t.Name
				}
312

313
314
315
316
				tt, ok := tensors[target]
				if !ok {
					return fmt.Errorf("unassigned tensor: %s", t.Name)
				}
Michael Yang's avatar
Michael Yang committed
317

318
319
320
321
322
323
324
325
326
327
328
				tts[i] = tt
			}

			sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
			bts := make([]byte, 128*format.KibiByte)

			var s uint64
			for s < t.Size() {
				n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
				if err != nil {
					return err
329
				}
Michael Yang's avatar
Michael Yang committed
330

331
332
				for _, tt := range tts {
					C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
333
				}
Michael Yang's avatar
Michael Yang committed
334

335
336
337
338
339
340
341
342
343
344
				s += uint64(n)

				if params.Progress != nil {
					done := doneBytes.Add(uint64(n))
					params.Progress(float32(done) / float32(totalBytes))
				}
			}

			return nil
		})
Michael Yang's avatar
Michael Yang committed
345
346
	}

347
348
349
350
351
352
353
354
	// start a goroutine to cancel the errgroup if the parent context is done
	go func() {
		<-ctx.Done()
		g.Go(func() error {
			return ctx.Err()
		})
	}()

355
	if err := g.Wait(); err != nil {
Michael Yang's avatar
Michael Yang committed
356
357
358
		return nil, err
	}

359
360
	// map devices to backend buffer types so new tensors can be assigned to the correct device
	deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
Michael Yang's avatar
Michael Yang committed
361
362
363
364

	// create backends and buffer types used for the compute graph scheduler
	var schedBackends []*C.struct_ggml_backend
	var schedBufts []*C.struct_ggml_backend_buffer_type
365
366
367
368
	for _, d := range append(gpus, append(accels, cpus...)...) {
		b := C.ggml_backend_dev_init(d, nil)
		bt := C.ggml_backend_get_default_buffer_type(b)
		if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
369
370
			// use the first gpu host buffer type for gpu if possible
			if hbt := C.ggml_backend_dev_host_buffer_type(gpus[0]); hbt != nil {
371
372
373
374
				bt = hbt
			}
		}

375
376
377
		deviceBufferTypes[d] = bt

		schedBackends = append(schedBackends, b)
Michael Yang's avatar
Michael Yang committed
378
		schedBufts = append(schedBufts, bt)
379

380
		slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
381
382

		if C.ggml_backend_is_cpu(b) {
Michael Yang's avatar
Michael Yang committed
383
			// set number of threads for cpu backend
Michael Yang's avatar
Michael Yang committed
384
			C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
385
		}
386
387
	}

Michael Yang's avatar
Michael Yang committed
388
	maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
Michael Yang's avatar
Michael Yang committed
389
	return &Backend{
390
		flashAttention: params.FlashAttention,
391
392
		meta:           meta,
		tensors:        tensors,
393
		sched: C.ggml_backend_sched_new(
Michael Yang's avatar
Michael Yang committed
394
395
396
397
			(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
			(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
			C.int(len(schedBackends)),
			C.size_t(maxGraphNodes),
398
			C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
399
		),
400
		input: deviceBufferTypes[input.d],
401
402
		layers: func() map[int]*C.struct_ggml_backend_buffer_type {
			m := make(map[int]*C.struct_ggml_backend_buffer_type)
403
			for i, layer := range layers {
404
				m[i] = deviceBufferTypes[layer.d]
405
406
407
			}
			return m
		}(),
Michael Yang's avatar
Michael Yang committed
408
		maxGraphNodes: maxGraphNodes,
Michael Yang's avatar
Michael Yang committed
409
410
411
412
413
414
415
416
417
418
419
420
	}, nil
}

func init() {
	ml.RegisterBackend("ggml", New)
}

func (b *Backend) Config() ml.Config {
	return b.meta.KV()
}

func (b *Backend) Get(name string) ml.Tensor {
421
422
	if t, ok := b.tensors[name]; ok {
		return &Tensor{b: b, t: t}
Michael Yang's avatar
Michael Yang committed
423
424
425
426
427
428
	}

	return nil
}

func (b *Backend) NewContext() ml.Context {
Michael Yang's avatar
Michael Yang committed
429
	return b.NewContextSize(b.maxGraphNodes)
430
431
432
}

func (b *Backend) NewContextSize(n int) ml.Context {
Jesse Gross's avatar
Jesse Gross committed
433
434
435
436
	if n > b.maxGraphNodes {
		panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
	}

Michael Yang's avatar
Michael Yang committed
437
	return &Context{
438
439
		b:             b,
		maxGraphNodes: n,
440
		ctx: C.ggml_init(C.struct_ggml_init_params{
441
			mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
442
443
			no_alloc: true,
		}),
Michael Yang's avatar
Michael Yang committed
444
445
446
	}
}

447
func (b *Backend) CacheConfig() ml.CacheConfig {
448
449
450
451
452
	if b.flashAttention {
		return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
	} else {
		return ml.CacheConfig{CachePadding: 32, PermutedV: true}
	}
453
454
}

Michael Yang's avatar
Michael Yang committed
455
type Context struct {
456
	b *Backend
Michael Yang's avatar
Michael Yang committed
457

458
	ctx   *C.struct_ggml_context
Michael Yang's avatar
Michael Yang committed
459
	graph *C.struct_ggml_cgraph
460

461
462
	// buft is the buffer type used for new tensors
	buft *C.struct_ggml_backend_buffer_type
463

Michael Yang's avatar
Michael Yang committed
464
	// maxGraphNodes is the maximum allowed number of graph nodes in this context
465
	maxGraphNodes int
Michael Yang's avatar
Michael Yang committed
466
467
}

Michael Yang's avatar
Michael Yang committed
468
469
func (c Context) Input() ml.Context {
	if c.b.input != nil {
470
471
472
		return &Context{
			b:             c.b,
			ctx:           c.ctx,
473
			buft:          c.b.input,
474
475
476
477
			maxGraphNodes: c.maxGraphNodes,
		}
	}

Michael Yang's avatar
Michael Yang committed
478
	return &c
479
480
}

Michael Yang's avatar
Michael Yang committed
481
func (c Context) Layer(i int) ml.Context {
482
	if buft, ok := c.b.layers[i]; ok {
483
484
485
		return &Context{
			b:             c.b,
			ctx:           c.ctx,
486
			buft:          buft,
487
488
489
490
			maxGraphNodes: c.maxGraphNodes,
		}
	}

Michael Yang's avatar
Michael Yang committed
491
	return &c
492
493
}

494
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
Michael Yang's avatar
Michael Yang committed
495
	if c.graph == nil {
496
		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
Michael Yang's avatar
Michael Yang committed
497
498
	}

499
500
501
502
503
	for _, tensor := range tensors {
		C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
	}

	return c
Michael Yang's avatar
Michael Yang committed
504
505
}

Michael Yang's avatar
Michael Yang committed
506
func (c Context) Compute(tensors ...ml.Tensor) {
507
	C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
Michael Yang's avatar
Michael Yang committed
508
	C.ggml_backend_sched_reset(c.b.sched)
Michael Yang's avatar
Michael Yang committed
509

510
511
512
	needSync := true
	sync := func() {
		if needSync {
513
			C.ggml_backend_sched_synchronize(c.b.sched)
514
515
516
			needSync = false
		}
	}
Michael Yang's avatar
Michael Yang committed
517

518
519
520
	for _, t := range tensors {
		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
			t.(*Tensor).sync = sync
521
522
		}
	}
Michael Yang's avatar
Michael Yang committed
523
524
}

Michael Yang's avatar
Michael Yang committed
525
func (c Context) MaxGraphNodes() int {
526
	return c.maxGraphNodes
Jesse Gross's avatar
Jesse Gross committed
527
528
}

529
530
531
func shapeToGGML(shape []int) *C.int64_t {
	sh := make([]C.int64_t, len(shape))
	for i, s := range shape {
532
		sh[i] = C.int64_t(s)
533
534
535
536
537
	}

	return &sh[0]
}

538
539
540
541
func pad(length, pad C.size_t) C.size_t {
	return ((length + pad - 1) / pad) * pad
}

542
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
543
544
545
546
	if c.buft == nil {
		panic("set Input, Output, or Layer before creating tensors")
	}

Michael Yang's avatar
Michael Yang committed
547
548
549
550
551
552
	var cdtype uint32
	switch dtype {
	case ml.DTypeF32:
		cdtype = C.GGML_TYPE_F32
	case ml.DTypeF16:
		cdtype = C.GGML_TYPE_F16
553
554
555
556
	case ml.DTypeQ80:
		cdtype = C.GGML_TYPE_Q8_0
	case ml.DTypeQ40:
		cdtype = C.GGML_TYPE_Q4_0
Michael Yang's avatar
Michael Yang committed
557
558
559
560
561
562
	case ml.DTypeI32:
		cdtype = C.GGML_TYPE_I32
	default:
		panic("unsupported dtype")
	}

Jesse Gross's avatar
Jesse Gross committed
563
	if len(shape) < 1 || shape[0] == 0 {
Michael Yang's avatar
Michael Yang committed
564
565
566
		var shape C.int64_t = 0
		return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
	} else if len(shape) > 4 {
Michael Yang's avatar
Michael Yang committed
567
568
569
570
571
572
573
574
575
		panic("unsupported number of dimensions")
	}

	for _, dim := range shape {
		if dim < 1 {
			panic("invalid shape")
		}
	}

Michael Yang's avatar
Michael Yang committed
576
	t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
577
578
	size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
	b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
Michael Yang's avatar
Michael Yang committed
579
	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
580
	return &Tensor{b: c.b, t: t}
581
582
583
}

func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
584
	return c.newTensor(dtype, shape)
585
586
587
}

func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
588
	t := c.newTensor(dtype, shape)
589
590
	C.ggml_set_zero(t.(*Tensor).t)
	return t
Michael Yang's avatar
Michael Yang committed
591
592
}

593
func checkShape[S ~[]E, E any](s S, shape ...int) error {
Michael Yang's avatar
Michael Yang committed
594
	n := len(s)
Jesse Gross's avatar
Jesse Gross committed
595
596
597
598
599

	if n == 0 {
		return nil
	}

Michael Yang's avatar
Michael Yang committed
600
601
602
603
604
	for _, v := range shape {
		n /= v
	}

	if n != 1 {
605
		return fmt.Errorf("invalid shape: %v", shape)
Michael Yang's avatar
Michael Yang committed
606
607
	}

608
	return nil
Michael Yang's avatar
Michael Yang committed
609
610
611
}

func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
Jesse Gross's avatar
Jesse Gross committed
612
	if err := checkShape(s, shape...); err != nil {
613
614
615
616
		return nil, err
	}

	t := c.newTensor(ml.DTypeF32, shape)
Jesse Gross's avatar
Jesse Gross committed
617
618
619
620
	if len(s) > 0 {
		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	}

621
	return t, nil
Michael Yang's avatar
Michael Yang committed
622
623
624
}

func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
Jesse Gross's avatar
Jesse Gross committed
625
	if err := checkShape(s, shape...); err != nil {
626
627
628
629
		return nil, err
	}

	t := c.newTensor(ml.DTypeI32, shape)
Jesse Gross's avatar
Jesse Gross committed
630
631
632
633
	if len(s) > 0 {
		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	}

634
	return t, nil
Michael Yang's avatar
Michael Yang committed
635
636
}

Michael Yang's avatar
Michael Yang committed
637
638
func (c *Context) Close() {
	if c != nil {
639
640
		C.ggml_free(c.ctx)
	}
Michael Yang's avatar
Michael Yang committed
641
642
643
}

type Tensor struct {
644
	b    *Backend
Michael Yang's avatar
Michael Yang committed
645
	t    *C.struct_ggml_tensor
646
	sync func()
Michael Yang's avatar
Michael Yang committed
647
648
649
650
651
652
653
654
655
656
}

func (t *Tensor) LogValue() slog.Value {
	return slog.GroupValue(
		slog.String("name", C.GoString(C.ggml_get_name(t.t))),
		slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
		slog.Any("shape", t.Shape()),
	)
}

657
658
func (t *Tensor) Dim(n int) int {
	return int(t.t.ne[n])
Michael Yang's avatar
Michael Yang committed
659
660
}

661
662
func (t *Tensor) Stride(n int) int {
	return int(t.t.nb[n])
Michael Yang's avatar
Michael Yang committed
663
664
}

665
666
func (t *Tensor) Shape() []int {
	shape := make([]int, C.ggml_n_dims(t.t))
Michael Yang's avatar
Michael Yang committed
667
668
669
670
671
672
673
	for i := range shape {
		shape[i] = t.Dim(i)
	}

	return shape
}

674
675
676
677
678
679
680
681
682
func (t *Tensor) Bytes() (data []byte) {
	if t.sync != nil {
		data = make([]byte, C.ggml_nbytes(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
	}

	return
Michael Yang's avatar
Michael Yang committed
683
684
}

685
686
687
688
689
690
func (t *Tensor) Floats() (data []float32) {
	if t.sync != nil {
		data = make([]float32, C.ggml_nelements(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
Michael Yang's avatar
Michael Yang committed
691
692
693
694
695
696
697
698
699
	}

	return
}

func (t *Tensor) DType() ml.DType {
	switch t.t._type {
	case C.GGML_TYPE_F32:
		return ml.DTypeF32
Jesse Gross's avatar
Jesse Gross committed
700
701
	case C.GGML_TYPE_F16:
		return ml.DTypeF16
702
703
704
705
	case C.GGML_TYPE_Q8_0:
		return ml.DTypeQ80
	case C.GGML_TYPE_Q4_0:
		return ml.DTypeQ40
Michael Yang's avatar
Michael Yang committed
706
707
708
709
710
711
712
713
714
	case C.GGML_TYPE_I32:
		return ml.DTypeI32
	default:
		return ml.DTypeOther
	}
}

func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
715
		b: t.b,
Michael Yang's avatar
Michael Yang committed
716
717
718
719
720
721
722
723
724
725
726
727
728
729
		t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
	if len(s) > 0 {
		return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
	}

	return t
}

func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
	return &Tensor{
730
		b: t.b,
Michael Yang's avatar
Michael Yang committed
731
732
733
734
735
736
		t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
	}
}

func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
	return &Tensor{
737
		b: t.b,
Michael Yang's avatar
Michael Yang committed
738
739
740
741
742
743
		t: C.ggml_cont(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
744
		b: t.b,
Michael Yang's avatar
Michael Yang committed
745
746
747
748
749
750
		t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
751
		b: t.b,
Michael Yang's avatar
Michael Yang committed
752
753
754
755
		t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

756
757
758
759
760
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
	C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)

	return &Tensor{
761
		b: t.b,
762
763
764
765
		t: mul,
	}
}

Michael Yang's avatar
Michael Yang committed
766
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
767
	tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
Michael Yang's avatar
Michael Yang committed
768
769
770
771
772
773
774
775
	if b != nil {
		tt = tt.Add(ctx, b)
	}

	return tt
}

func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
776
	return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
Michael Yang's avatar
Michael Yang committed
777
778
}

779
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
780
781
782
783
784
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
785
		b: t.b,
Michael Yang's avatar
Michael Yang committed
786
787
788
789
790
791
792
793
794
795
		t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
796
		b: t.b,
Michael Yang's avatar
Michael Yang committed
797
798
799
800
801
802
		t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
803
		b: t.b,
Michael Yang's avatar
Michael Yang committed
804
805
806
807
808
809
		t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
810
		b: t.b,
Michael Yang's avatar
Michael Yang committed
811
812
813
814
		t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

815
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
816
817
818
	switch len(shape) {
	case 1:
		return &Tensor{
819
			b: t.b,
Michael Yang's avatar
Michael Yang committed
820
821
822
823
			t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
		}
	case 2:
		return &Tensor{
824
			b: t.b,
Michael Yang's avatar
Michael Yang committed
825
826
827
828
			t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
		}
	case 3:
		return &Tensor{
829
			b: t.b,
Michael Yang's avatar
Michael Yang committed
830
831
832
833
			t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
		}
	case 4:
		return &Tensor{
834
			b: t.b,
Michael Yang's avatar
Michael Yang committed
835
836
837
838
839
840
841
842
843
			t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
	return &Tensor{
844
		b: t.b,
Michael Yang's avatar
Michael Yang committed
845
846
847
848
849
850
		t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
	}
}

func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
	return &Tensor{
851
		b: t.b,
Michael Yang's avatar
Michael Yang committed
852
853
854
855
856
857
		t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
	return &Tensor{
858
		b: t.b,
Michael Yang's avatar
Michael Yang committed
859
860
861
862
		t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
	}
}

863
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
864
865
866
867
868
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
869
		b: t.b,
Michael Yang's avatar
Michael Yang committed
870
871
872
873
874
875
876
877
		t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
	switch len(shape) {
	case 1:
		return &Tensor{
878
			b: t.b,
Michael Yang's avatar
Michael Yang committed
879
880
881
882
			t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
		}
	case 3:
		return &Tensor{
883
			b: t.b,
Michael Yang's avatar
Michael Yang committed
884
885
886
887
888
889
890
			t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]),
				C.size_t(shape[1]),
				C.size_t(offset)),
		}
	case 5:
		return &Tensor{
891
			b: t.b,
Michael Yang's avatar
Michael Yang committed
892
893
894
895
896
897
898
			t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
				C.size_t(shape[1]), C.size_t(shape[3]),
				C.size_t(offset)),
		}
	case 7:
		return &Tensor{
899
			b: t.b,
Michael Yang's avatar
Michael Yang committed
900
901
902
903
904
905
906
907
908
909
910
			t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
				C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
				C.size_t(offset)),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

const (
Patrick Devine's avatar
Patrick Devine committed
911
912
913
914
	ropeTypeNorm   C.int = 0
	ropeTypeNeox   C.int = 2
	ropeTypeMrope  C.int = 8
	ropeTypeVision C.int = 24
Michael Yang's avatar
Michael Yang committed
915
916
)

Patrick Devine's avatar
Patrick Devine committed
917
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
918
	if ropeFactors == nil {
919
		ropeFactors = &Tensor{b: t.b}
Michael Yang's avatar
Michael Yang committed
920
921
	}

Jesse Gross's avatar
Jesse Gross committed
922
923
924
925
926
	dequant := t.t
	if C.ggml_is_quantized(t.t._type) {
		dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
	}

Michael Yang's avatar
Michael Yang committed
927
	return &Tensor{
928
		b: t.b,
Michael Yang's avatar
Michael Yang committed
929
		t: C.ggml_rope_ext(
Jesse Gross's avatar
Jesse Gross committed
930
			ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
Michael Yang's avatar
Michael Yang committed
931
			C.int(ropeDim),
Patrick Devine's avatar
Patrick Devine committed
932
933
			C.int(ropeType),
			131072, // YaRN n_ctx_train
Michael Yang's avatar
Michael Yang committed
934
935
936
937
938
939
940
941
942
943
944
945
			C.float(ropeBase),
			C.float(ropeScale),
			0.,  // YaRN ext_factor
			1.,  // YaRN attn_factor
			32., // YaRN beta_fast
			1.,  // YaRN beta_slow
		),
	}
}

func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
	return &Tensor{
946
		b: t.b,
Michael Yang's avatar
Michael Yang committed
947
948
949
950
951
952
		t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
	return &Tensor{
953
		b: t.b,
Michael Yang's avatar
Michael Yang committed
954
955
956
957
958
959
		t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
960
		b: t.b,
Michael Yang's avatar
Michael Yang committed
961
962
963
		t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
	}
}
964

Michael Yang's avatar
Michael Yang committed
965
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
966
967
	return &Tensor{
		b: t.b,
Michael Yang's avatar
Michael Yang committed
968
		t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
Michael Yang's avatar
Michael Yang committed
969
970
971
	}
}

Michael Yang's avatar
Michael Yang committed
972
973
974
975
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
	var tt *C.struct_ggml_tensor
	switch len(strides) {
	case 0:
Michael Yang's avatar
Michael Yang committed
976
		tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
Michael Yang's avatar
Michael Yang committed
977
	case 1:
Michael Yang's avatar
Michael Yang committed
978
		tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
Michael Yang's avatar
Michael Yang committed
979
980
981
982
983
984
985
	default:
		panic("unsupported number of dimensions")
	}

	return &Tensor{b: t.b, t: tt}
}

986
987
988
989
990
991
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
	var kqMask *C.struct_ggml_tensor
	if mask != nil {
		kqMask = mask.(*Tensor).t
	}

992
993
994
	query := t.Permute(ctx, 0, 2, 1, 3)
	key = key.Permute(ctx, 0, 2, 1, 3)

995
996
	if t.b.flashAttention {
		value = value.Permute(ctx, 0, 2, 1, 3)
997

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
		kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
		C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
		return &Tensor{b: t.b, t: kqv}
	} else {
		kq := key.MulmatFullPrec(ctx, query)
		kq = &Tensor{
			b: t.b,
			t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
		}

		kqv := value.Mulmat(ctx, kq)
		return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	}
1011
}