ggml.go 30.1 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
package ggml

3
4
5
6
7
8
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
// #include <stdlib.h>
// #include <stdint.h>
// #include "ggml.h"
// #include "ggml-cpu.h"
// #include "ggml-backend.h"
Michael Yang's avatar
Michael Yang committed
9
10
11
import "C"

import (
12
	"context"
13
	"errors"
Michael Yang's avatar
Michael Yang committed
14
15
16
	"fmt"
	"io"
	"log/slog"
17
	"maps"
Michael Yang's avatar
Michael Yang committed
18
	"os"
19
	"runtime"
20
21
22
	"slices"
	"strconv"
	"strings"
23
	"sync/atomic"
24
	"unicode"
Michael Yang's avatar
Michael Yang committed
25
26
27
	"unsafe"

	"github.com/ollama/ollama/format"
28
29
	"github.com/ollama/ollama/fs"
	fsggml "github.com/ollama/ollama/fs/ggml"
30
	"github.com/ollama/ollama/logutil"
Michael Yang's avatar
Michael Yang committed
31
	"github.com/ollama/ollama/ml"
32
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
Michael Yang's avatar
Michael Yang committed
33
34
35
	"golang.org/x/sync/errgroup"
)

Michael Yang's avatar
Michael Yang committed
36
37
38
39
40
func devices() []*C.struct_ggml_backend_device {
	ggml.OnceLoad()
	ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
	for i := range ds {
		ds[i] = C.ggml_backend_dev_get(C.size_t(i))
Michael Yang's avatar
Michael Yang committed
41
	}
Michael Yang's avatar
Michael Yang committed
42
43

	return ds
44
}
Michael Yang's avatar
Michael Yang committed
45
46

type Backend struct {
47
48
49
	// modelPath is the location of the model data
	modelPath string

50
51
	meta *fsggml.GGML

52
53
54
55
	// tensorLoadTargets maps from the name of the tensor in the file
	// to the name that is used by the model definition
	tensorLoadTargets map[string][]string

56
57
58
59
	sched         *C.struct_ggml_backend_sched
	schedBackends []*C.struct_ggml_backend
	schedBufts    []*C.struct_ggml_backend_buffer_type

60
	tensors map[string]*C.struct_ggml_tensor
Michael Yang's avatar
Michael Yang committed
61
62

	// input is the backend used for inputs
63
	input *C.struct_ggml_backend_buffer_type
Michael Yang's avatar
Michael Yang committed
64
65

	// layers is the backend used for repeating layers
66
	layers map[int]*C.struct_ggml_backend_buffer_type
67

68
	flashAttention bool
Michael Yang's avatar
Michael Yang committed
69
70
71

	// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
	maxGraphNodes int
Michael Yang's avatar
Michael Yang committed
72
73
}

74
75
76
77
78
79
80
81
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
	r, err := os.Open(modelPath)
	if err != nil {
		return nil, err
	}
	defer r.Close()

	meta, err := fsggml.Decode(r, -1)
Michael Yang's avatar
Michael Yang committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
	if err != nil {
		return nil, err
	}

	slog.Info(
		"",
		"architecture", meta.KV().Architecture(),
		"file_type", meta.KV().FileType(),
		"name", meta.KV().String("general.name"),
		"description", meta.KV().String("general.description"),
		"num_tensors", len(meta.Tensors().Items()),
		"num_key_values", len(meta.KV()),
	)

96
	type deviceBufferType struct {
97
98
99
100
101
		d   *C.struct_ggml_backend_device
		bts []*C.struct_ggml_backend_buffer_type
	}

	var cpus, accels, gpus []*C.struct_ggml_backend_device
Michael Yang's avatar
Michael Yang committed
102
	for _, d := range devices() {
103
104
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU:
105
106
107
108
			if len(cpus) == 0 {
				// only the first cpu device should be used
				cpus = append(cpus, d)
			}
109
110
		case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
			accels = append(accels, d)
Michael Yang's avatar
Michael Yang committed
111
		case C.GGML_BACKEND_DEVICE_TYPE_GPU:
112
			gpus = append(gpus, d)
Michael Yang's avatar
Michael Yang committed
113
114
115
		}
	}

Michael Yang's avatar
Michael Yang committed
116
	// create list of buffer types for the cpu
Michael Yang's avatar
Michael Yang committed
117
	cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
118
119
120
121
	for _, d := range append(accels, append(gpus, cpus...)...) {
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU,
			C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
Michael Yang's avatar
Michael Yang committed
122
			cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
Michael Yang's avatar
Michael Yang committed
123
		}
124
125
	}

Michael Yang's avatar
Michael Yang committed
126
	// create list of buffer types for each gpu
127
	var gpuDeviceBufferTypes []deviceBufferType
128
129
	for _, d := range gpus {
		bt := C.ggml_backend_dev_buffer_type(d)
130
		gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
131
			d:   d,
Michael Yang's avatar
Michael Yang committed
132
			bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
133
		})
Michael Yang's avatar
Michael Yang committed
134
135
	}

Michael Yang's avatar
Michael Yang committed
136
137
138
139
140
	useDefaultSplit := true
	for _, s := range params.TensorSplit {
		if s != 0 {
			useDefaultSplit = false
			break
141
		}
Michael Yang's avatar
Michael Yang committed
142
	}
143

Michael Yang's avatar
Michael Yang committed
144
145
146
147
	// calculate splits
	splits := make([]float32, len(gpus))
	if useDefaultSplit {
		// default: split on free memory
148
149
150
151
152
		for i := range splits {
			var free, total C.size_t
			C.ggml_backend_dev_memory(gpus[i], &free, &total)
			splits[i] = float32(free)
		}
Michael Yang's avatar
Michael Yang committed
153
154
	} else {
		splits = params.TensorSplit
155
156
157
	}

	var sum float32
Michael Yang's avatar
Michael Yang committed
158
	// cumulative sum of all splits
159
160
161
162
163
	for i := range splits {
		sum += splits[i]
		splits[i] = sum
	}

Michael Yang's avatar
Michael Yang committed
164
	// normalize splits
165
	for i := range splits {
166
		splits[i] /= sum
167
168
	}

Michael Yang's avatar
Michael Yang committed
169
	// inputs always use cpu
Michael Yang's avatar
Michael Yang committed
170
	input := cpuDeviceBufferType
171

172
	blocks := int(meta.KV().BlockCount())
Michael Yang's avatar
Michael Yang committed
173
174
175
176

	// define a range of gpu layers. anything outside of this range is assigned to the cpu
	gpuRangeStart := max(0, blocks-params.NumGPULayers)
	gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
Michael Yang's avatar
Michael Yang committed
177
	assignLayer := func(i int) deviceBufferType {
Michael Yang's avatar
Michael Yang committed
178
		if i < gpuRangeStart || i >= gpuRangeStop {
Michael Yang's avatar
Michael Yang committed
179
			return cpuDeviceBufferType
180
		}
181

Michael Yang's avatar
Michael Yang committed
182
		index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
183
		if index < 0 || index >= len(gpuDeviceBufferTypes) {
Michael Yang's avatar
Michael Yang committed
184
			return cpuDeviceBufferType
185
186
187
		}

		return gpuDeviceBufferTypes[index]
188
189
	}

Michael Yang's avatar
Michael Yang committed
190
	// repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
191
	layers := make([]deviceBufferType, blocks)
192
	for i := range layers {
193
		layers[i] = assignLayer(i)
194
195
	}

Michael Yang's avatar
Michael Yang committed
196
	// outputs are assigned iff allowed by splits and configured number of gpu layers
197
	output := assignLayer(blocks)
198
199
200

	maxTensors := len(meta.Tensors().Items())
	maxTensors += 1
Michael Yang's avatar
Michael Yang committed
201
	// each layer has at most 2 extra tensors for rope operations
202
203
	maxTensors += blocks * 2

204
	type tensor struct {
205
		source *fsggml.Tensor
206
207
208
		target string
	}

Michael Yang's avatar
Michael Yang committed
209
	// some tensors are mapped to different names so keep a list
210
211
	targets := make(map[string][]string)

Michael Yang's avatar
Michael Yang committed
212
	// contexts are shared by tensors of the same buffer type
213
	ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
214
	createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
215
216
217
218
219
220
221
		for _, bt := range bts {
			if _, ok := ctxs[bt]; !ok {
				ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
					mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
					no_alloc: true,
				})
			}
Michael Yang's avatar
Michael Yang committed
222

223
224
225
226
227
228
229
230
			targets[t.source.Name] = append(targets[t.source.Name], t.target)

			name := t.source.Name
			if t.target != "" {
				name = t.target
			}

			cname := C.CString(name)
Michael Yang's avatar
Michael Yang committed
231
			defer C.free(unsafe.Pointer(cname))
232
233
234
235
			if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
				return tt
			}

236
			tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
Michael Yang's avatar
Michael Yang committed
237
238
			C.ggml_set_name(tt, cname)

239
			slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
240
241
242
243
244
			//nolint:staticcheck // TODO: check if buffer type supports this tensor
			return tt
		}

		return nil
Michael Yang's avatar
Michael Yang committed
245
246
	}

247
	contains := func(s string, parts ...string) bool {
248
249
250
251
252
253
254
255
		split := strings.Split(s, ".")
		for _, part := range parts {
			if slices.Contains(split, part) {
				return true
			}
		}

		return false
Michael Yang's avatar
Michael Yang committed
256
257
	}

258
259
	for _, t := range meta.Tensors().Items() {
		switch {
260
		case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
261
			createTensor(tensor{source: t}, input.bts)
Michael Yang's avatar
Michael Yang committed
262
263
264
			if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
				createTensor(tensor{source: t, target: "output.weight"}, output.bts)
			}
265
		case contains(t.Name, "cls", "output", "output_norm"):
266
			createTensor(tensor{source: t}, output.bts)
267
		case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
Michael Yang's avatar
Michael Yang committed
268
			// TODO: assign vision tensors to the gpu if possible
Michael Yang's avatar
Michael Yang committed
269
			createTensor(tensor{source: t}, output.bts)
Michael Yang's avatar
Michael Yang committed
270
271
272
273
274
275
276
277
		case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
			// these tensors should be repeated per layer
			for i, layer := range layers {
				createTensor(tensor{
					source: t,
					target: "blk." + strconv.Itoa(i) + "." + t.Name,
				}, layer.bts)
			}
278
		default:
Michael Yang's avatar
Michael Yang committed
279
280
281
282
			layerIndex := -1
			if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
				if i, err := strconv.Atoi(fields[0]); err == nil {
					layerIndex = i
283
				}
Michael Yang's avatar
Michael Yang committed
284
			}
285

Michael Yang's avatar
Michael Yang committed
286
287
			if layerIndex >= 0 {
				createTensor(tensor{source: t}, layers[layerIndex].bts)
288
			} else {
Michael Yang's avatar
Michael Yang committed
289
290
				// load all other tensors on the cpu
				createTensor(tensor{source: t}, input.bts)
291
292
293
			}
		}
	}
Michael Yang's avatar
Michael Yang committed
294

Michael Yang's avatar
Michael Yang committed
295
296
	// allocate buffers for each context
	bbs := make(map[*C.struct_ggml_context]*C.struct_ggml_backend_buffer, len(ctxs))
297
298
299
300
301
302
	for bt, c := range ctxs {
		if C.ggml_get_first_tensor(c) == nil {
			continue
		}

		b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
303
304
305
306
		if b == nil {
			return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
		}

307
		C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
Michael Yang's avatar
Michael Yang committed
308
		bbs[c] = b
309
310
311
	}

	for bs := range maps.Values(bbs) {
Michael Yang's avatar
Michael Yang committed
312
		slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
313
314
	}

Michael Yang's avatar
Michael Yang committed
315
	// map tensor names to tensors for easy lookup later
316
317
318
319
320
321
322
	tensors := make(map[string]*C.struct_ggml_tensor)
	for _, c := range ctxs {
		for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
			tensors[C.GoString(C.ggml_get_name(t))] = t
		}
	}

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
	// map devices to backend buffer types so new tensors can be assigned to the correct device
	deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)

	// create backends and buffer types used for the compute graph scheduler
	var schedBackends []*C.struct_ggml_backend
	var schedBufts []*C.struct_ggml_backend_buffer_type
	for _, d := range append(gpus, append(accels, cpus...)...) {
		b := C.ggml_backend_dev_init(d, nil)
		bt := C.ggml_backend_get_default_buffer_type(b)

		deviceBufferTypes[d] = bt

		schedBackends = append(schedBackends, b)
		schedBufts = append(schedBufts, bt)

		if C.ggml_backend_is_cpu(b) {
			// set number of threads for cpu backend
			C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
		}
	}

	maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
	return &Backend{
		modelPath:         modelPath,
		flashAttention:    params.FlashAttention,
		meta:              meta,
		tensorLoadTargets: targets,
		tensors:           tensors,
		sched: C.ggml_backend_sched_new(
			(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
			(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
			C.int(len(schedBackends)),
			C.size_t(maxGraphNodes),
			C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
			C._Bool(false),
		),
		schedBackends: schedBackends,
		schedBufts:    schedBufts,
		input:         deviceBufferTypes[input.d],
		layers: func() map[int]*C.struct_ggml_backend_buffer_type {
			m := make(map[int]*C.struct_ggml_backend_buffer_type)
			for i, layer := range layers {
				m[i] = deviceBufferTypes[layer.d]
			}
			return m
		}(),
		maxGraphNodes: maxGraphNodes,
	}, nil
}

func init() {
	ml.RegisterBackend("ggml", New)
}

func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
378
	var doneBytes atomic.Uint64
379
	totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
380
381
382

	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(runtime.GOMAXPROCS(0))
383
	for _, t := range b.meta.Tensors().Items() {
384
		t := t
385
		g.Go(func() error {
386
			tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
387
			for i := range tts {
388
				target := b.tensorLoadTargets[t.Name][i]
389
390
391
				if target == "" {
					target = t.Name
				}
392

393
				tt, ok := b.tensors[target]
394
395
396
				if !ok {
					return fmt.Errorf("unassigned tensor: %s", t.Name)
				}
Michael Yang's avatar
Michael Yang committed
397

398
399
400
				tts[i] = tt
			}

401
402
			// Create a new FD for each goroutine so that each FD is read sequentially, rather than
			// seeking around within an FD shared between all goroutines.
403
			file, err := os.Open(b.modelPath)
404
			if err != nil {
405
				slog.Warn("file open error", "file", b.modelPath, "error", err)
406
407
408
				return err
			}
			defer file.Close()
409
			sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size()))
410
411
412
413
			bts := make([]byte, 128*format.KibiByte)

			var s uint64
			for s < t.Size() {
414
415
416
417
418
				// Stop if either the parent context has been canceled or if any of the other tensors returned an error
				if err := ctx.Err(); err != nil {
					return err
				}

419
420
				n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
				if err != nil {
421
					slog.Warn("file read error", "file", b.modelPath, "error", err)
422
					return err
423
				}
Michael Yang's avatar
Michael Yang committed
424

425
426
				for _, tt := range tts {
					C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
427
				}
Michael Yang's avatar
Michael Yang committed
428

429
430
				s += uint64(n)

431
				if progress != nil {
432
					done := doneBytes.Add(uint64(n))
433
					progress(float32(done) / float32(totalBytes))
434
435
436
437
438
				}
			}

			return nil
		})
Michael Yang's avatar
Michael Yang committed
439
440
	}

441
	if err := g.Wait(); err != nil {
442
		return err
443
444
	}

445
	return nil
Michael Yang's avatar
Michael Yang committed
446
447
}

448
func (b *Backend) Config() fs.Config {
Michael Yang's avatar
Michael Yang committed
449
450
451
452
	return b.meta.KV()
}

func (b *Backend) Get(name string) ml.Tensor {
453
454
	if t, ok := b.tensors[name]; ok {
		return &Tensor{b: b, t: t}
Michael Yang's avatar
Michael Yang committed
455
456
457
458
459
460
	}

	return nil
}

func (b *Backend) NewContext() ml.Context {
Michael Yang's avatar
Michael Yang committed
461
	return b.NewContextSize(b.maxGraphNodes)
462
463
464
}

func (b *Backend) NewContextSize(n int) ml.Context {
Jesse Gross's avatar
Jesse Gross committed
465
466
467
468
	if n > b.maxGraphNodes {
		panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
	}

469
470
	var allocatedBuffers []*C.struct_ggml_backend_buffer

Michael Yang's avatar
Michael Yang committed
471
	return &Context{
472
473
		b:             b,
		maxGraphNodes: n,
474
		ctx: C.ggml_init(C.struct_ggml_init_params{
475
			mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
476
477
			no_alloc: true,
		}),
478
		allocatedBuffers: &allocatedBuffers,
Michael Yang's avatar
Michael Yang committed
479
480
481
	}
}

482
func (b *Backend) CacheConfig() ml.CacheConfig {
483
484
485
486
487
	if b.flashAttention {
		return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
	} else {
		return ml.CacheConfig{CachePadding: 32, PermutedV: true}
	}
488
489
}

Michael Yang's avatar
Michael Yang committed
490
type Context struct {
491
	b *Backend
Michael Yang's avatar
Michael Yang committed
492

493
	ctx   *C.struct_ggml_context
Michael Yang's avatar
Michael Yang committed
494
	graph *C.struct_ggml_cgraph
495

496
497
	// buft is the buffer type used for new tensors
	buft *C.struct_ggml_backend_buffer_type
498

499
500
501
502
	// allocatedBuffers are buffers for tensors that we have allocated in this context
	// so that we can free them when we close the context
	allocatedBuffers *[]*C.struct_ggml_backend_buffer

Michael Yang's avatar
Michael Yang committed
503
	// maxGraphNodes is the maximum allowed number of graph nodes in this context
504
	maxGraphNodes int
Michael Yang's avatar
Michael Yang committed
505
506
}

507
func (c *Context) Input() ml.Context {
Michael Yang's avatar
Michael Yang committed
508
	if c.b.input != nil {
509
		return &Context{
510
511
512
513
514
			b:                c.b,
			ctx:              c.ctx,
			buft:             c.b.input,
			allocatedBuffers: c.allocatedBuffers,
			maxGraphNodes:    c.maxGraphNodes,
515
516
517
		}
	}

518
	return c
519
520
}

521
func (c *Context) Layer(i int) ml.Context {
522
	if buft, ok := c.b.layers[i]; ok {
523
		return &Context{
524
525
526
527
528
			b:                c.b,
			ctx:              c.ctx,
			buft:             buft,
			allocatedBuffers: c.allocatedBuffers,
			maxGraphNodes:    c.maxGraphNodes,
529
530
531
		}
	}

532
	return c
533
534
}

535
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
Michael Yang's avatar
Michael Yang committed
536
	if c.graph == nil {
537
		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
Michael Yang's avatar
Michael Yang committed
538
539
	}

540
541
542
543
544
	for _, tensor := range tensors {
		C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
	}

	return c
Michael Yang's avatar
Michael Yang committed
545
546
}

547
func (c *Context) Compute(tensors ...ml.Tensor) {
548
	C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
Michael Yang's avatar
Michael Yang committed
549
	C.ggml_backend_sched_reset(c.b.sched)
Michael Yang's avatar
Michael Yang committed
550

551
552
553
	needSync := true
	sync := func() {
		if needSync {
554
			C.ggml_backend_sched_synchronize(c.b.sched)
555
556
557
			needSync = false
		}
	}
Michael Yang's avatar
Michael Yang committed
558

559
560
561
	for _, t := range tensors {
		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
			t.(*Tensor).sync = sync
562
563
		}
	}
Michael Yang's avatar
Michael Yang committed
564
565
}

566
func (c *Context) Reserve() error {
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
	if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
		C.ggml_backend_sched_reset(c.b.sched)
		return errors.New("failed to reserve graph")
	}

	slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
	for i := range c.b.schedBackends {
		size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
		slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
			"size", format.HumanBytes2(uint64(size)))
	}

	C.ggml_backend_sched_reset(c.b.sched)

	return nil
}

584
func (c *Context) MaxGraphNodes() int {
585
	return c.maxGraphNodes
Jesse Gross's avatar
Jesse Gross committed
586
587
}

588
589
590
func shapeToGGML(shape []int) *C.int64_t {
	sh := make([]C.int64_t, len(shape))
	for i, s := range shape {
591
		sh[i] = C.int64_t(s)
592
593
594
595
596
	}

	return &sh[0]
}

597
598
599
600
func pad(length, pad C.size_t) C.size_t {
	return ((length + pad - 1) / pad) * pad
}

601
func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
602
	if c.buft == nil {
603
		panic("set Input or Layer before creating tensors")
604
605
	}

Michael Yang's avatar
Michael Yang committed
606
607
608
609
610
611
	var cdtype uint32
	switch dtype {
	case ml.DTypeF32:
		cdtype = C.GGML_TYPE_F32
	case ml.DTypeF16:
		cdtype = C.GGML_TYPE_F16
612
613
614
615
	case ml.DTypeQ80:
		cdtype = C.GGML_TYPE_Q8_0
	case ml.DTypeQ40:
		cdtype = C.GGML_TYPE_Q4_0
Michael Yang's avatar
Michael Yang committed
616
617
618
619
620
621
	case ml.DTypeI32:
		cdtype = C.GGML_TYPE_I32
	default:
		panic("unsupported dtype")
	}

Jesse Gross's avatar
Jesse Gross committed
622
	if len(shape) < 1 || shape[0] == 0 {
Michael Yang's avatar
Michael Yang committed
623
		var shape C.int64_t = 0
624
		return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
Michael Yang's avatar
Michael Yang committed
625
	} else if len(shape) > 4 {
Michael Yang's avatar
Michael Yang committed
626
627
628
629
630
631
632
633
634
		panic("unsupported number of dimensions")
	}

	for _, dim := range shape {
		if dim < 1 {
			panic("invalid shape")
		}
	}

Michael Yang's avatar
Michael Yang committed
635
	t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
636
637
	size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
	b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
638
639
640
	if b == nil {
		return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
	}
641
	*c.allocatedBuffers = append(*c.allocatedBuffers, b)
642

Michael Yang's avatar
Michael Yang committed
643
	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
644
	return &Tensor{b: c.b, t: t}, nil
645
646
}

647
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
648
649
650
651
652
653
	t, err := c.newTensor(dtype, shape)
	if err != nil {
		panic(err)
	}

	return t
654
655
}

656
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
657
658
659
660
661
	t, err := c.newTensor(dtype, shape)
	if err != nil {
		panic(err)
	}

662
663
	C.ggml_set_zero(t.(*Tensor).t)
	return t
Michael Yang's avatar
Michael Yang committed
664
665
}

666
func checkShape[S ~[]E, E any](s S, shape ...int) error {
Michael Yang's avatar
Michael Yang committed
667
	n := len(s)
Jesse Gross's avatar
Jesse Gross committed
668
669
670
671
672

	if n == 0 {
		return nil
	}

Michael Yang's avatar
Michael Yang committed
673
674
675
676
677
	for _, v := range shape {
		n /= v
	}

	if n != 1 {
678
		return fmt.Errorf("invalid shape: %v", shape)
Michael Yang's avatar
Michael Yang committed
679
680
	}

681
	return nil
Michael Yang's avatar
Michael Yang committed
682
683
}

684
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
Jesse Gross's avatar
Jesse Gross committed
685
	if err := checkShape(s, shape...); err != nil {
686
687
688
		return nil, err
	}

689
690
691
692
693
	t, err := c.newTensor(ml.DTypeF32, shape)
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
694
695
696
697
	if len(s) > 0 {
		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	}

698
	return t, nil
Michael Yang's avatar
Michael Yang committed
699
700
}

701
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
Jesse Gross's avatar
Jesse Gross committed
702
	if err := checkShape(s, shape...); err != nil {
703
704
705
		return nil, err
	}

706
707
708
709
710
	t, err := c.newTensor(ml.DTypeI32, shape)
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
711
712
713
714
	if len(s) > 0 {
		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
	}

715
	return t, nil
Michael Yang's avatar
Michael Yang committed
716
717
}

Michael Yang's avatar
arange  
Michael Yang committed
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
	switch dtype {
	case ml.DTypeF32:
		// ggml_arange creates a float32 tensor
		return &Tensor{
			b: c.b,
			t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)),
		}
	case ml.DTypeI32:
		// ggml_cast does not support float32 to int32 conversion
		arange := make([]int32, 0, int((stop-start)/step))
		for i := start; i < stop; i += step {
			arange = append(arange, int32(i))
		}

		t, err := c.Input().FromIntSlice(arange, len(arange))
		if err != nil {
			panic(err)
		}

		return t
	default:
		panic("unsupported dtype for arange")
	}
}

Michael Yang's avatar
Michael Yang committed
744
745
func (c *Context) Close() {
	if c != nil {
746
747
748
749
750
		for _, b := range *c.allocatedBuffers {
			C.ggml_backend_buffer_free(b)
		}
		*c.allocatedBuffers = nil

751
752
		C.ggml_free(c.ctx)
	}
Michael Yang's avatar
Michael Yang committed
753
754
755
}

type Tensor struct {
756
	b    *Backend
Michael Yang's avatar
Michael Yang committed
757
	t    *C.struct_ggml_tensor
758
	sync func()
Michael Yang's avatar
Michael Yang committed
759
760
761
762
763
764
765
766
767
768
}

func (t *Tensor) LogValue() slog.Value {
	return slog.GroupValue(
		slog.String("name", C.GoString(C.ggml_get_name(t.t))),
		slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
		slog.Any("shape", t.Shape()),
	)
}

769
770
func (t *Tensor) Dim(n int) int {
	return int(t.t.ne[n])
Michael Yang's avatar
Michael Yang committed
771
772
}

773
774
func (t *Tensor) Stride(n int) int {
	return int(t.t.nb[n])
Michael Yang's avatar
Michael Yang committed
775
776
}

777
778
func (t *Tensor) Shape() []int {
	shape := make([]int, C.ggml_n_dims(t.t))
Michael Yang's avatar
Michael Yang committed
779
780
781
782
783
784
785
	for i := range shape {
		shape[i] = t.Dim(i)
	}

	return shape
}

786
787
788
789
790
791
792
793
794
func (t *Tensor) Bytes() (data []byte) {
	if t.sync != nil {
		data = make([]byte, C.ggml_nbytes(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
	}

	return
Michael Yang's avatar
Michael Yang committed
795
796
}

797
798
799
800
801
802
func (t *Tensor) Floats() (data []float32) {
	if t.sync != nil {
		data = make([]float32, C.ggml_nelements(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
Michael Yang's avatar
Michael Yang committed
803
804
805
806
807
808
809
810
811
	}

	return
}

func (t *Tensor) DType() ml.DType {
	switch t.t._type {
	case C.GGML_TYPE_F32:
		return ml.DTypeF32
Jesse Gross's avatar
Jesse Gross committed
812
813
	case C.GGML_TYPE_F16:
		return ml.DTypeF16
814
815
816
817
	case C.GGML_TYPE_Q8_0:
		return ml.DTypeQ80
	case C.GGML_TYPE_Q4_0:
		return ml.DTypeQ40
Michael Yang's avatar
Michael Yang committed
818
819
820
821
822
823
824
	case C.GGML_TYPE_I32:
		return ml.DTypeI32
	default:
		return ml.DTypeOther
	}
}

825
826
827
828
829
830
831
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_neg(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
832
833
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
834
		b: t.b,
Michael Yang's avatar
Michael Yang committed
835
836
837
838
		t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
	if dim < 0 || dim >= C.GGML_MAX_DIMS {
		panic("invalid dimension")
	}

	shape := make([]C.int64_t, C.GGML_MAX_DIMS)
	for i := range C.GGML_MAX_DIMS {
		if i == dim {
			shape[i] = C.int64_t(t.Dim(i) * n)
		} else {
			shape[i] = C.int64_t(t.Dim(i))
		}
	}

	tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape))
	return &Tensor{
		b: t.b,
		t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl),
	}
}

Michael Yang's avatar
Michael Yang committed
860
861
862
863
864
865
866
867
868
869
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
	if len(s) > 0 {
		return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
	}

	return t
}

func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
	return &Tensor{
870
		b: t.b,
Michael Yang's avatar
Michael Yang committed
871
872
873
874
875
876
		t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
	}
}

func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
	return &Tensor{
877
		b: t.b,
Michael Yang's avatar
Michael Yang committed
878
879
880
881
882
883
		t: C.ggml_cont(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
884
		b: t.b,
Michael Yang's avatar
Michael Yang committed
885
886
887
888
889
890
		t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
891
		b: t.b,
Michael Yang's avatar
Michael Yang committed
892
893
894
895
		t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

896
897
898
899
900
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
	C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)

	return &Tensor{
901
		b: t.b,
902
903
904
905
		t: mul,
	}
}

Michael Yang's avatar
llama4  
Michael Yang committed
906
907
908
909
910
911
912
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
	}
}

Michael Yang's avatar
Michael Yang committed
913
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
914
915
916
917
918
919
	tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
	if w != nil {
		tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
		if b != nil {
			tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t)
		}
Michael Yang's avatar
Michael Yang committed
920
921
	}

Michael Yang's avatar
llama4  
Michael Yang committed
922
	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
923
924
925
}

func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
926
927
928
929
930
931
	tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))
	if w != nil {
		tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
	}

	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
932
933
}

934
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
935
936
	if len(shape) != 4 {
		panic("expected 4 dimensions")
937
938
	} else if shape[3] != 0 {
		panic("cuda does not support 4d tensors")
Michael Yang's avatar
Michael Yang committed
939
940
941
	}

	return &Tensor{
942
		b: t.b,
Michael Yang's avatar
Michael Yang committed
943
944
945
946
947
948
949
950
951
952
		t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
	if len(shape) != 4 {
		panic("expected 4 dimensions")
	}

	return &Tensor{
953
		b: t.b,
Michael Yang's avatar
Michael Yang committed
954
955
956
957
958
959
		t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
960
		b: t.b,
Michael Yang's avatar
Michael Yang committed
961
962
963
964
965
966
		t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
967
		b: t.b,
Michael Yang's avatar
Michael Yang committed
968
969
970
971
		t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

972
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
973
974
975
	switch len(shape) {
	case 1:
		return &Tensor{
976
			b: t.b,
Michael Yang's avatar
Michael Yang committed
977
978
979
980
			t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
		}
	case 2:
		return &Tensor{
981
			b: t.b,
Michael Yang's avatar
Michael Yang committed
982
983
984
985
			t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
		}
	case 3:
		return &Tensor{
986
			b: t.b,
Michael Yang's avatar
Michael Yang committed
987
988
989
990
			t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
		}
	case 4:
		return &Tensor{
991
			b: t.b,
Michael Yang's avatar
Michael Yang committed
992
993
994
995
996
997
998
999
1000
			t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
	return &Tensor{
1001
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1002
1003
1004
1005
1006
1007
		t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
	}
}

func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
	return &Tensor{
1008
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1009
1010
1011
1012
		t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
	}
}

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sin(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_cos(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1027
1028
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
	return &Tensor{
1029
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1030
1031
1032
1033
		t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
llama4  
Michael Yang committed
1034
1035
1036
1037
1038
1039
1040
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1041
1042
1043
1044
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
	switch len(shape) {
	case 1:
		return &Tensor{
1045
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1046
1047
1048
1049
			t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
		}
	case 3:
		return &Tensor{
1050
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1051
1052
1053
1054
1055
1056
1057
			t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]),
				C.size_t(shape[1]),
				C.size_t(offset)),
		}
	case 5:
		return &Tensor{
1058
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1059
1060
1061
1062
1063
1064
1065
			t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
				C.size_t(shape[1]), C.size_t(shape[3]),
				C.size_t(offset)),
		}
	case 7:
		return &Tensor{
1066
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
			t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
				C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
				C.size_t(offset)),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

const (
Patrick Devine's avatar
Patrick Devine committed
1078
1079
1080
1081
	ropeTypeNorm   C.int = 0
	ropeTypeNeox   C.int = 2
	ropeTypeMrope  C.int = 8
	ropeTypeVision C.int = 24
Michael Yang's avatar
Michael Yang committed
1082
1083
)

1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor {
	// Default options
	opts := &ml.RopeOptions{
		OriginalContextLen: 131072,
	}

	// Apply any provided options
	for _, option := range options {
		option(opts)
	}

Michael Yang's avatar
Michael Yang committed
1095
	if ropeFactors == nil {
1096
		ropeFactors = &Tensor{b: t.b}
Michael Yang's avatar
Michael Yang committed
1097
1098
	}

Jesse Gross's avatar
Jesse Gross committed
1099
1100
1101
1102
1103
	dequant := t.t
	if C.ggml_is_quantized(t.t._type) {
		dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
	}

Michael Yang's avatar
Michael Yang committed
1104
	return &Tensor{
1105
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1106
		t: C.ggml_rope_ext(
1107
1108
1109
1110
			ctx.(*Context).ctx,
			dequant,
			positionIDs.(*Tensor).t,
			ropeFactors.(*Tensor).t,
Michael Yang's avatar
Michael Yang committed
1111
			C.int(ropeDim),
Patrick Devine's avatar
Patrick Devine committed
1112
			C.int(ropeType),
1113
			C.int(opts.OriginalContextLen),
Michael Yang's avatar
Michael Yang committed
1114
1115
			C.float(ropeBase),
			C.float(ropeScale),
1116
1117
1118
1119
			C.float(0.0),
			C.float(1.0),
			C.float(32.0),
			C.float(1.0),
Michael Yang's avatar
Michael Yang committed
1120
1121
1122
1123
		),
	}
}

1124
1125
1126
1127
1128
1129
1130
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
	}
}

Michael Yang's avatar
Michael Yang committed
1131
1132
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
	return &Tensor{
1133
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1134
1135
1136
1137
1138
1139
		t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
	return &Tensor{
1140
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1141
1142
1143
1144
1145
1146
		t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
1147
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1148
1149
1150
		t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
	}
}
1151

Michael Yang's avatar
Michael Yang committed
1152
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
1153
1154
	return &Tensor{
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1155
		t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
Michael Yang's avatar
Michael Yang committed
1156
1157
1158
	}
}

Michael Yang's avatar
Michael Yang committed
1159
1160
1161
1162
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
	var tt *C.struct_ggml_tensor
	switch len(strides) {
	case 0:
Michael Yang's avatar
Michael Yang committed
1163
		tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
Michael Yang's avatar
Michael Yang committed
1164
	case 1:
Michael Yang's avatar
Michael Yang committed
1165
		tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
Michael Yang's avatar
Michael Yang committed
1166
1167
1168
1169
1170
1171
1172
	default:
		panic("unsupported number of dimensions")
	}

	return &Tensor{b: t.b, t: tt}
}

1173
1174
1175
1176
1177
1178
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
	var kqMask *C.struct_ggml_tensor
	if mask != nil {
		kqMask = mask.(*Tensor).t
	}

1179
1180
1181
	query := t.Permute(ctx, 0, 2, 1, 3)
	key = key.Permute(ctx, 0, 2, 1, 3)

1182
1183
	if t.b.flashAttention {
		value = value.Permute(ctx, 0, 2, 1, 3)
1184

1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
		kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
		C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
		return &Tensor{b: t.b, t: kqv}
	} else {
		kq := key.MulmatFullPrec(ctx, query)
		kq = &Tensor{
			b: t.b,
			t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
		}

		kqv := value.Mulmat(ctx, kq)
		return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	}
1198
}
1199
1200
1201
1202
1203
1204
1205

func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_dup(ctx.(*Context).ctx, t.t),
	}
}
Michael Yang's avatar
llama4  
Michael Yang committed
1206
1207
1208
1209
1210
1211
1212

func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
	}
}
1213
1214
1215
1216
1217
1218
1219

func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
	}
}