ggml.go 32.7 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
package ggml

3
4
5
6
7
8
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
// #include <stdlib.h>
// #include <stdint.h>
// #include "ggml.h"
// #include "ggml-cpu.h"
// #include "ggml-backend.h"
Michael Yang's avatar
Michael Yang committed
9
10
11
import "C"

import (
12
	"context"
Michael Yang's avatar
Michael Yang committed
13
14
15
	"fmt"
	"io"
	"log/slog"
16
	"maps"
Michael Yang's avatar
Michael Yang committed
17
	"os"
18
	"runtime"
19
20
21
	"slices"
	"strconv"
	"strings"
22
	"sync/atomic"
23
	"unicode"
Michael Yang's avatar
Michael Yang committed
24
25
26
	"unsafe"

	"github.com/ollama/ollama/format"
27
28
	"github.com/ollama/ollama/fs"
	fsggml "github.com/ollama/ollama/fs/ggml"
29
	"github.com/ollama/ollama/logutil"
Michael Yang's avatar
Michael Yang committed
30
	"github.com/ollama/ollama/ml"
31
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
32
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
33
34
35
	"golang.org/x/sync/errgroup"
)

Michael Yang's avatar
Michael Yang committed
36
37
38
39
40
func devices() []*C.struct_ggml_backend_device {
	ggml.OnceLoad()
	ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
	for i := range ds {
		ds[i] = C.ggml_backend_dev_get(C.size_t(i))
Michael Yang's avatar
Michael Yang committed
41
	}
Michael Yang's avatar
Michael Yang committed
42
43

	return ds
44
}
Michael Yang's avatar
Michael Yang committed
45
46

type Backend struct {
47
48
49
	// modelPath is the location of the model data
	modelPath string

50
51
	meta *fsggml.GGML

52
53
54
55
	// tensorLoadTargets maps from the name of the tensor in the file
	// to the name that is used by the model definition
	tensorLoadTargets map[string][]string

56
57
58
59
	sched         *C.struct_ggml_backend_sched
	schedBackends []*C.struct_ggml_backend
	schedBufts    []*C.struct_ggml_backend_buffer_type

60
	tensors map[string]*C.struct_ggml_tensor
Michael Yang's avatar
Michael Yang committed
61
62

	// input is the backend used for inputs
63
	input *C.struct_ggml_backend_buffer_type
Michael Yang's avatar
Michael Yang committed
64
65

	// layers is the backend used for repeating layers
66
	layers map[int]*C.struct_ggml_backend_buffer_type
67

68
69
70
71
72
73
	// requiredMemory is the cumulative memory allocations needed by the backend
	requiredMemory *ml.BackendMemory

	// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
	btDeviceMemory map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory

74
	flashAttention bool
Michael Yang's avatar
Michael Yang committed
75
76
77

	// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
	maxGraphNodes int
Michael Yang's avatar
Michael Yang committed
78
79
}

80
81
82
83
84
85
86
87
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
	r, err := os.Open(modelPath)
	if err != nil {
		return nil, err
	}
	defer r.Close()

	meta, err := fsggml.Decode(r, -1)
Michael Yang's avatar
Michael Yang committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
	if err != nil {
		return nil, err
	}

	slog.Info(
		"",
		"architecture", meta.KV().Architecture(),
		"file_type", meta.KV().FileType(),
		"name", meta.KV().String("general.name"),
		"description", meta.KV().String("general.description"),
		"num_tensors", len(meta.Tensors().Items()),
		"num_key_values", len(meta.KV()),
	)

102
103
104
	var requiredMemory ml.BackendMemory
	btDeviceMemory := make(map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory)

105
	type deviceBufferType struct {
106
107
108
109
110
		d   *C.struct_ggml_backend_device
		bts []*C.struct_ggml_backend_buffer_type
	}

	var cpus, accels, gpus []*C.struct_ggml_backend_device
Michael Yang's avatar
Michael Yang committed
111
	for _, d := range devices() {
112
113
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU:
114
115
116
117
			if len(cpus) == 0 {
				// only the first cpu device should be used
				cpus = append(cpus, d)
			}
118
119
		case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
			accels = append(accels, d)
Michael Yang's avatar
Michael Yang committed
120
		case C.GGML_BACKEND_DEVICE_TYPE_GPU:
121
			gpus = append(gpus, d)
Michael Yang's avatar
Michael Yang committed
122
123
124
		}
	}

125
126
	blocks := int(meta.KV().BlockCount())

Michael Yang's avatar
Michael Yang committed
127
	// create list of buffer types for the cpu
Michael Yang's avatar
Michael Yang committed
128
	cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
129
130
131
132
	for _, d := range append(accels, append(gpus, cpus...)...) {
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU,
			C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
Michael Yang's avatar
Michael Yang committed
133
			cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
134
			btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU
Michael Yang's avatar
Michael Yang committed
135
		}
136
137
	}

138
	requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
139
140
141
	var props C.struct_ggml_backend_dev_props
	C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
	requiredMemory.CPU.UUID = C.GoString(props.uuid)
142
143
144
	requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
	requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)

Michael Yang's avatar
Michael Yang committed
145
	// create list of buffer types for each gpu
146
	var gpuDeviceBufferTypes []deviceBufferType
147
148
	requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
	for i, d := range gpus {
149
		bt := C.ggml_backend_dev_buffer_type(d)
150
		gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
151
			d:   d,
Michael Yang's avatar
Michael Yang committed
152
			bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
153
		})
154
155
		btDeviceMemory[bt] = &requiredMemory.GPUs[i]
		requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
156
157
158
		var props C.struct_ggml_backend_dev_props
		C.ggml_backend_dev_get_props(d, &props)
		requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
159
160
		requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
		requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
Michael Yang's avatar
Michael Yang committed
161
162
	}

Michael Yang's avatar
Michael Yang committed
163
164
165
166
167
	useDefaultSplit := true
	for _, s := range params.TensorSplit {
		if s != 0 {
			useDefaultSplit = false
			break
168
		}
Michael Yang's avatar
Michael Yang committed
169
	}
170

Michael Yang's avatar
Michael Yang committed
171
172
173
174
	// calculate splits
	splits := make([]float32, len(gpus))
	if useDefaultSplit {
		// default: split on free memory
175
176
177
178
179
		for i := range splits {
			var free, total C.size_t
			C.ggml_backend_dev_memory(gpus[i], &free, &total)
			splits[i] = float32(free)
		}
Michael Yang's avatar
Michael Yang committed
180
181
	} else {
		splits = params.TensorSplit
182
183
184
	}

	var sum float32
Michael Yang's avatar
Michael Yang committed
185
	// cumulative sum of all splits
186
187
188
189
190
	for i := range splits {
		sum += splits[i]
		splits[i] = sum
	}

Michael Yang's avatar
Michael Yang committed
191
	// normalize splits
192
	for i := range splits {
193
		splits[i] /= sum
194
195
	}

Michael Yang's avatar
Michael Yang committed
196
	// inputs always use cpu
Michael Yang's avatar
Michael Yang committed
197
	input := cpuDeviceBufferType
198

Michael Yang's avatar
Michael Yang committed
199
200
201
	// define a range of gpu layers. anything outside of this range is assigned to the cpu
	gpuRangeStart := max(0, blocks-params.NumGPULayers)
	gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
Michael Yang's avatar
Michael Yang committed
202
	assignLayer := func(i int) deviceBufferType {
Michael Yang's avatar
Michael Yang committed
203
		if i < gpuRangeStart || i >= gpuRangeStop {
Michael Yang's avatar
Michael Yang committed
204
			return cpuDeviceBufferType
205
		}
206

Michael Yang's avatar
Michael Yang committed
207
		index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
208
		if index < 0 || index >= len(gpuDeviceBufferTypes) {
Michael Yang's avatar
Michael Yang committed
209
			return cpuDeviceBufferType
210
211
212
		}

		return gpuDeviceBufferTypes[index]
213
214
	}

Michael Yang's avatar
Michael Yang committed
215
	// repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
216
	layers := make([]deviceBufferType, blocks)
217
	for i := range layers {
218
		layers[i] = assignLayer(i)
219
220
	}

Michael Yang's avatar
Michael Yang committed
221
	// outputs are assigned iff allowed by splits and configured number of gpu layers
222
	output := assignLayer(blocks)
223
224
225

	maxTensors := len(meta.Tensors().Items())
	maxTensors += 1
Michael Yang's avatar
Michael Yang committed
226
	// each layer has at most 2 extra tensors for rope operations
227
228
	maxTensors += blocks * 2

229
	type tensor struct {
230
		source *fsggml.Tensor
231
232
233
		target string
	}

Michael Yang's avatar
Michael Yang committed
234
	// some tensors are mapped to different names so keep a list
235
236
	targets := make(map[string][]string)

Michael Yang's avatar
Michael Yang committed
237
	// contexts are shared by tensors of the same buffer type
238
	ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
239
	createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor {
240
241
242
243
244
245
246
		for _, bt := range bts {
			if _, ok := ctxs[bt]; !ok {
				ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
					mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
					no_alloc: true,
				})
			}
Michael Yang's avatar
Michael Yang committed
247

248
249
250
251
252
253
254
255
			targets[t.source.Name] = append(targets[t.source.Name], t.target)

			name := t.source.Name
			if t.target != "" {
				name = t.target
			}

			cname := C.CString(name)
Michael Yang's avatar
Michael Yang committed
256
			defer C.free(unsafe.Pointer(cname))
257
258
259
260
			if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
				return tt
			}

261
			tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
Michael Yang's avatar
Michael Yang committed
262
263
			C.ggml_set_name(tt, cname)

264
			slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
265
266
267
268
269
270
271
272
273
274

			size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
			if layer == -1 {
				// Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case
				requiredMemory.InputWeights.Status = ml.Allocated
				requiredMemory.InputWeights.Size += uint64(size)
			} else {
				btDeviceMemory[bt].Weights[layer].Size += uint64(size)
			}

275
276
277
278
279
			//nolint:staticcheck // TODO: check if buffer type supports this tensor
			return tt
		}

		return nil
Michael Yang's avatar
Michael Yang committed
280
281
	}

282
	contains := func(s string, parts ...string) bool {
283
284
285
286
287
288
289
290
		split := strings.Split(s, ".")
		for _, part := range parts {
			if slices.Contains(split, part) {
				return true
			}
		}

		return false
Michael Yang's avatar
Michael Yang committed
291
292
	}

293
294
	for _, t := range meta.Tensors().Items() {
		switch {
295
		case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
296
			createTensor(tensor{source: t}, input.bts, -1)
Michael Yang's avatar
Michael Yang committed
297
			if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
298
				createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
Michael Yang's avatar
Michael Yang committed
299
			}
300
		case contains(t.Name, "cls", "output", "output_norm"):
301
			createTensor(tensor{source: t}, output.bts, blocks)
302
		case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
Michael Yang's avatar
Michael Yang committed
303
			// TODO: assign vision tensors to the gpu if possible
304
			createTensor(tensor{source: t}, output.bts, blocks)
Michael Yang's avatar
Michael Yang committed
305
306
307
308
309
310
		case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
			// these tensors should be repeated per layer
			for i, layer := range layers {
				createTensor(tensor{
					source: t,
					target: "blk." + strconv.Itoa(i) + "." + t.Name,
311
				}, layer.bts, i)
Michael Yang's avatar
Michael Yang committed
312
			}
313
		default:
Michael Yang's avatar
Michael Yang committed
314
315
316
317
			layerIndex := -1
			if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
				if i, err := strconv.Atoi(fields[0]); err == nil {
					layerIndex = i
318
				}
Michael Yang's avatar
Michael Yang committed
319
			}
320

Michael Yang's avatar
Michael Yang committed
321
			if layerIndex >= 0 {
322
				createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex)
323
			} else {
Michael Yang's avatar
Michael Yang committed
324
				// load all other tensors on the cpu
325
				createTensor(tensor{source: t}, input.bts, -1)
326
327
328
			}
		}
	}
Michael Yang's avatar
Michael Yang committed
329

Michael Yang's avatar
Michael Yang committed
330
331
	// allocate buffers for each context
	bbs := make(map[*C.struct_ggml_context]*C.struct_ggml_backend_buffer, len(ctxs))
332
333
334
335
336
337
	for bt, c := range ctxs {
		if C.ggml_get_first_tensor(c) == nil {
			continue
		}

		b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
338
339
340
341
342
343
344
345
346
347
		for i := range btDeviceMemory[bt].Weights {
			if btDeviceMemory[bt].Weights[i].Size != 0 {
				if b != nil {
					btDeviceMemory[bt].Weights[i].Status = ml.Allocated
				} else {
					btDeviceMemory[bt].Weights[i].Status = ml.Failed
				}
			}
		}

348
		if b == nil {
349
			panic(ml.ErrNoMem{BackendMemory: requiredMemory})
350
351
		}

352
		C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
Michael Yang's avatar
Michael Yang committed
353
		bbs[c] = b
354
355
356
	}

	for bs := range maps.Values(bbs) {
Michael Yang's avatar
Michael Yang committed
357
		slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
358
359
	}

Michael Yang's avatar
Michael Yang committed
360
	// map tensor names to tensors for easy lookup later
361
362
363
364
365
366
367
	tensors := make(map[string]*C.struct_ggml_tensor)
	for _, c := range ctxs {
		for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
			tensors[C.GoString(C.ggml_get_name(t))] = t
		}
	}

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
	// map devices to backend buffer types so new tensors can be assigned to the correct device
	deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)

	// create backends and buffer types used for the compute graph scheduler
	var schedBackends []*C.struct_ggml_backend
	var schedBufts []*C.struct_ggml_backend_buffer_type
	for _, d := range append(gpus, append(accels, cpus...)...) {
		b := C.ggml_backend_dev_init(d, nil)
		bt := C.ggml_backend_get_default_buffer_type(b)

		deviceBufferTypes[d] = bt

		schedBackends = append(schedBackends, b)
		schedBufts = append(schedBufts, bt)

		if C.ggml_backend_is_cpu(b) {
			// set number of threads for cpu backend
			C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
		}
	}

	maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
	return &Backend{
		modelPath:         modelPath,
		flashAttention:    params.FlashAttention,
		meta:              meta,
		tensorLoadTargets: targets,
		tensors:           tensors,
		sched: C.ggml_backend_sched_new(
			(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
			(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
			C.int(len(schedBackends)),
			C.size_t(maxGraphNodes),
			C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
			C._Bool(false),
		),
		schedBackends: schedBackends,
		schedBufts:    schedBufts,
		input:         deviceBufferTypes[input.d],
		layers: func() map[int]*C.struct_ggml_backend_buffer_type {
			m := make(map[int]*C.struct_ggml_backend_buffer_type)
			for i, layer := range layers {
				m[i] = deviceBufferTypes[layer.d]
			}
			return m
		}(),
414
415
416
		requiredMemory: &requiredMemory,
		btDeviceMemory: btDeviceMemory,
		maxGraphNodes:  maxGraphNodes,
417
418
419
420
421
422
423
424
	}, nil
}

func init() {
	ml.RegisterBackend("ggml", New)
}

func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
425
	var doneBytes atomic.Uint64
426
	totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
427
428
429

	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(runtime.GOMAXPROCS(0))
430
	for _, t := range b.meta.Tensors().Items() {
431
		t := t
432
		g.Go(func() error {
433
			tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
434
			for i := range tts {
435
				target := b.tensorLoadTargets[t.Name][i]
436
437
438
				if target == "" {
					target = t.Name
				}
439

440
				tt, ok := b.tensors[target]
441
442
443
				if !ok {
					return fmt.Errorf("unassigned tensor: %s", t.Name)
				}
Michael Yang's avatar
Michael Yang committed
444

445
446
447
				tts[i] = tt
			}

448
449
			// Create a new FD for each goroutine so that each FD is read sequentially, rather than
			// seeking around within an FD shared between all goroutines.
450
			file, err := os.Open(b.modelPath)
451
			if err != nil {
452
				slog.Warn("file open error", "file", b.modelPath, "error", err)
453
454
455
				return err
			}
			defer file.Close()
456
			sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size()))
457
458
459
460
			bts := make([]byte, 128*format.KibiByte)

			var s uint64
			for s < t.Size() {
461
462
463
464
465
				// Stop if either the parent context has been canceled or if any of the other tensors returned an error
				if err := ctx.Err(); err != nil {
					return err
				}

466
467
				n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
				if err != nil {
468
					slog.Warn("file read error", "file", b.modelPath, "error", err)
469
					return err
470
				}
Michael Yang's avatar
Michael Yang committed
471

472
473
				for _, tt := range tts {
					C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
474
				}
Michael Yang's avatar
Michael Yang committed
475

476
477
				s += uint64(n)

478
				if progress != nil {
479
					done := doneBytes.Add(uint64(n))
480
					progress(float32(done) / float32(totalBytes))
481
482
483
484
485
				}
			}

			return nil
		})
Michael Yang's avatar
Michael Yang committed
486
487
	}

488
	if err := g.Wait(); err != nil {
489
		return err
490
491
	}

492
	return nil
Michael Yang's avatar
Michael Yang committed
493
494
}

495
496
497
498
func (b *Backend) BackendMemory() ml.BackendMemory {
	return *b.requiredMemory
}

499
func (b *Backend) Config() fs.Config {
Michael Yang's avatar
Michael Yang committed
500
501
502
503
	return b.meta.KV()
}

func (b *Backend) Get(name string) ml.Tensor {
504
505
	if t, ok := b.tensors[name]; ok {
		return &Tensor{b: b, t: t}
Michael Yang's avatar
Michael Yang committed
506
507
508
509
510
511
	}

	return nil
}

func (b *Backend) NewContext() ml.Context {
Michael Yang's avatar
Michael Yang committed
512
	return b.NewContextSize(b.maxGraphNodes)
513
514
515
}

func (b *Backend) NewContextSize(n int) ml.Context {
Jesse Gross's avatar
Jesse Gross committed
516
517
518
519
	if n > b.maxGraphNodes {
		panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
	}

520
521
	var allocatedBuffers []*C.struct_ggml_backend_buffer

Michael Yang's avatar
Michael Yang committed
522
	return &Context{
523
524
		b:             b,
		maxGraphNodes: n,
525
		ctx: C.ggml_init(C.struct_ggml_init_params{
526
			mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
527
528
			no_alloc: true,
		}),
529
		allocatedBuffers: &allocatedBuffers,
530
		layer:            -1,
Michael Yang's avatar
Michael Yang committed
531
532
533
	}
}

534
func (b *Backend) CacheConfig() ml.CacheConfig {
535
536
537
538
539
	if b.flashAttention {
		return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
	} else {
		return ml.CacheConfig{CachePadding: 32, PermutedV: true}
	}
540
541
}

Michael Yang's avatar
Michael Yang committed
542
type Context struct {
543
	b *Backend
Michael Yang's avatar
Michael Yang committed
544

545
	ctx   *C.struct_ggml_context
Michael Yang's avatar
Michael Yang committed
546
	graph *C.struct_ggml_cgraph
547

548
549
	// buft is the buffer type used for new tensors
	buft *C.struct_ggml_backend_buffer_type
550

551
552
553
554
	// allocatedBuffers are buffers for tensors that we have allocated in this context
	// so that we can free them when we close the context
	allocatedBuffers *[]*C.struct_ggml_backend_buffer

Michael Yang's avatar
Michael Yang committed
555
	// maxGraphNodes is the maximum allowed number of graph nodes in this context
556
	maxGraphNodes int
557
558
559

	// layer is the graph layer that this context is allocating for - assumed to be cache
	layer int
Michael Yang's avatar
Michael Yang committed
560
561
}

562
func (c *Context) Input() ml.Context {
Michael Yang's avatar
Michael Yang committed
563
	if c.b.input != nil {
564
		return &Context{
565
566
567
568
569
			b:                c.b,
			ctx:              c.ctx,
			buft:             c.b.input,
			allocatedBuffers: c.allocatedBuffers,
			maxGraphNodes:    c.maxGraphNodes,
570
			layer:            -1,
571
572
573
		}
	}

574
	return c
575
576
}

577
func (c *Context) Layer(i int) ml.Context {
578
	if buft, ok := c.b.layers[i]; ok {
579
		return &Context{
580
581
582
583
584
			b:                c.b,
			ctx:              c.ctx,
			buft:             buft,
			allocatedBuffers: c.allocatedBuffers,
			maxGraphNodes:    c.maxGraphNodes,
585
			layer:            i,
586
587
588
		}
	}

589
	return c
590
591
}

592
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
Michael Yang's avatar
Michael Yang committed
593
	if c.graph == nil {
594
		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
Michael Yang's avatar
Michael Yang committed
595
596
	}

597
598
599
600
601
	for _, tensor := range tensors {
		C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
	}

	return c
Michael Yang's avatar
Michael Yang committed
602
603
}

604
func (c *Context) Compute(tensors ...ml.Tensor) {
605
606
607
	if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
		panic(fmt.Errorf("error computing ggml graph: %v", status))
	}
Michael Yang's avatar
Michael Yang committed
608
	C.ggml_backend_sched_reset(c.b.sched)
Michael Yang's avatar
Michael Yang committed
609

610
611
612
	needSync := true
	sync := func() {
		if needSync {
613
			C.ggml_backend_sched_synchronize(c.b.sched)
614
615
616
			needSync = false
		}
	}
Michael Yang's avatar
Michael Yang committed
617

618
619
620
	for _, t := range tensors {
		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
			t.(*Tensor).sync = sync
621
622
		}
	}
Michael Yang's avatar
Michael Yang committed
623
624
}

625
626
func (c *Context) Reserve() {
	reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
627
628

	slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
629
630
631
632
633
634

	// Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations
	for _, bt := range c.b.schedBufts {
		c.b.btDeviceMemory[bt].Graph = ml.Memory{}
	}

635
	for i := range c.b.schedBackends {
636
637
638
639
640
641
642
643
644
645
		bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i])

		graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph
		graph.Size += uint64(bufferStatus.size)
		if bufferStatus.allocated && graph.Status != ml.Failed {
			graph.Status = ml.Allocated
		} else {
			graph.Status = ml.Failed
		}

646
		slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
647
			"size", format.HumanBytes2(uint64(bufferStatus.size)))
648
649
	}

650
651
652
	if !reserved {
		panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
	}
653
654
}

655
func (c *Context) MaxGraphNodes() int {
656
	return c.maxGraphNodes
Jesse Gross's avatar
Jesse Gross committed
657
658
}

659
660
661
func shapeToGGML(shape []int) *C.int64_t {
	sh := make([]C.int64_t, len(shape))
	for i, s := range shape {
662
		sh[i] = C.int64_t(s)
663
664
665
666
667
	}

	return &sh[0]
}

668
669
670
671
func pad(length, pad C.size_t) C.size_t {
	return ((length + pad - 1) / pad) * pad
}

672
func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
673
	if c.buft == nil {
674
		panic("set Input or Layer before creating tensors")
675
676
	}

Michael Yang's avatar
Michael Yang committed
677
678
679
680
681
682
	var cdtype uint32
	switch dtype {
	case ml.DTypeF32:
		cdtype = C.GGML_TYPE_F32
	case ml.DTypeF16:
		cdtype = C.GGML_TYPE_F16
683
684
685
686
	case ml.DTypeQ80:
		cdtype = C.GGML_TYPE_Q8_0
	case ml.DTypeQ40:
		cdtype = C.GGML_TYPE_Q4_0
Michael Yang's avatar
Michael Yang committed
687
688
689
690
691
692
	case ml.DTypeI32:
		cdtype = C.GGML_TYPE_I32
	default:
		panic("unsupported dtype")
	}

Jesse Gross's avatar
Jesse Gross committed
693
	if len(shape) < 1 || shape[0] == 0 {
Michael Yang's avatar
Michael Yang committed
694
		var shape C.int64_t = 0
695
		return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
Michael Yang's avatar
Michael Yang committed
696
	} else if len(shape) > 4 {
Michael Yang's avatar
Michael Yang committed
697
698
699
700
701
702
703
704
705
		panic("unsupported number of dimensions")
	}

	for _, dim := range shape {
		if dim < 1 {
			panic("invalid shape")
		}
	}

Michael Yang's avatar
Michael Yang committed
706
	t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
707
	size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
708

709
	b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
710
711
712
713
714
715
716
717
718
719
720
	if c.layer >= 0 {
		cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer]

		cache.Size += uint64(size)
		if b != nil {
			cache.Status = ml.Allocated
		} else {
			cache.Status = ml.Failed
		}
	}

721
	if b == nil {
722
		panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
723
724
	}

725
	*c.allocatedBuffers = append(*c.allocatedBuffers, b)
Michael Yang's avatar
Michael Yang committed
726
	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
727
	return &Tensor{b: c.b, t: t}
728
729
}

730
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
731
	return c.newTensor(dtype, shape)
732
733
}

734
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
735
	t := c.newTensor(dtype, shape)
736
737
	C.ggml_set_zero(t.(*Tensor).t)
	return t
Michael Yang's avatar
Michael Yang committed
738
739
}

740
func checkShape[S ~[]E, E any](s S, shape ...int) {
Michael Yang's avatar
Michael Yang committed
741
	n := len(s)
Jesse Gross's avatar
Jesse Gross committed
742
743

	if n == 0 {
744
		return
Jesse Gross's avatar
Jesse Gross committed
745
746
	}

Michael Yang's avatar
Michael Yang committed
747
748
749
750
751
	for _, v := range shape {
		n /= v
	}

	if n != 1 {
752
		panic(fmt.Errorf("invalid shape: %v", shape))
Michael Yang's avatar
Michael Yang committed
753
754
755
	}
}

756
757
func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
	checkShape(s, shape...)
758

759
	t := c.newTensor(ml.DTypeF32, shape)
760

Jesse Gross's avatar
Jesse Gross committed
761
762
763
764
	if len(s) > 0 {
		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	}

765
	return t
Michael Yang's avatar
Michael Yang committed
766
767
}

768
769
func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor {
	checkShape(s, shape...)
770

771
	t := c.newTensor(ml.DTypeI32, shape)
772

Jesse Gross's avatar
Jesse Gross committed
773
774
775
776
	if len(s) > 0 {
		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	}

777
	return t
Michael Yang's avatar
Michael Yang committed
778
779
}

Michael Yang's avatar
arange  
Michael Yang committed
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
	switch dtype {
	case ml.DTypeF32:
		// ggml_arange creates a float32 tensor
		return &Tensor{
			b: c.b,
			t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)),
		}
	case ml.DTypeI32:
		// ggml_cast does not support float32 to int32 conversion
		arange := make([]int32, 0, int((stop-start)/step))
		for i := start; i < stop; i += step {
			arange = append(arange, int32(i))
		}

795
		return c.Input().FromIntSlice(arange, len(arange))
Michael Yang's avatar
arange  
Michael Yang committed
796
797
798
799
800
	default:
		panic("unsupported dtype for arange")
	}
}

Michael Yang's avatar
Michael Yang committed
801
802
func (c *Context) Close() {
	if c != nil {
803
804
805
806
807
		for _, b := range *c.allocatedBuffers {
			C.ggml_backend_buffer_free(b)
		}
		*c.allocatedBuffers = nil

808
809
		C.ggml_free(c.ctx)
	}
Michael Yang's avatar
Michael Yang committed
810
811
812
}

type Tensor struct {
813
	b    *Backend
Michael Yang's avatar
Michael Yang committed
814
	t    *C.struct_ggml_tensor
815
	sync func()
Michael Yang's avatar
Michael Yang committed
816
817
818
819
820
821
822
823
824
825
}

func (t *Tensor) LogValue() slog.Value {
	return slog.GroupValue(
		slog.String("name", C.GoString(C.ggml_get_name(t.t))),
		slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
		slog.Any("shape", t.Shape()),
	)
}

826
827
func (t *Tensor) Dim(n int) int {
	return int(t.t.ne[n])
Michael Yang's avatar
Michael Yang committed
828
829
}

830
831
func (t *Tensor) Stride(n int) int {
	return int(t.t.nb[n])
Michael Yang's avatar
Michael Yang committed
832
833
}

834
835
func (t *Tensor) Shape() []int {
	shape := make([]int, C.ggml_n_dims(t.t))
Michael Yang's avatar
Michael Yang committed
836
837
838
839
840
841
842
	for i := range shape {
		shape[i] = t.Dim(i)
	}

	return shape
}

843
844
845
846
847
848
849
850
851
func (t *Tensor) Bytes() (data []byte) {
	if t.sync != nil {
		data = make([]byte, C.ggml_nbytes(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
	}

	return
Michael Yang's avatar
Michael Yang committed
852
853
}

854
855
856
857
858
859
func (t *Tensor) Floats() (data []float32) {
	if t.sync != nil {
		data = make([]float32, C.ggml_nelements(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
Michael Yang's avatar
Michael Yang committed
860
861
862
863
864
865
866
867
868
	}

	return
}

func (t *Tensor) DType() ml.DType {
	switch t.t._type {
	case C.GGML_TYPE_F32:
		return ml.DTypeF32
Jesse Gross's avatar
Jesse Gross committed
869
870
	case C.GGML_TYPE_F16:
		return ml.DTypeF16
871
872
873
874
	case C.GGML_TYPE_Q8_0:
		return ml.DTypeQ80
	case C.GGML_TYPE_Q4_0:
		return ml.DTypeQ40
Michael Yang's avatar
Michael Yang committed
875
876
877
878
879
880
881
	case C.GGML_TYPE_I32:
		return ml.DTypeI32
	default:
		return ml.DTypeOther
	}
}

882
883
884
885
886
887
888
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_neg(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
889
890
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
891
		b: t.b,
Michael Yang's avatar
Michael Yang committed
892
893
894
895
		t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
	if dim < 0 || dim >= C.GGML_MAX_DIMS {
		panic("invalid dimension")
	}

	shape := make([]C.int64_t, C.GGML_MAX_DIMS)
	for i := range C.GGML_MAX_DIMS {
		if i == dim {
			shape[i] = C.int64_t(t.Dim(i) * n)
		} else {
			shape[i] = C.int64_t(t.Dim(i))
		}
	}

	tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape))
	return &Tensor{
		b: t.b,
		t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl),
	}
}

Michael Yang's avatar
Michael Yang committed
917
918
919
920
921
922
923
924
925
926
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
	if len(s) > 0 {
		return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
	}

	return t
}

func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
	return &Tensor{
927
		b: t.b,
Michael Yang's avatar
Michael Yang committed
928
929
930
931
932
933
		t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
	}
}

func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
	return &Tensor{
934
		b: t.b,
Michael Yang's avatar
Michael Yang committed
935
936
937
938
939
940
		t: C.ggml_cont(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
941
		b: t.b,
Michael Yang's avatar
Michael Yang committed
942
943
944
945
		t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

946
947
948
949
950
951
952
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

Michael Yang's avatar
Michael Yang committed
953
954
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
955
		b: t.b,
Michael Yang's avatar
Michael Yang committed
956
957
958
959
		t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

960
961
962
963
964
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
	C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)

	return &Tensor{
965
		b: t.b,
966
967
968
969
		t: mul,
	}
}

Michael Yang's avatar
llama4  
Michael Yang committed
970
971
972
973
974
975
976
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
	}
}

Michael Yang's avatar
Michael Yang committed
977
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
978
979
980
981
982
983
	tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
	if w != nil {
		tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
		if b != nil {
			tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t)
		}
Michael Yang's avatar
Michael Yang committed
984
985
	}

Michael Yang's avatar
llama4  
Michael Yang committed
986
	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
987
988
989
}

func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
990
991
992
993
994
995
	tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))
	if w != nil {
		tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
	}

	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
996
997
}

998
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
999
1000
	if len(shape) != 4 {
		panic("expected 4 dimensions")
1001
1002
	} else if shape[3] != 0 {
		panic("cuda does not support 4d tensors")
Michael Yang's avatar
Michael Yang committed
1003
1004
1005
	}

	return &Tensor{
1006
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
		t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
1017
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1018
1019
1020
1021
1022
1023
		t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
1024
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1025
1026
1027
1028
1029
1030
		t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
1031
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1032
1033
1034
1035
		t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

1036
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
1037
1038
1039
	switch len(shape) {
	case 1:
		return &Tensor{
1040
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1041
1042
1043
1044
			t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
		}
	case 2:
		return &Tensor{
1045
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1046
1047
1048
1049
			t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
		}
	case 3:
		return &Tensor{
1050
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1051
1052
1053
1054
			t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
		}
	case 4:
		return &Tensor{
1055
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1056
1057
1058
1059
1060
1061
1062
1063
1064
			t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
	return &Tensor{
1065
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1066
1067
1068
1069
		t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
	}
}

1070
1071
1072
1073
1074
1075
1076
func (t *Tensor) SumRows(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sum_rows(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1077
1078
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
	return &Tensor{
1079
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1080
1081
1082
1083
		t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
	}
}

1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sin(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_cos(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1098
1099
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
	return &Tensor{
1100
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1101
1102
1103
1104
		t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
llama4  
Michael Yang committed
1105
1106
1107
1108
1109
1110
1111
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1112
1113
1114
1115
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
	switch len(shape) {
	case 1:
		return &Tensor{
1116
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1117
1118
1119
1120
			t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
		}
	case 3:
		return &Tensor{
1121
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1122
1123
1124
1125
1126
1127
1128
			t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]),
				C.size_t(shape[1]),
				C.size_t(offset)),
		}
	case 5:
		return &Tensor{
1129
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1130
1131
1132
1133
1134
1135
1136
			t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
				C.size_t(shape[1]), C.size_t(shape[3]),
				C.size_t(offset)),
		}
	case 7:
		return &Tensor{
1137
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
			t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
				C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
				C.size_t(offset)),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

1148
func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor {
1149
	// Default options
1150
	opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}}
1151
1152
1153
1154
1155
1156

	// Apply any provided options
	for _, option := range options {
		option(opts)
	}

Jesse Gross's avatar
Jesse Gross committed
1157
1158
1159
1160
1161
	dequant := t.t
	if C.ggml_is_quantized(t.t._type) {
		dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
	}

Michael Yang's avatar
Michael Yang committed
1162
	return &Tensor{
1163
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1164
		t: C.ggml_rope_ext(
1165
1166
			ctx.(*Context).ctx,
			dequant,
1167
1168
			positions.(*Tensor).t,
			opts.Factors.(*Tensor).t,
Michael Yang's avatar
Michael Yang committed
1169
			C.int(ropeDim),
1170
1171
			C.int(opts.Type),
			C.int(opts.OriginalContextLength),
Michael Yang's avatar
Michael Yang committed
1172
1173
			C.float(ropeBase),
			C.float(ropeScale),
1174
1175
1176
1177
			C.float(0.0),
			C.float(1.0),
			C.float(32.0),
			C.float(1.0),
Michael Yang's avatar
Michael Yang committed
1178
1179
1180
1181
		),
	}
}

1182
1183
1184
1185
1186
1187
1188
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
	}
}

Michael Yang's avatar
Michael Yang committed
1189
1190
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
	return &Tensor{
1191
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1192
1193
1194
1195
1196
1197
		t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
	return &Tensor{
1198
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1199
1200
1201
1202
1203
1204
		t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
1205
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1206
1207
1208
		t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
	}
}
1209

Michael Yang's avatar
Michael Yang committed
1210
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
1211
1212
	return &Tensor{
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1213
		t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
Michael Yang's avatar
Michael Yang committed
1214
1215
1216
	}
}

Michael Yang's avatar
Michael Yang committed
1217
1218
1219
1220
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
	var tt *C.struct_ggml_tensor
	switch len(strides) {
	case 0:
Michael Yang's avatar
Michael Yang committed
1221
		tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
Michael Yang's avatar
Michael Yang committed
1222
	case 1:
Michael Yang's avatar
Michael Yang committed
1223
		tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
Michael Yang's avatar
Michael Yang committed
1224
1225
1226
1227
1228
1229
1230
	default:
		panic("unsupported number of dimensions")
	}

	return &Tensor{b: t.b, t: tt}
}

1231
1232
1233
1234
1235
1236
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
	var kqMask *C.struct_ggml_tensor
	if mask != nil {
		kqMask = mask.(*Tensor).t
	}

1237
1238
1239
	query := t.Permute(ctx, 0, 2, 1, 3)
	key = key.Permute(ctx, 0, 2, 1, 3)

1240
1241
	if t.b.flashAttention {
		value = value.Permute(ctx, 0, 2, 1, 3)
1242

1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
		kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
		C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
		return &Tensor{b: t.b, t: kqv}
	} else {
		kq := key.MulmatFullPrec(ctx, query)
		kq = &Tensor{
			b: t.b,
			t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
		}

		kqv := value.Mulmat(ctx, kq)
		return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	}
1256
}
1257
1258
1259
1260
1261
1262
1263

func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_dup(ctx.(*Context).ctx, t.t),
	}
}
Michael Yang's avatar
llama4  
Michael Yang committed
1264
1265
1266
1267
1268
1269
1270

func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
	}
}
1271
1272
1273
1274
1275
1276
1277

func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
	}
}