ggml.go 21.1 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
package ggml

3
4
5
6
7
8
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
// #include <stdlib.h>
// #include <stdint.h>
// #include "ggml.h"
// #include "ggml-cpu.h"
// #include "ggml-backend.h"
Michael Yang's avatar
Michael Yang committed
9
10
11
import "C"

import (
12
	"errors"
Michael Yang's avatar
Michael Yang committed
13
14
	"fmt"
	"io"
15
	"iter"
Michael Yang's avatar
Michael Yang committed
16
	"log/slog"
17
	"maps"
Michael Yang's avatar
Michael Yang committed
18
	"os"
19
20
21
22
	"slices"
	"strconv"
	"strings"
	"unicode"
Michael Yang's avatar
Michael Yang committed
23
24
25
26
27
	"unsafe"

	"github.com/ollama/ollama/format"
	fs "github.com/ollama/ollama/fs/ggml"
	"github.com/ollama/ollama/ml"
28
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
Michael Yang's avatar
Michael Yang committed
29
30
31
	"golang.org/x/sync/errgroup"
)

32
33
func devices() iter.Seq[*C.struct_ggml_backend_device] {
	return func(yield func(*C.struct_ggml_backend_device) bool) {
34
		ggml.OnceLoad()
35
36
37
38
39
		for i := range C.ggml_backend_dev_count() {
			if !yield(C.ggml_backend_dev_get(i)) {
				return
			}
		}
Michael Yang's avatar
Michael Yang committed
40
	}
41
}
Michael Yang's avatar
Michael Yang committed
42
43

type Backend struct {
44
45
46
47
48
49
	meta    *fs.GGML
	sched   *C.struct_ggml_backend_sched
	tensors map[string]*C.struct_ggml_tensor
	input   *C.struct_ggml_backend
	output  *C.struct_ggml_backend
	layers  map[int]*C.struct_ggml_backend
50

51
	flashAttention bool
Michael Yang's avatar
Michael Yang committed
52
53
}

54
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
Michael Yang's avatar
Michael Yang committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
	meta, n, err := fs.Decode(r, -1)
	if err != nil {
		return nil, err
	}

	slog.Info(
		"",
		"architecture", meta.KV().Architecture(),
		"file_type", meta.KV().FileType(),
		"name", meta.KV().String("general.name"),
		"description", meta.KV().String("general.description"),
		"num_tensors", len(meta.Tensors().Items()),
		"num_key_values", len(meta.KV()),
	)

70
71
72
73
74
75
76
77
78
79
80
81
	type dbt struct {
		d   *C.struct_ggml_backend_device
		bts []*C.struct_ggml_backend_buffer_type
	}

	var cpus, accels, gpus []*C.struct_ggml_backend_device
	for d := range devices() {
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU:
			cpus = append(cpus, d)
		case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
			accels = append(accels, d)
Michael Yang's avatar
Michael Yang committed
82
		case C.GGML_BACKEND_DEVICE_TYPE_GPU:
83
			gpus = append(gpus, d)
Michael Yang's avatar
Michael Yang committed
84
85
86
		}
	}

87
88
89
90
91
92
	var cpuBufferTypes []*C.struct_ggml_backend_buffer_type
	for _, d := range append(accels, append(gpus, cpus...)...) {
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU,
			C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
			cpuBufferTypes = append(cpuBufferTypes, C.ggml_backend_dev_buffer_type(d))
Michael Yang's avatar
Michael Yang committed
93
		}
94
95
96
97
	}

	var sum uint64
	var cumsum []uint64
Michael Yang's avatar
Michael Yang committed
98

99
100
101
102
103
104
105
106
107
108
109
110
	var gpuBufferTypes []dbt
	for _, d := range gpus {
		var free, total C.size_t
		C.ggml_backend_dev_memory(d, &free, &total)
		sum += uint64(free)
		cumsum = append(cumsum, sum)

		bt := C.ggml_backend_dev_buffer_type(d)
		gpuBufferTypes = append(gpuBufferTypes, dbt{
			d:   d,
			bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuBufferTypes...),
		})
Michael Yang's avatar
Michael Yang committed
111
112
	}

113
114
115
116
117
118
119
120
121
122
123
	splits := make([]float64, len(cumsum))
	for i := range splits {
		splits[i] = float64(cumsum[i]) / float64(sum)
	}

	input := dbt{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}

	var blocks int
	for key, value := range meta.KV() {
		if strings.HasSuffix(key, ".block_count") {
			blocks += int(value.(uint32))
Michael Yang's avatar
Michael Yang committed
124
		}
125
	}
Michael Yang's avatar
Michael Yang committed
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
	indexFunc := func(i int) func(float64) bool {
		return func(f float64) bool {
			return float64(i)/float64(blocks+1) < f
		}
	}

	layers := make([]dbt, blocks)
	for i := range layers {
		layers[i] = gpuBufferTypes[slices.IndexFunc(splits, indexFunc(i))]
	}

	output := gpuBufferTypes[slices.IndexFunc(splits, indexFunc(blocks))]

	maxTensors := len(meta.Tensors().Items())
	maxTensors += 1
	maxTensors += blocks * 2

144
145
146
147
148
149
150
	type tensor struct {
		source *fs.Tensor
		target string
	}

	targets := make(map[string][]string)

151
	ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
152
	createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
153
154
155
156
157
158
159
		for _, bt := range bts {
			if _, ok := ctxs[bt]; !ok {
				ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
					mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
					no_alloc: true,
				})
			}
Michael Yang's avatar
Michael Yang committed
160

161
162
163
164
165
166
167
168
			targets[t.source.Name] = append(targets[t.source.Name], t.target)

			name := t.source.Name
			if t.target != "" {
				name = t.target
			}

			cname := C.CString(name)
Michael Yang's avatar
Michael Yang committed
169
			defer C.free(unsafe.Pointer(cname))
170
171
172
173
			if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
				return tt
			}

174
			tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
Michael Yang's avatar
Michael Yang committed
175
176
			C.ggml_set_name(tt, cname)

177
			slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
178
179
180
181
182
			//nolint:staticcheck // TODO: check if buffer type supports this tensor
			return tt
		}

		return nil
Michael Yang's avatar
Michael Yang committed
183
184
	}

185
186
187
188
189
190
191
192
193
	hasPart := func(s string, parts ...string) bool {
		split := strings.Split(s, ".")
		for _, part := range parts {
			if slices.Contains(split, part) {
				return true
			}
		}

		return false
Michael Yang's avatar
Michael Yang committed
194
195
	}

196
197
198
	for _, t := range meta.Tensors().Items() {
		switch {
		case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
199
			createTensor(tensor{source: t}, input.bts)
200
		case hasPart(t.Name, "cls", "output", "output_norm"):
201
			createTensor(tensor{source: t}, output.bts)
202
203
204
205
206
207
208
209
210
211
		default:
			if i := func() int {
				if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
					if i, err := strconv.Atoi(fields[0]); err == nil {
						return i
					}
				}

				return -1
			}(); i >= 0 {
212
				createTensor(tensor{source: t}, layers[i].bts)
213
			} else {
214
215
216
217
218
				for i, layer := range layers {
					createTensor(tensor{
						source: t,
						target: "blk." + strconv.Itoa(i) + "." + t.Name,
					}, layer.bts)
219
220
221
222
				}
			}
		}
	}
Michael Yang's avatar
Michael Yang committed
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
	bbs := make(map[*C.struct_ggml_context][]*C.struct_ggml_backend_buffer, len(ctxs))

	for bt, c := range ctxs {
		if C.ggml_get_first_tensor(c) == nil {
			continue
		}

		b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
		C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
		bbs[c] = append(bbs[c], b)
	}

	for bs := range maps.Values(bbs) {
		for _, b := range bs {
238
			slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(b)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(b))))
239
240
241
242
243
244
245
246
247
248
249
		}
	}

	tensors := make(map[string]*C.struct_ggml_tensor)
	for _, c := range ctxs {
		for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
			tensors[C.GoString(C.ggml_get_name(t))] = t
		}
	}

	sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
Michael Yang's avatar
Michael Yang committed
250
	var g errgroup.Group
251
	for _, t := range meta.Tensors().Items() {
252
253
254
255
256
		for _, target := range targets[t.Name] {
			g.Go(func() error {
				if target == "" {
					target = t.Name
				}
257

258
259
260
261
				tt, ok := tensors[target]
				if !ok {
					return fmt.Errorf("unassigned tensor: %s", t.Name)
				}
Michael Yang's avatar
Michael Yang committed
262

263
264
265
266
267
				bts := make([]byte, t.Size())
				n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
				if err != nil {
					return err
				}
Michael Yang's avatar
Michael Yang committed
268

269
270
271
				if n != len(bts) {
					return errors.New("short read")
				}
Michael Yang's avatar
Michael Yang committed
272

273
274
275
276
277
278
279
				cname := C.CString(t.Name)
				C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
				C.free(unsafe.Pointer(cname))

				return nil
			})
		}
Michael Yang's avatar
Michael Yang committed
280
281
	}

282
	if g.Wait() != nil {
Michael Yang's avatar
Michael Yang committed
283
284
285
		return nil, err
	}

286
	deviceBackends := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend)
287
288
289
290
291
	var backends []*C.struct_ggml_backend
	var bufts []*C.struct_ggml_backend_buffer_type
	for _, d := range append(gpus, append(accels, cpus...)...) {
		b := C.ggml_backend_dev_init(d, nil)
		backends = append(backends, b)
292
		deviceBackends[d] = b
293
294
295
296
297
298
299
300
301
302

		bt := C.ggml_backend_get_default_buffer_type(b)
		if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
			if hbt := C.ggml_backend_dev_host_buffer_type(d); hbt != nil {
				bt = hbt
			}
		}

		bufts = append(bufts, bt)

303
		slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
304
305
	}

Michael Yang's avatar
Michael Yang committed
306
	return &Backend{
307
		flashAttention: params.FlashAttention,
308
309
		meta:           meta,
		tensors:        tensors,
310
311
312
313
314
315
316
		sched: C.ggml_backend_sched_new(
			(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
			(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
			C.int(len(backends)),
			C.size_t(max(8192, len(meta.Tensors().Items())*5)),
			true,
		),
317
318
319
320
321
322
323
324
325
		input:  deviceBackends[input.d],
		output: deviceBackends[output.d],
		layers: func() map[int]*C.struct_ggml_backend {
			m := make(map[int]*C.struct_ggml_backend)
			for i, layer := range layers {
				m[i] = deviceBackends[layer.d]
			}
			return m
		}(),
Michael Yang's avatar
Michael Yang committed
326
327
328
329
330
331
332
333
334
335
336
337
	}, nil
}

func init() {
	ml.RegisterBackend("ggml", New)
}

func (b *Backend) Config() ml.Config {
	return b.meta.KV()
}

func (b *Backend) Get(name string) ml.Tensor {
338
339
	if t, ok := b.tensors[name]; ok {
		return &Tensor{b: b, t: t}
Michael Yang's avatar
Michael Yang committed
340
341
342
343
344
345
	}

	return nil
}

func (b *Backend) NewContext() ml.Context {
346
347
348
349
	return b.NewContextSize(max(8192, len(b.meta.Tensors().Items())*5))
}

func (b *Backend) NewContextSize(n int) ml.Context {
Michael Yang's avatar
Michael Yang committed
350
	return &Context{
351
		b: b,
352
		ctx: C.ggml_init(C.struct_ggml_init_params{
353
			mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
354
355
			no_alloc: true,
		}),
356
		backend:       C.ggml_backend_sched_get_backend(b.sched, 0),
357
358
359
360
		maxGraphNodes: n,
		input:         b.input,
		output:        b.output,
		layers:        b.layers,
Michael Yang's avatar
Michael Yang committed
361
362
363
	}
}

364
func (b *Backend) CacheConfig() ml.CacheConfig {
365
366
367
368
369
	if b.flashAttention {
		return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
	} else {
		return ml.CacheConfig{CachePadding: 32, PermutedV: true}
	}
370
371
}

Michael Yang's avatar
Michael Yang committed
372
type Context struct {
373
	b *Backend
Michael Yang's avatar
Michael Yang committed
374

375
	ctx   *C.struct_ggml_context
Michael Yang's avatar
Michael Yang committed
376
	graph *C.struct_ggml_cgraph
377
378

	// backend is the backend used for new tensors
379
	backend *C.struct_ggml_backend
380

381
382
383
384
385
386
387
388
389
	// input is the backend used for inputs
	input *C.struct_ggml_backend

	// output is the backend used for outputs
	output *C.struct_ggml_backend

	// output is the backend used for repeating layers
	layers map[int]*C.struct_ggml_backend

390
	maxGraphNodes int
Michael Yang's avatar
Michael Yang committed
391
392
}

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
func (c *Context) Input() ml.Context {
	if c.input != nil {
		return &Context{
			b:             c.b,
			ctx:           c.ctx,
			backend:       c.input,
			maxGraphNodes: c.maxGraphNodes,
		}
	}

	return c
}

func (c *Context) Output() ml.Context {
	if c.output != nil {
		return &Context{
			b:             c.b,
			ctx:           c.ctx,
			backend:       c.output,
			maxGraphNodes: c.maxGraphNodes,
		}
	}

	return c
}

func (c *Context) Layer(i int) ml.Context {
	if backend, ok := c.layers[i]; ok {
		return &Context{
			b:             c.b,
			ctx:           c.ctx,
			backend:       backend,
			maxGraphNodes: c.maxGraphNodes,
		}
	}

	return c
}

432
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
Michael Yang's avatar
Michael Yang committed
433
	if c.graph == nil {
434
		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
Michael Yang's avatar
Michael Yang committed
435
436
	}

437
438
439
440
441
	for _, tensor := range tensors {
		C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
	}

	return c
Michael Yang's avatar
Michael Yang committed
442
443
}

444
func (c *Context) Compute(tensors ...ml.Tensor) {
445
	C.ggml_backend_sched_reset(c.b.sched)
446
447
	C.ggml_backend_sched_alloc_graph(c.b.sched, c.graph)
	C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
Michael Yang's avatar
Michael Yang committed
448

449
450
451
	needSync := true
	sync := func() {
		if needSync {
452
			C.ggml_backend_sched_synchronize(c.b.sched)
453
454
455
			needSync = false
		}
	}
Michael Yang's avatar
Michael Yang committed
456

457
458
459
	for _, t := range tensors {
		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
			t.(*Tensor).sync = sync
460
461
		}
	}
Michael Yang's avatar
Michael Yang committed
462
463
}

464
465
func (c *Context) MaxGraphNodes() int {
	return c.maxGraphNodes
Jesse Gross's avatar
Jesse Gross committed
466
467
}

468
469
470
func shapeToGGML(shape []int) *C.int64_t {
	sh := make([]C.int64_t, len(shape))
	for i, s := range shape {
471
		sh[i] = C.int64_t(s)
472
473
474
475
476
	}

	return &sh[0]
}

477
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
478
479
480
481
482
483
484
485
486
487
488
489
490
	if len(shape) < 1 || len(shape) > 4 {
		panic("unsupported number of dimensions")
	}

	for _, dim := range shape {
		if dim < 1 {
			panic("invalid shape")
		}
	}

	var t *C.struct_ggml_tensor
	switch dtype {
	case ml.DTypeF32:
491
		t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
Jesse Gross's avatar
Jesse Gross committed
492
	case ml.DTypeF16:
493
		t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
Michael Yang's avatar
Michael Yang committed
494
	case ml.DTypeI32:
495
		t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
Michael Yang's avatar
Michael Yang committed
496
497
498
499
	default:
		panic("unsupported dtype")
	}

500
	b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
Michael Yang's avatar
Michael Yang committed
501
	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
502
	return &Tensor{b: c.b, t: t}
503
504
505
}

func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
506
	return c.newTensor(dtype, shape)
507
508
509
}

func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
510
	t := c.newTensor(dtype, shape)
511
512
	C.ggml_set_zero(t.(*Tensor).t)
	return t
Michael Yang's avatar
Michael Yang committed
513
514
}

515
func checkShape[S ~[]E, E any](s S, shape ...int) error {
Michael Yang's avatar
Michael Yang committed
516
517
518
519
520
521
	n := len(s)
	for _, v := range shape {
		n /= v
	}

	if n != 1 {
522
		return fmt.Errorf("invalid shape: %v", shape)
Michael Yang's avatar
Michael Yang committed
523
524
	}

525
	return nil
Michael Yang's avatar
Michael Yang committed
526
527
528
}

func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
529
530
531
532
533
534
535
	if err := checkShape(s, shape...); err != nil {
		return nil, err
	}

	t := c.newTensor(ml.DTypeF32, shape)
	C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	return t, nil
Michael Yang's avatar
Michael Yang committed
536
537
538
}

func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
539
540
541
542
543
544
545
	if err := checkShape(s, shape...); err != nil {
		return nil, err
	}

	t := c.newTensor(ml.DTypeI32, shape)
	C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	return t, nil
Michael Yang's avatar
Michael Yang committed
546
547
}

548
549
func (c Context) Close() {
	if c.ctx != nil {
550
551
		C.ggml_free(c.ctx)
	}
Michael Yang's avatar
Michael Yang committed
552
553
554
}

type Tensor struct {
555
	b    *Backend
Michael Yang's avatar
Michael Yang committed
556
	t    *C.struct_ggml_tensor
557
	sync func()
Michael Yang's avatar
Michael Yang committed
558
559
560
561
562
563
564
565
566
567
}

func (t *Tensor) LogValue() slog.Value {
	return slog.GroupValue(
		slog.String("name", C.GoString(C.ggml_get_name(t.t))),
		slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
		slog.Any("shape", t.Shape()),
	)
}

568
569
func (t *Tensor) Dim(n int) int {
	return int(t.t.ne[n])
Michael Yang's avatar
Michael Yang committed
570
571
}

572
573
func (t *Tensor) Stride(n int) int {
	return int(t.t.nb[n])
Michael Yang's avatar
Michael Yang committed
574
575
}

576
577
func (t *Tensor) Shape() []int {
	shape := make([]int, C.ggml_n_dims(t.t))
Michael Yang's avatar
Michael Yang committed
578
579
580
581
582
583
584
	for i := range shape {
		shape[i] = t.Dim(i)
	}

	return shape
}

585
586
587
588
589
590
591
592
593
func (t *Tensor) Bytes() (data []byte) {
	if t.sync != nil {
		data = make([]byte, C.ggml_nbytes(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
	}

	return
Michael Yang's avatar
Michael Yang committed
594
595
}

596
597
598
599
600
601
func (t *Tensor) Floats() (data []float32) {
	if t.sync != nil {
		data = make([]float32, C.ggml_nelements(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
Michael Yang's avatar
Michael Yang committed
602
603
604
605
606
607
608
609
610
	}

	return
}

func (t *Tensor) DType() ml.DType {
	switch t.t._type {
	case C.GGML_TYPE_F32:
		return ml.DTypeF32
Jesse Gross's avatar
Jesse Gross committed
611
612
	case C.GGML_TYPE_F16:
		return ml.DTypeF16
Michael Yang's avatar
Michael Yang committed
613
614
615
616
617
618
619
620
621
	case C.GGML_TYPE_I32:
		return ml.DTypeI32
	default:
		return ml.DTypeOther
	}
}

func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
622
		b: t.b,
Michael Yang's avatar
Michael Yang committed
623
624
625
626
627
628
629
630
631
632
633
634
635
636
		t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
	if len(s) > 0 {
		return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
	}

	return t
}

func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
	return &Tensor{
637
		b: t.b,
Michael Yang's avatar
Michael Yang committed
638
639
640
641
642
643
		t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
	}
}

func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
	return &Tensor{
644
		b: t.b,
Michael Yang's avatar
Michael Yang committed
645
646
647
648
649
650
		t: C.ggml_cont(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
651
		b: t.b,
Michael Yang's avatar
Michael Yang committed
652
653
654
655
656
657
		t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
658
		b: t.b,
Michael Yang's avatar
Michael Yang committed
659
660
661
662
		t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

663
664
665
666
667
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
	C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)

	return &Tensor{
668
		b: t.b,
669
670
671
672
		t: mul,
	}
}

Michael Yang's avatar
Michael Yang committed
673
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
674
	tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
Michael Yang's avatar
Michael Yang committed
675
676
677
678
679
680
681
682
	if b != nil {
		tt = tt.Add(ctx, b)
	}

	return tt
}

func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
683
	return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
Michael Yang's avatar
Michael Yang committed
684
685
}

686
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
687
688
689
690
691
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
692
		b: t.b,
Michael Yang's avatar
Michael Yang committed
693
694
695
696
697
698
699
700
701
702
		t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
703
		b: t.b,
Michael Yang's avatar
Michael Yang committed
704
705
706
707
708
709
		t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
710
		b: t.b,
Michael Yang's avatar
Michael Yang committed
711
712
713
714
715
716
		t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
717
		b: t.b,
Michael Yang's avatar
Michael Yang committed
718
719
720
721
		t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

722
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
723
724
725
	switch len(shape) {
	case 1:
		return &Tensor{
726
			b: t.b,
Michael Yang's avatar
Michael Yang committed
727
728
729
730
			t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
		}
	case 2:
		return &Tensor{
731
			b: t.b,
Michael Yang's avatar
Michael Yang committed
732
733
734
735
			t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
		}
	case 3:
		return &Tensor{
736
			b: t.b,
Michael Yang's avatar
Michael Yang committed
737
738
739
740
			t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
		}
	case 4:
		return &Tensor{
741
			b: t.b,
Michael Yang's avatar
Michael Yang committed
742
743
744
745
746
747
748
749
750
			t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
	return &Tensor{
751
		b: t.b,
Michael Yang's avatar
Michael Yang committed
752
753
754
755
756
757
		t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
	}
}

func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
	return &Tensor{
758
		b: t.b,
Michael Yang's avatar
Michael Yang committed
759
760
761
762
763
764
		t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
	return &Tensor{
765
		b: t.b,
Michael Yang's avatar
Michael Yang committed
766
767
768
769
		t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
	}
}

770
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
771
772
773
774
775
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
776
		b: t.b,
Michael Yang's avatar
Michael Yang committed
777
778
779
780
781
782
783
784
		t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
	switch len(shape) {
	case 1:
		return &Tensor{
785
			b: t.b,
Michael Yang's avatar
Michael Yang committed
786
787
788
789
			t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
		}
	case 3:
		return &Tensor{
790
			b: t.b,
Michael Yang's avatar
Michael Yang committed
791
792
793
794
795
796
797
			t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]),
				C.size_t(shape[1]),
				C.size_t(offset)),
		}
	case 5:
		return &Tensor{
798
			b: t.b,
Michael Yang's avatar
Michael Yang committed
799
800
801
802
803
804
805
			t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
				C.size_t(shape[1]), C.size_t(shape[3]),
				C.size_t(offset)),
		}
	case 7:
		return &Tensor{
806
			b: t.b,
Michael Yang's avatar
Michael Yang committed
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
			t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
				C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
				C.size_t(offset)),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

const (
	ropeTypeNorm C.int = iota
)

func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
	if ropeFactors == nil {
823
		ropeFactors = &Tensor{b: t.b}
Michael Yang's avatar
Michael Yang committed
824
825
	}

Jesse Gross's avatar
Jesse Gross committed
826
827
828
829
830
	dequant := t.t
	if C.ggml_is_quantized(t.t._type) {
		dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
	}

Michael Yang's avatar
Michael Yang committed
831
	return &Tensor{
832
		b: t.b,
Michael Yang's avatar
Michael Yang committed
833
		t: C.ggml_rope_ext(
Jesse Gross's avatar
Jesse Gross committed
834
			ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
Michael Yang's avatar
Michael Yang committed
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
			C.int(ropeDim),
			131072,       // YaRN n_ctx_train
			ropeTypeNorm, // ROPE_TYPE_NORM
			C.float(ropeBase),
			C.float(ropeScale),
			0.,  // YaRN ext_factor
			1.,  // YaRN attn_factor
			32., // YaRN beta_fast
			1.,  // YaRN beta_slow
		),
	}
}

func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
	return &Tensor{
850
		b: t.b,
Michael Yang's avatar
Michael Yang committed
851
852
853
854
855
856
		t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
	return &Tensor{
857
		b: t.b,
Michael Yang's avatar
Michael Yang committed
858
859
860
861
862
863
		t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
864
		b: t.b,
Michael Yang's avatar
Michael Yang committed
865
866
867
		t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
	}
}
868

869
870
871
872
873
874
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
	var kqMask *C.struct_ggml_tensor
	if mask != nil {
		kqMask = mask.(*Tensor).t
	}

875
876
877
	query := t.Permute(ctx, 0, 2, 1, 3)
	key = key.Permute(ctx, 0, 2, 1, 3)

878
879
	if t.b.flashAttention {
		value = value.Permute(ctx, 0, 2, 1, 3)
880

881
882
883
884
885
886
887
888
889
890
891
892
893
		kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
		C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
		return &Tensor{b: t.b, t: kqv}
	} else {
		kq := key.MulmatFullPrec(ctx, query)
		kq = &Tensor{
			b: t.b,
			t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
		}

		kqv := value.Mulmat(ctx, kq)
		return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	}
894
}