ggml.go 29.4 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
package ggml

3
4
5
6
7
8
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
// #include <stdlib.h>
// #include <stdint.h>
// #include "ggml.h"
// #include "ggml-cpu.h"
// #include "ggml-backend.h"
Michael Yang's avatar
Michael Yang committed
9
10
11
import "C"

import (
12
	"context"
13
	"errors"
Michael Yang's avatar
Michael Yang committed
14
15
16
	"fmt"
	"io"
	"log/slog"
17
	"maps"
Michael Yang's avatar
Michael Yang committed
18
	"os"
19
	"runtime"
20
21
22
	"slices"
	"strconv"
	"strings"
23
	"sync/atomic"
24
	"unicode"
Michael Yang's avatar
Michael Yang committed
25
26
27
	"unsafe"

	"github.com/ollama/ollama/format"
28
29
	"github.com/ollama/ollama/fs"
	fsggml "github.com/ollama/ollama/fs/ggml"
Michael Yang's avatar
Michael Yang committed
30
	"github.com/ollama/ollama/ml"
31
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
Michael Yang's avatar
Michael Yang committed
32
33
34
	"golang.org/x/sync/errgroup"
)

Michael Yang's avatar
Michael Yang committed
35
36
37
38
39
func devices() []*C.struct_ggml_backend_device {
	ggml.OnceLoad()
	ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
	for i := range ds {
		ds[i] = C.ggml_backend_dev_get(C.size_t(i))
Michael Yang's avatar
Michael Yang committed
40
	}
Michael Yang's avatar
Michael Yang committed
41
42

	return ds
43
}
Michael Yang's avatar
Michael Yang committed
44
45

type Backend struct {
46
47
48
49
50
51
	meta *fsggml.GGML

	sched         *C.struct_ggml_backend_sched
	schedBackends []*C.struct_ggml_backend
	schedBufts    []*C.struct_ggml_backend_buffer_type

52
	tensors map[string]*C.struct_ggml_tensor
Michael Yang's avatar
Michael Yang committed
53
54

	// input is the backend used for inputs
55
	input *C.struct_ggml_backend_buffer_type
Michael Yang's avatar
Michael Yang committed
56
57

	// layers is the backend used for repeating layers
58
	layers map[int]*C.struct_ggml_backend_buffer_type
59

60
	flashAttention bool
Michael Yang's avatar
Michael Yang committed
61
62
63

	// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
	maxGraphNodes int
Michael Yang's avatar
Michael Yang committed
64
65
}

66
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
67
	meta, n, err := fsggml.Decode(r, -1)
Michael Yang's avatar
Michael Yang committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
	if err != nil {
		return nil, err
	}

	slog.Info(
		"",
		"architecture", meta.KV().Architecture(),
		"file_type", meta.KV().FileType(),
		"name", meta.KV().String("general.name"),
		"description", meta.KV().String("general.description"),
		"num_tensors", len(meta.Tensors().Items()),
		"num_key_values", len(meta.KV()),
	)

82
	type deviceBufferType struct {
83
84
85
86
87
		d   *C.struct_ggml_backend_device
		bts []*C.struct_ggml_backend_buffer_type
	}

	var cpus, accels, gpus []*C.struct_ggml_backend_device
Michael Yang's avatar
Michael Yang committed
88
	for _, d := range devices() {
89
90
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU:
91
92
93
94
			if len(cpus) == 0 {
				// only the first cpu device should be used
				cpus = append(cpus, d)
			}
95
96
		case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
			accels = append(accels, d)
Michael Yang's avatar
Michael Yang committed
97
		case C.GGML_BACKEND_DEVICE_TYPE_GPU:
98
			gpus = append(gpus, d)
Michael Yang's avatar
Michael Yang committed
99
100
101
		}
	}

Michael Yang's avatar
Michael Yang committed
102
	// create list of buffer types for the cpu
Michael Yang's avatar
Michael Yang committed
103
	cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
104
105
106
107
	for _, d := range append(accels, append(gpus, cpus...)...) {
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU,
			C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
Michael Yang's avatar
Michael Yang committed
108
			cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
Michael Yang's avatar
Michael Yang committed
109
		}
110
111
	}

Michael Yang's avatar
Michael Yang committed
112
	// create list of buffer types for each gpu
113
	var gpuDeviceBufferTypes []deviceBufferType
114
115
	for _, d := range gpus {
		bt := C.ggml_backend_dev_buffer_type(d)
116
		gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
117
			d:   d,
Michael Yang's avatar
Michael Yang committed
118
			bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
119
		})
Michael Yang's avatar
Michael Yang committed
120
121
	}

Michael Yang's avatar
Michael Yang committed
122
123
124
125
126
	useDefaultSplit := true
	for _, s := range params.TensorSplit {
		if s != 0 {
			useDefaultSplit = false
			break
127
		}
Michael Yang's avatar
Michael Yang committed
128
	}
129

Michael Yang's avatar
Michael Yang committed
130
131
132
133
	// calculate splits
	splits := make([]float32, len(gpus))
	if useDefaultSplit {
		// default: split on free memory
134
135
136
137
138
		for i := range splits {
			var free, total C.size_t
			C.ggml_backend_dev_memory(gpus[i], &free, &total)
			splits[i] = float32(free)
		}
Michael Yang's avatar
Michael Yang committed
139
140
	} else {
		splits = params.TensorSplit
141
142
143
	}

	var sum float32
Michael Yang's avatar
Michael Yang committed
144
	// cumulative sum of all splits
145
146
147
148
149
	for i := range splits {
		sum += splits[i]
		splits[i] = sum
	}

Michael Yang's avatar
Michael Yang committed
150
	// normalize splits
151
	for i := range splits {
152
		splits[i] /= sum
153
154
	}

Michael Yang's avatar
Michael Yang committed
155
	// inputs always use cpu
Michael Yang's avatar
Michael Yang committed
156
	input := cpuDeviceBufferType
157

158
	blocks := int(meta.KV().BlockCount())
Michael Yang's avatar
Michael Yang committed
159
160
161
162

	// define a range of gpu layers. anything outside of this range is assigned to the cpu
	gpuRangeStart := max(0, blocks-params.NumGPULayers)
	gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
Michael Yang's avatar
Michael Yang committed
163
	assignLayer := func(i int) deviceBufferType {
Michael Yang's avatar
Michael Yang committed
164
		if i < gpuRangeStart || i >= gpuRangeStop {
Michael Yang's avatar
Michael Yang committed
165
			return cpuDeviceBufferType
166
		}
167

Michael Yang's avatar
Michael Yang committed
168
		index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
169
		if index < 0 || index >= len(gpuDeviceBufferTypes) {
Michael Yang's avatar
Michael Yang committed
170
			return cpuDeviceBufferType
171
172
173
		}

		return gpuDeviceBufferTypes[index]
174
175
	}

Michael Yang's avatar
Michael Yang committed
176
	// repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
177
	layers := make([]deviceBufferType, blocks)
178
	for i := range layers {
179
		layers[i] = assignLayer(i)
180
181
	}

Michael Yang's avatar
Michael Yang committed
182
	// outputs are assigned iff allowed by splits and configured number of gpu layers
183
	output := assignLayer(blocks)
184
185
186

	maxTensors := len(meta.Tensors().Items())
	maxTensors += 1
Michael Yang's avatar
Michael Yang committed
187
	// each layer has at most 2 extra tensors for rope operations
188
189
	maxTensors += blocks * 2

190
	type tensor struct {
191
		source *fsggml.Tensor
192
193
194
		target string
	}

Michael Yang's avatar
Michael Yang committed
195
	// some tensors are mapped to different names so keep a list
196
197
	targets := make(map[string][]string)

Michael Yang's avatar
Michael Yang committed
198
	// contexts are shared by tensors of the same buffer type
199
	ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
200
	createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
201
202
203
204
205
206
207
		for _, bt := range bts {
			if _, ok := ctxs[bt]; !ok {
				ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
					mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
					no_alloc: true,
				})
			}
Michael Yang's avatar
Michael Yang committed
208

209
210
211
212
213
214
215
216
			targets[t.source.Name] = append(targets[t.source.Name], t.target)

			name := t.source.Name
			if t.target != "" {
				name = t.target
			}

			cname := C.CString(name)
Michael Yang's avatar
Michael Yang committed
217
			defer C.free(unsafe.Pointer(cname))
218
219
220
221
			if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
				return tt
			}

222
			tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
Michael Yang's avatar
Michael Yang committed
223
224
			C.ggml_set_name(tt, cname)

225
			slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
226
227
228
229
230
			//nolint:staticcheck // TODO: check if buffer type supports this tensor
			return tt
		}

		return nil
Michael Yang's avatar
Michael Yang committed
231
232
	}

233
	contains := func(s string, parts ...string) bool {
234
235
236
237
238
239
240
241
		split := strings.Split(s, ".")
		for _, part := range parts {
			if slices.Contains(split, part) {
				return true
			}
		}

		return false
Michael Yang's avatar
Michael Yang committed
242
243
	}

244
245
	for _, t := range meta.Tensors().Items() {
		switch {
246
		case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
247
			createTensor(tensor{source: t}, input.bts)
Michael Yang's avatar
Michael Yang committed
248
249
250
			if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
				createTensor(tensor{source: t, target: "output.weight"}, output.bts)
			}
251
		case contains(t.Name, "cls", "output", "output_norm"):
252
			createTensor(tensor{source: t}, output.bts)
253
		case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
Michael Yang's avatar
Michael Yang committed
254
			// TODO: assign vision tensors to the gpu if possible
Michael Yang's avatar
Michael Yang committed
255
			createTensor(tensor{source: t}, output.bts)
Michael Yang's avatar
Michael Yang committed
256
257
258
259
260
261
262
263
		case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
			// these tensors should be repeated per layer
			for i, layer := range layers {
				createTensor(tensor{
					source: t,
					target: "blk." + strconv.Itoa(i) + "." + t.Name,
				}, layer.bts)
			}
264
		default:
Michael Yang's avatar
Michael Yang committed
265
266
267
268
			layerIndex := -1
			if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
				if i, err := strconv.Atoi(fields[0]); err == nil {
					layerIndex = i
269
				}
Michael Yang's avatar
Michael Yang committed
270
			}
271

Michael Yang's avatar
Michael Yang committed
272
273
			if layerIndex >= 0 {
				createTensor(tensor{source: t}, layers[layerIndex].bts)
274
			} else {
Michael Yang's avatar
Michael Yang committed
275
276
				// load all other tensors on the cpu
				createTensor(tensor{source: t}, input.bts)
277
278
279
			}
		}
	}
Michael Yang's avatar
Michael Yang committed
280

Michael Yang's avatar
Michael Yang committed
281
282
	// allocate buffers for each context
	bbs := make(map[*C.struct_ggml_context]*C.struct_ggml_backend_buffer, len(ctxs))
283
284
285
286
287
288
	for bt, c := range ctxs {
		if C.ggml_get_first_tensor(c) == nil {
			continue
		}

		b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
289
290
291
292
		if b == nil {
			return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
		}

293
		C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
Michael Yang's avatar
Michael Yang committed
294
		bbs[c] = b
295
296
297
	}

	for bs := range maps.Values(bbs) {
Michael Yang's avatar
Michael Yang committed
298
		slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
299
300
	}

Michael Yang's avatar
Michael Yang committed
301
	// map tensor names to tensors for easy lookup later
302
303
304
305
306
307
308
	tensors := make(map[string]*C.struct_ggml_tensor)
	for _, c := range ctxs {
		for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
			tensors[C.GoString(C.ggml_get_name(t))] = t
		}
	}

309
310
311
312
313
	var doneBytes atomic.Uint64
	totalBytes := uint64(n) - meta.Tensors().Offset

	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(runtime.GOMAXPROCS(0))
314
	for _, t := range meta.Tensors().Items() {
315
316
317
318
		g.Go(func() error {
			tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
			for i := range tts {
				target := targets[t.Name][i]
319
320
321
				if target == "" {
					target = t.Name
				}
322

323
324
325
326
				tt, ok := tensors[target]
				if !ok {
					return fmt.Errorf("unassigned tensor: %s", t.Name)
				}
Michael Yang's avatar
Michael Yang committed
327

328
329
330
				tts[i] = tt
			}

331
332
333
334
			// Create a new FD for each goroutine so that each FD is read sequentially, rather than
			// seeking around within an FD shared between all goroutines.
			file, err := os.Open(r.Name())
			if err != nil {
Jesse Gross's avatar
Jesse Gross committed
335
				slog.Warn("file open error", "file", r.Name(), "error", err)
336
337
338
339
				return err
			}
			defer file.Close()
			sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
340
341
342
343
			bts := make([]byte, 128*format.KibiByte)

			var s uint64
			for s < t.Size() {
344
345
346
347
348
				// Stop if either the parent context has been canceled or if any of the other tensors returned an error
				if err := ctx.Err(); err != nil {
					return err
				}

349
350
				n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
				if err != nil {
Jesse Gross's avatar
Jesse Gross committed
351
					slog.Warn("file read error", "file", r.Name(), "error", err)
352
					return err
353
				}
Michael Yang's avatar
Michael Yang committed
354

355
356
				for _, tt := range tts {
					C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
357
				}
Michael Yang's avatar
Michael Yang committed
358

359
360
361
362
363
364
365
366
367
368
				s += uint64(n)

				if params.Progress != nil {
					done := doneBytes.Add(uint64(n))
					params.Progress(float32(done) / float32(totalBytes))
				}
			}

			return nil
		})
Michael Yang's avatar
Michael Yang committed
369
370
	}

371
	if err := g.Wait(); err != nil {
Michael Yang's avatar
Michael Yang committed
372
373
374
		return nil, err
	}

375
376
	// map devices to backend buffer types so new tensors can be assigned to the correct device
	deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
Michael Yang's avatar
Michael Yang committed
377
378
379
380

	// create backends and buffer types used for the compute graph scheduler
	var schedBackends []*C.struct_ggml_backend
	var schedBufts []*C.struct_ggml_backend_buffer_type
381
382
383
384
	for _, d := range append(gpus, append(accels, cpus...)...) {
		b := C.ggml_backend_dev_init(d, nil)
		bt := C.ggml_backend_get_default_buffer_type(b)

385
386
387
		deviceBufferTypes[d] = bt

		schedBackends = append(schedBackends, b)
Michael Yang's avatar
Michael Yang committed
388
		schedBufts = append(schedBufts, bt)
389

390
		if C.ggml_backend_is_cpu(b) {
Michael Yang's avatar
Michael Yang committed
391
			// set number of threads for cpu backend
Michael Yang's avatar
Michael Yang committed
392
			C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
393
		}
394
395
	}

Michael Yang's avatar
Michael Yang committed
396
	maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
Michael Yang's avatar
Michael Yang committed
397
	return &Backend{
398
		flashAttention: params.FlashAttention,
399
400
		meta:           meta,
		tensors:        tensors,
401
		sched: C.ggml_backend_sched_new(
Michael Yang's avatar
Michael Yang committed
402
403
404
405
			(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
			(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
			C.int(len(schedBackends)),
			C.size_t(maxGraphNodes),
406
			C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
407
		),
408
409
410
		schedBackends: schedBackends,
		schedBufts:    schedBufts,
		input:         deviceBufferTypes[input.d],
411
412
		layers: func() map[int]*C.struct_ggml_backend_buffer_type {
			m := make(map[int]*C.struct_ggml_backend_buffer_type)
413
			for i, layer := range layers {
414
				m[i] = deviceBufferTypes[layer.d]
415
416
417
			}
			return m
		}(),
Michael Yang's avatar
Michael Yang committed
418
		maxGraphNodes: maxGraphNodes,
Michael Yang's avatar
Michael Yang committed
419
420
421
422
423
424
425
	}, nil
}

func init() {
	ml.RegisterBackend("ggml", New)
}

426
func (b *Backend) Config() fs.Config {
Michael Yang's avatar
Michael Yang committed
427
428
429
430
	return b.meta.KV()
}

func (b *Backend) Get(name string) ml.Tensor {
431
432
	if t, ok := b.tensors[name]; ok {
		return &Tensor{b: b, t: t}
Michael Yang's avatar
Michael Yang committed
433
434
435
436
437
438
	}

	return nil
}

func (b *Backend) NewContext() ml.Context {
Michael Yang's avatar
Michael Yang committed
439
	return b.NewContextSize(b.maxGraphNodes)
440
441
442
}

func (b *Backend) NewContextSize(n int) ml.Context {
Jesse Gross's avatar
Jesse Gross committed
443
444
445
446
	if n > b.maxGraphNodes {
		panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
	}

447
448
	var allocatedBuffers []*C.struct_ggml_backend_buffer

Michael Yang's avatar
Michael Yang committed
449
	return &Context{
450
451
		b:             b,
		maxGraphNodes: n,
452
		ctx: C.ggml_init(C.struct_ggml_init_params{
453
			mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
454
455
			no_alloc: true,
		}),
456
		allocatedBuffers: &allocatedBuffers,
Michael Yang's avatar
Michael Yang committed
457
458
459
	}
}

460
func (b *Backend) CacheConfig() ml.CacheConfig {
461
462
463
464
465
	if b.flashAttention {
		return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
	} else {
		return ml.CacheConfig{CachePadding: 32, PermutedV: true}
	}
466
467
}

Michael Yang's avatar
Michael Yang committed
468
type Context struct {
469
	b *Backend
Michael Yang's avatar
Michael Yang committed
470

471
	ctx   *C.struct_ggml_context
Michael Yang's avatar
Michael Yang committed
472
	graph *C.struct_ggml_cgraph
473

474
475
	// buft is the buffer type used for new tensors
	buft *C.struct_ggml_backend_buffer_type
476

477
478
479
480
	// allocatedBuffers are buffers for tensors that we have allocated in this context
	// so that we can free them when we close the context
	allocatedBuffers *[]*C.struct_ggml_backend_buffer

Michael Yang's avatar
Michael Yang committed
481
	// maxGraphNodes is the maximum allowed number of graph nodes in this context
482
	maxGraphNodes int
Michael Yang's avatar
Michael Yang committed
483
484
}

485
func (c *Context) Input() ml.Context {
Michael Yang's avatar
Michael Yang committed
486
	if c.b.input != nil {
487
		return &Context{
488
489
490
491
492
			b:                c.b,
			ctx:              c.ctx,
			buft:             c.b.input,
			allocatedBuffers: c.allocatedBuffers,
			maxGraphNodes:    c.maxGraphNodes,
493
494
495
		}
	}

496
	return c
497
498
}

499
func (c *Context) Layer(i int) ml.Context {
500
	if buft, ok := c.b.layers[i]; ok {
501
		return &Context{
502
503
504
505
506
			b:                c.b,
			ctx:              c.ctx,
			buft:             buft,
			allocatedBuffers: c.allocatedBuffers,
			maxGraphNodes:    c.maxGraphNodes,
507
508
509
		}
	}

510
	return c
511
512
}

513
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
Michael Yang's avatar
Michael Yang committed
514
	if c.graph == nil {
515
		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
Michael Yang's avatar
Michael Yang committed
516
517
	}

518
519
520
521
522
	for _, tensor := range tensors {
		C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
	}

	return c
Michael Yang's avatar
Michael Yang committed
523
524
}

525
func (c *Context) Compute(tensors ...ml.Tensor) {
526
	C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
Michael Yang's avatar
Michael Yang committed
527
	C.ggml_backend_sched_reset(c.b.sched)
Michael Yang's avatar
Michael Yang committed
528

529
530
531
	needSync := true
	sync := func() {
		if needSync {
532
			C.ggml_backend_sched_synchronize(c.b.sched)
533
534
535
			needSync = false
		}
	}
Michael Yang's avatar
Michael Yang committed
536

537
538
539
	for _, t := range tensors {
		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
			t.(*Tensor).sync = sync
540
541
		}
	}
Michael Yang's avatar
Michael Yang committed
542
543
}

544
func (c *Context) Reserve() error {
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
	if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
		C.ggml_backend_sched_reset(c.b.sched)
		return errors.New("failed to reserve graph")
	}

	slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
	for i := range c.b.schedBackends {
		size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
		slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
			"size", format.HumanBytes2(uint64(size)))
	}

	C.ggml_backend_sched_reset(c.b.sched)

	return nil
}

562
func (c *Context) MaxGraphNodes() int {
563
	return c.maxGraphNodes
Jesse Gross's avatar
Jesse Gross committed
564
565
}

566
567
568
func shapeToGGML(shape []int) *C.int64_t {
	sh := make([]C.int64_t, len(shape))
	for i, s := range shape {
569
		sh[i] = C.int64_t(s)
570
571
572
573
574
	}

	return &sh[0]
}

575
576
577
578
func pad(length, pad C.size_t) C.size_t {
	return ((length + pad - 1) / pad) * pad
}

579
func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
580
	if c.buft == nil {
581
		panic("set Input or Layer before creating tensors")
582
583
	}

Michael Yang's avatar
Michael Yang committed
584
585
586
587
588
589
	var cdtype uint32
	switch dtype {
	case ml.DTypeF32:
		cdtype = C.GGML_TYPE_F32
	case ml.DTypeF16:
		cdtype = C.GGML_TYPE_F16
590
591
592
593
	case ml.DTypeQ80:
		cdtype = C.GGML_TYPE_Q8_0
	case ml.DTypeQ40:
		cdtype = C.GGML_TYPE_Q4_0
Michael Yang's avatar
Michael Yang committed
594
595
596
597
598
599
	case ml.DTypeI32:
		cdtype = C.GGML_TYPE_I32
	default:
		panic("unsupported dtype")
	}

Jesse Gross's avatar
Jesse Gross committed
600
	if len(shape) < 1 || shape[0] == 0 {
Michael Yang's avatar
Michael Yang committed
601
		var shape C.int64_t = 0
602
		return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
Michael Yang's avatar
Michael Yang committed
603
	} else if len(shape) > 4 {
Michael Yang's avatar
Michael Yang committed
604
605
606
607
608
609
610
611
612
		panic("unsupported number of dimensions")
	}

	for _, dim := range shape {
		if dim < 1 {
			panic("invalid shape")
		}
	}

Michael Yang's avatar
Michael Yang committed
613
	t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
614
615
	size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
	b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
616
617
618
	if b == nil {
		return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
	}
619
	*c.allocatedBuffers = append(*c.allocatedBuffers, b)
620

Michael Yang's avatar
Michael Yang committed
621
	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
622
	return &Tensor{b: c.b, t: t}, nil
623
624
}

625
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
626
627
628
629
630
631
	t, err := c.newTensor(dtype, shape)
	if err != nil {
		panic(err)
	}

	return t
632
633
}

634
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
635
636
637
638
639
	t, err := c.newTensor(dtype, shape)
	if err != nil {
		panic(err)
	}

640
641
	C.ggml_set_zero(t.(*Tensor).t)
	return t
Michael Yang's avatar
Michael Yang committed
642
643
}

644
func checkShape[S ~[]E, E any](s S, shape ...int) error {
Michael Yang's avatar
Michael Yang committed
645
	n := len(s)
Jesse Gross's avatar
Jesse Gross committed
646
647
648
649
650

	if n == 0 {
		return nil
	}

Michael Yang's avatar
Michael Yang committed
651
652
653
654
655
	for _, v := range shape {
		n /= v
	}

	if n != 1 {
656
		return fmt.Errorf("invalid shape: %v", shape)
Michael Yang's avatar
Michael Yang committed
657
658
	}

659
	return nil
Michael Yang's avatar
Michael Yang committed
660
661
}

662
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
Jesse Gross's avatar
Jesse Gross committed
663
	if err := checkShape(s, shape...); err != nil {
664
665
666
		return nil, err
	}

667
668
669
670
671
	t, err := c.newTensor(ml.DTypeF32, shape)
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
672
673
674
675
	if len(s) > 0 {
		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	}

676
	return t, nil
Michael Yang's avatar
Michael Yang committed
677
678
}

679
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
Jesse Gross's avatar
Jesse Gross committed
680
	if err := checkShape(s, shape...); err != nil {
681
682
683
		return nil, err
	}

684
685
686
687
688
	t, err := c.newTensor(ml.DTypeI32, shape)
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
689
690
691
692
	if len(s) > 0 {
		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	}

693
	return t, nil
Michael Yang's avatar
Michael Yang committed
694
695
}

Michael Yang's avatar
arange  
Michael Yang committed
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
	switch dtype {
	case ml.DTypeF32:
		// ggml_arange creates a float32 tensor
		return &Tensor{
			b: c.b,
			t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)),
		}
	case ml.DTypeI32:
		// ggml_cast does not support float32 to int32 conversion
		arange := make([]int32, 0, int((stop-start)/step))
		for i := start; i < stop; i += step {
			arange = append(arange, int32(i))
		}

		t, err := c.Input().FromIntSlice(arange, len(arange))
		if err != nil {
			panic(err)
		}

		return t
	default:
		panic("unsupported dtype for arange")
	}
}

Michael Yang's avatar
Michael Yang committed
722
723
func (c *Context) Close() {
	if c != nil {
724
725
726
727
728
		for _, b := range *c.allocatedBuffers {
			C.ggml_backend_buffer_free(b)
		}
		*c.allocatedBuffers = nil

729
730
		C.ggml_free(c.ctx)
	}
Michael Yang's avatar
Michael Yang committed
731
732
733
}

type Tensor struct {
734
	b    *Backend
Michael Yang's avatar
Michael Yang committed
735
	t    *C.struct_ggml_tensor
736
	sync func()
Michael Yang's avatar
Michael Yang committed
737
738
739
740
741
742
743
744
745
746
}

func (t *Tensor) LogValue() slog.Value {
	return slog.GroupValue(
		slog.String("name", C.GoString(C.ggml_get_name(t.t))),
		slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
		slog.Any("shape", t.Shape()),
	)
}

747
748
func (t *Tensor) Dim(n int) int {
	return int(t.t.ne[n])
Michael Yang's avatar
Michael Yang committed
749
750
}

751
752
func (t *Tensor) Stride(n int) int {
	return int(t.t.nb[n])
Michael Yang's avatar
Michael Yang committed
753
754
}

755
756
func (t *Tensor) Shape() []int {
	shape := make([]int, C.ggml_n_dims(t.t))
Michael Yang's avatar
Michael Yang committed
757
758
759
760
761
762
763
	for i := range shape {
		shape[i] = t.Dim(i)
	}

	return shape
}

764
765
766
767
768
769
770
771
772
func (t *Tensor) Bytes() (data []byte) {
	if t.sync != nil {
		data = make([]byte, C.ggml_nbytes(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
	}

	return
Michael Yang's avatar
Michael Yang committed
773
774
}

775
776
777
778
779
780
func (t *Tensor) Floats() (data []float32) {
	if t.sync != nil {
		data = make([]float32, C.ggml_nelements(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
Michael Yang's avatar
Michael Yang committed
781
782
783
784
785
786
787
788
789
	}

	return
}

func (t *Tensor) DType() ml.DType {
	switch t.t._type {
	case C.GGML_TYPE_F32:
		return ml.DTypeF32
Jesse Gross's avatar
Jesse Gross committed
790
791
	case C.GGML_TYPE_F16:
		return ml.DTypeF16
792
793
794
795
	case C.GGML_TYPE_Q8_0:
		return ml.DTypeQ80
	case C.GGML_TYPE_Q4_0:
		return ml.DTypeQ40
Michael Yang's avatar
Michael Yang committed
796
797
798
799
800
801
802
	case C.GGML_TYPE_I32:
		return ml.DTypeI32
	default:
		return ml.DTypeOther
	}
}

803
804
805
806
807
808
809
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_neg(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
810
811
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
812
		b: t.b,
Michael Yang's avatar
Michael Yang committed
813
814
815
816
		t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
	if dim < 0 || dim >= C.GGML_MAX_DIMS {
		panic("invalid dimension")
	}

	shape := make([]C.int64_t, C.GGML_MAX_DIMS)
	for i := range C.GGML_MAX_DIMS {
		if i == dim {
			shape[i] = C.int64_t(t.Dim(i) * n)
		} else {
			shape[i] = C.int64_t(t.Dim(i))
		}
	}

	tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape))
	return &Tensor{
		b: t.b,
		t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl),
	}
}

Michael Yang's avatar
Michael Yang committed
838
839
840
841
842
843
844
845
846
847
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
	if len(s) > 0 {
		return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
	}

	return t
}

func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
	return &Tensor{
848
		b: t.b,
Michael Yang's avatar
Michael Yang committed
849
850
851
852
853
854
		t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
	}
}

func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
	return &Tensor{
855
		b: t.b,
Michael Yang's avatar
Michael Yang committed
856
857
858
859
860
861
		t: C.ggml_cont(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
862
		b: t.b,
Michael Yang's avatar
Michael Yang committed
863
864
865
866
867
868
		t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
869
		b: t.b,
Michael Yang's avatar
Michael Yang committed
870
871
872
873
		t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

874
875
876
877
878
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
	C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)

	return &Tensor{
879
		b: t.b,
880
881
882
883
		t: mul,
	}
}

Michael Yang's avatar
llama4  
Michael Yang committed
884
885
886
887
888
889
890
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
	}
}

Michael Yang's avatar
Michael Yang committed
891
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
892
893
894
895
896
897
	tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
	if w != nil {
		tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
		if b != nil {
			tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t)
		}
Michael Yang's avatar
Michael Yang committed
898
899
	}

Michael Yang's avatar
llama4  
Michael Yang committed
900
	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
901
902
903
}

func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
904
905
906
907
908
909
	tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))
	if w != nil {
		tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
	}

	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
910
911
}

912
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
913
914
915
916
917
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
918
		b: t.b,
Michael Yang's avatar
Michael Yang committed
919
920
921
922
923
924
925
926
927
928
		t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
929
		b: t.b,
Michael Yang's avatar
Michael Yang committed
930
931
932
933
934
935
		t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
936
		b: t.b,
Michael Yang's avatar
Michael Yang committed
937
938
939
940
941
942
		t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
943
		b: t.b,
Michael Yang's avatar
Michael Yang committed
944
945
946
947
		t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

948
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
949
950
951
	switch len(shape) {
	case 1:
		return &Tensor{
952
			b: t.b,
Michael Yang's avatar
Michael Yang committed
953
954
955
956
			t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
		}
	case 2:
		return &Tensor{
957
			b: t.b,
Michael Yang's avatar
Michael Yang committed
958
959
960
961
			t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
		}
	case 3:
		return &Tensor{
962
			b: t.b,
Michael Yang's avatar
Michael Yang committed
963
964
965
966
			t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
		}
	case 4:
		return &Tensor{
967
			b: t.b,
Michael Yang's avatar
Michael Yang committed
968
969
970
971
972
973
974
975
976
			t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
	return &Tensor{
977
		b: t.b,
Michael Yang's avatar
Michael Yang committed
978
979
980
981
982
983
		t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
	}
}

func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
	return &Tensor{
984
		b: t.b,
Michael Yang's avatar
Michael Yang committed
985
986
987
988
		t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
	}
}

989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sin(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_cos(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1003
1004
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
	return &Tensor{
1005
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1006
1007
1008
1009
		t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
llama4  
Michael Yang committed
1010
1011
1012
1013
1014
1015
1016
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
	}
}

1017
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
1018
1019
1020
1021
1022
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
1023
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1024
1025
1026
1027
1028
1029
1030
1031
		t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
	switch len(shape) {
	case 1:
		return &Tensor{
1032
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1033
1034
1035
1036
			t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
		}
	case 3:
		return &Tensor{
1037
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1038
1039
1040
1041
1042
1043
1044
			t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]),
				C.size_t(shape[1]),
				C.size_t(offset)),
		}
	case 5:
		return &Tensor{
1045
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1046
1047
1048
1049
1050
1051
1052
			t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
				C.size_t(shape[1]), C.size_t(shape[3]),
				C.size_t(offset)),
		}
	case 7:
		return &Tensor{
1053
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
			t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
				C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
				C.size_t(offset)),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

const (
Patrick Devine's avatar
Patrick Devine committed
1065
1066
1067
1068
	ropeTypeNorm   C.int = 0
	ropeTypeNeox   C.int = 2
	ropeTypeMrope  C.int = 8
	ropeTypeVision C.int = 24
Michael Yang's avatar
Michael Yang committed
1069
1070
)

Patrick Devine's avatar
Patrick Devine committed
1071
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
1072
	if ropeFactors == nil {
1073
		ropeFactors = &Tensor{b: t.b}
Michael Yang's avatar
Michael Yang committed
1074
1075
	}

Jesse Gross's avatar
Jesse Gross committed
1076
1077
1078
1079
1080
	dequant := t.t
	if C.ggml_is_quantized(t.t._type) {
		dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
	}

Michael Yang's avatar
Michael Yang committed
1081
	return &Tensor{
1082
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1083
		t: C.ggml_rope_ext(
Jesse Gross's avatar
Jesse Gross committed
1084
			ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
Michael Yang's avatar
Michael Yang committed
1085
			C.int(ropeDim),
Patrick Devine's avatar
Patrick Devine committed
1086
1087
			C.int(ropeType),
			131072, // YaRN n_ctx_train
Michael Yang's avatar
Michael Yang committed
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
			C.float(ropeBase),
			C.float(ropeScale),
			0.,  // YaRN ext_factor
			1.,  // YaRN attn_factor
			32., // YaRN beta_fast
			1.,  // YaRN beta_slow
		),
	}
}

1098
1099
1100
1101
1102
1103
1104
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
	}
}

Michael Yang's avatar
Michael Yang committed
1105
1106
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
	return &Tensor{
1107
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1108
1109
1110
1111
1112
1113
		t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
	return &Tensor{
1114
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1115
1116
1117
1118
1119
1120
		t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
1121
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1122
1123
1124
		t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
	}
}
1125

Michael Yang's avatar
Michael Yang committed
1126
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
1127
1128
	return &Tensor{
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1129
		t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
Michael Yang's avatar
Michael Yang committed
1130
1131
1132
	}
}

Michael Yang's avatar
Michael Yang committed
1133
1134
1135
1136
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
	var tt *C.struct_ggml_tensor
	switch len(strides) {
	case 0:
Michael Yang's avatar
Michael Yang committed
1137
		tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
Michael Yang's avatar
Michael Yang committed
1138
	case 1:
Michael Yang's avatar
Michael Yang committed
1139
		tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
Michael Yang's avatar
Michael Yang committed
1140
1141
1142
1143
1144
1145
1146
	default:
		panic("unsupported number of dimensions")
	}

	return &Tensor{b: t.b, t: tt}
}

1147
1148
1149
1150
1151
1152
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
	var kqMask *C.struct_ggml_tensor
	if mask != nil {
		kqMask = mask.(*Tensor).t
	}

1153
1154
1155
	query := t.Permute(ctx, 0, 2, 1, 3)
	key = key.Permute(ctx, 0, 2, 1, 3)

1156
1157
	if t.b.flashAttention {
		value = value.Permute(ctx, 0, 2, 1, 3)
1158

1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
		kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
		C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
		return &Tensor{b: t.b, t: kqv}
	} else {
		kq := key.MulmatFullPrec(ctx, query)
		kq = &Tensor{
			b: t.b,
			t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
		}

		kqv := value.Mulmat(ctx, kq)
		return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	}
1172
}
1173
1174
1175
1176
1177
1178
1179

func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_dup(ctx.(*Context).ctx, t.t),
	}
}
Michael Yang's avatar
llama4  
Michael Yang committed
1180
1181
1182
1183
1184
1185
1186

func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
	}
}