"official/legacy/transformer/model_params.py" did not exist on "3fca8afe7eaaaa845b803ab7951ee3f3ed235208"
ggml.go 47 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
package ggml

3
4
// #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
// #cgo windows LDFLAGS: -lpthread
5
6
7
8
9
10
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
// #include <stdlib.h>
// #include <stdint.h>
// #include "ggml.h"
// #include "ggml-cpu.h"
// #include "ggml-backend.h"
Michael Yang's avatar
Michael Yang committed
11
12
13
import "C"

import (
Michael Yang's avatar
Michael Yang committed
14
	"cmp"
15
	"context"
Michael Yang's avatar
Michael Yang committed
16
	"encoding/binary"
Jesse Gross's avatar
Jesse Gross committed
17
	"errors"
Michael Yang's avatar
Michael Yang committed
18
19
20
	"fmt"
	"io"
	"log/slog"
21
	"maps"
Michael Yang's avatar
Michael Yang committed
22
	"os"
23
	"runtime"
24
25
26
	"slices"
	"strconv"
	"strings"
Jesse Gross's avatar
Jesse Gross committed
27
	"sync"
28
	"sync/atomic"
29
	"unicode"
Michael Yang's avatar
Michael Yang committed
30
31
32
	"unsafe"

	"github.com/ollama/ollama/format"
33
34
	"github.com/ollama/ollama/fs"
	fsggml "github.com/ollama/ollama/fs/ggml"
35
	"github.com/ollama/ollama/logutil"
Michael Yang's avatar
Michael Yang committed
36
	"github.com/ollama/ollama/ml"
37
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
38
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
39
40
41
	"golang.org/x/sync/errgroup"
)

Jesse Gross's avatar
Jesse Gross committed
42
43
44
45
46
47
var (
	cpus, accels, gpus []C.ggml_backend_dev_t
	backends           map[C.ggml_backend_dev_t]C.ggml_backend_t
)

var initDevices = sync.OnceFunc(func() {
Michael Yang's avatar
Michael Yang committed
48
49
	ggml.OnceLoad()

Jesse Gross's avatar
Jesse Gross committed
50
51
52
53
54
55
56
57
58
59
60
61
	backends = make(map[C.ggml_backend_dev_t]C.ggml_backend_t)
	for i := range C.ggml_backend_dev_count() {
		d := C.ggml_backend_dev_get(i)

		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU:
			if len(cpus) == 0 {
				// only the first cpu device should be used
				cpus = append(cpus, d)
			}
		case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
			accels = append(accels, d)
62
63
		case C.GGML_BACKEND_DEVICE_TYPE_GPU,
			C.GGML_BACKEND_DEVICE_TYPE_IGPU:
Jesse Gross's avatar
Jesse Gross committed
64
65
66
67
68
69
			gpus = append(gpus, d)
		}

		backends[d] = C.ggml_backend_dev_init(d, nil)
	}
})
Michael Yang's avatar
Michael Yang committed
70

Jesse Gross's avatar
Jesse Gross committed
71
72
73
74
75
type layerDevice struct {
	d  C.ggml_backend_dev_t
	bt C.ggml_backend_buffer_type_t
}

Michael Yang's avatar
Michael Yang committed
76
type Backend struct {
77
78
79
	// modelPath is the location of the model data
	modelPath string

80
81
	meta *fsggml.GGML

Jesse Gross's avatar
Jesse Gross committed
82
83
84
85
	// allocMemory means that memory should be allocated for tensors and not
	// just a dry run
	allocMemory bool

86
87
88
89
	// tensorLoadTargets maps from the name of the tensor in the file
	// to the name that is used by the model definition
	tensorLoadTargets map[string][]string

90
	schedMu       sync.Mutex // Only one Compute can run at a time
91
92
93
	sched         C.ggml_backend_sched_t
	schedBackends []C.ggml_backend_t
	schedBufts    []C.ggml_backend_buffer_type_t
94

95
	tensors map[string]*C.struct_ggml_tensor
Michael Yang's avatar
Michael Yang committed
96

Jesse Gross's avatar
Jesse Gross committed
97
	// input is the backend buffer type used for inputs
98
	input C.ggml_backend_buffer_type_t
Michael Yang's avatar
Michael Yang committed
99

Jesse Gross's avatar
Jesse Gross committed
100
101
102
	// output is the backend device used for outputs
	output C.ggml_backend_dev_t

Michael Yang's avatar
Michael Yang committed
103
	// layers is the backend used for repeating layers
Jesse Gross's avatar
Jesse Gross committed
104
	layers map[int]layerDevice
105

106
107
108
109
	// requiredMemory is the cumulative memory allocations needed by the backend
	requiredMemory *ml.BackendMemory

	// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
110
	btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
111

112
	flashAttention bool
Michael Yang's avatar
Michael Yang committed
113
114
115

	// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
	maxGraphNodes int
Jesse Gross's avatar
Jesse Gross committed
116
117
118

	// weightBuffers are the GGML contexts and buffers for allocating weights
	weightBuffers map[*C.struct_ggml_context]C.ggml_backend_buffer_t
Michael Yang's avatar
Michael Yang committed
119
120
}

Jesse Gross's avatar
Jesse Gross committed
121
122
var once sync.Once

123
124
125
126
127
128
129
130
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
	r, err := os.Open(modelPath)
	if err != nil {
		return nil, err
	}
	defer r.Close()

	meta, err := fsggml.Decode(r, -1)
Michael Yang's avatar
Michael Yang committed
131
132
133
134
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
135
136
137
138
139
140
141
142
143
144
145
	once.Do(func() {
		slog.Info(
			"",
			"architecture", meta.KV().Architecture(),
			"file_type", meta.KV().FileType(),
			"name", meta.KV().String("general.name"),
			"description", meta.KV().String("general.description"),
			"num_tensors", len(meta.Tensors().Items()),
			"num_key_values", len(meta.KV()),
		)
	})
Michael Yang's avatar
Michael Yang committed
146

Jesse Gross's avatar
Jesse Gross committed
147
148
	initDevices()

149
	var requiredMemory ml.BackendMemory
150
	btDeviceMemory := make(map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory)
151

152
	type deviceBufferType struct {
153
154
		d   C.ggml_backend_dev_t
		bts []C.ggml_backend_buffer_type_t
155
156
	}

157
158
	blocks := int(meta.KV().BlockCount())

Michael Yang's avatar
Michael Yang committed
159
	// create list of buffer types for the cpu
Michael Yang's avatar
Michael Yang committed
160
	cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
161
162
163
164
	for _, d := range append(accels, append(gpus, cpus...)...) {
		switch C.ggml_backend_dev_type(d) {
		case C.GGML_BACKEND_DEVICE_TYPE_CPU,
			C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
Jesse Gross's avatar
Jesse Gross committed
165
166
167
			bt := C.ggml_backend_dev_buffer_type(d)
			cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, bt)

168
			btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU
Michael Yang's avatar
Michael Yang committed
169
		}
170
171
	}

172
	requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
173
174
	var props C.struct_ggml_backend_dev_props
	C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
175
	requiredMemory.CPU.ID = C.GoString(props.id)
176
	requiredMemory.CPU.Library = C.GoString(props.library)
177
178
	requiredMemory.CPU.Weights = make([]uint64, blocks+1)
	requiredMemory.CPU.Cache = make([]uint64, blocks+1)
179

Michael Yang's avatar
Michael Yang committed
180
	// create list of buffer types for each gpu
181
	var gpuDeviceBufferTypes []deviceBufferType
182
183
	requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
	for i, d := range gpus {
184
		bt := C.ggml_backend_dev_buffer_type(d)
185
		gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
186
			d:   d,
187
			bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...),
188
		})
Jesse Gross's avatar
Jesse Gross committed
189

190
191
		btDeviceMemory[bt] = &requiredMemory.GPUs[i]
		requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
192
193
		var props C.struct_ggml_backend_dev_props
		C.ggml_backend_dev_get_props(d, &props)
194
		requiredMemory.GPUs[i].ID = C.GoString(props.id)
195
		requiredMemory.GPUs[i].Library = C.GoString(props.library)
196
197
		requiredMemory.GPUs[i].Weights = make([]uint64, blocks+1)
		requiredMemory.GPUs[i].Cache = make([]uint64, blocks+1)
Michael Yang's avatar
Michael Yang committed
198
199
	}

Michael Yang's avatar
Michael Yang committed
200
	// inputs always use cpu
Michael Yang's avatar
Michael Yang committed
201
	input := cpuDeviceBufferType
202

Jesse Gross's avatar
Jesse Gross committed
203
204
205
206
207
	assignLayer := func(layer int) deviceBufferType {
		for _, p := range params.GPULayers {
			for _, l := range p.Layers {
				if l == layer {
					for i := range requiredMemory.GPUs {
208
						if requiredMemory.GPUs[i].DeviceID == p.DeviceID {
Jesse Gross's avatar
Jesse Gross committed
209
210
211
							return gpuDeviceBufferTypes[i]
						}
					}
212

Jesse Gross's avatar
Jesse Gross committed
213
214
215
					return cpuDeviceBufferType
				}
			}
216
217
		}

Jesse Gross's avatar
Jesse Gross committed
218
		return cpuDeviceBufferType
219
220
	}

Michael Yang's avatar
Michael Yang committed
221
	// repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
222
	layers := make([]deviceBufferType, blocks)
223
	for i := range layers {
224
		layers[i] = assignLayer(i)
225
226
	}

Michael Yang's avatar
Michael Yang committed
227
	// outputs are assigned iff allowed by splits and configured number of gpu layers
228
	output := assignLayer(blocks)
229
230
231

	maxTensors := len(meta.Tensors().Items())
	maxTensors += 1
Michael Yang's avatar
Michael Yang committed
232
	// each layer has at most 2 extra tensors for rope operations
233
234
	maxTensors += blocks * 2

235
	type tensor struct {
236
		source *fsggml.Tensor
237
238
239
		target string
	}

Michael Yang's avatar
Michael Yang committed
240
	// some tensors are mapped to different names so keep a list
241
242
	targets := make(map[string][]string)

Michael Yang's avatar
Michael Yang committed
243
	// contexts are shared by tensors of the same buffer type
244
245
	ctxs := make(map[C.ggml_backend_buffer_type_t]*C.struct_ggml_context)
	createTensor := func(t tensor, bts []C.ggml_backend_buffer_type_t, layer int) *C.struct_ggml_tensor {
246
247
248
249
250
251
252
		for _, bt := range bts {
			if _, ok := ctxs[bt]; !ok {
				ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
					mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
					no_alloc: true,
				})
			}
Michael Yang's avatar
Michael Yang committed
253

254
255
256
257
258
259
260
261
			targets[t.source.Name] = append(targets[t.source.Name], t.target)

			name := t.source.Name
			if t.target != "" {
				name = t.target
			}

			cname := C.CString(name)
Michael Yang's avatar
Michael Yang committed
262
			defer C.free(unsafe.Pointer(cname))
263
264
265
266
			if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
				return tt
			}

267
268
269
270
271
272
273
274
275
276
			kind := t.source.Kind
			if t.source.Kind == 4 {
				// transform raw mxfp4 stream to ggml mxfp4 format
				kind = 39
			} else if t.source.Kind == uint32(fsggml.TensorTypeBF16) && strings.HasSuffix(t.source.Name, "_exps.bias") {
				// transform "_exps.bias" from bf16 to fp32; add_ids only supports fp32 tensors
				kind = uint32(fsggml.TensorTypeF32)
			}

			tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
Michael Yang's avatar
Michael Yang committed
277
278
			C.ggml_set_name(tt, cname)

279
			logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
280
281
282

			size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
			if layer == -1 {
283
				requiredMemory.InputWeights += uint64(size)
284
			} else {
285
				btDeviceMemory[bt].Weights[layer] += uint64(size)
286
287
			}

288
289
290
291
292
			//nolint:staticcheck // TODO: check if buffer type supports this tensor
			return tt
		}

		return nil
Michael Yang's avatar
Michael Yang committed
293
294
	}

295
	contains := func(s string, parts ...string) bool {
296
297
298
299
300
301
302
303
		split := strings.Split(s, ".")
		for _, part := range parts {
			if slices.Contains(split, part) {
				return true
			}
		}

		return false
Michael Yang's avatar
Michael Yang committed
304
305
	}

306
307
	for _, t := range meta.Tensors().Items() {
		switch {
308
		case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
309
			createTensor(tensor{source: t}, input.bts, -1)
Michael Yang's avatar
Michael Yang committed
310
			if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
311
				createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
Michael Yang's avatar
Michael Yang committed
312
			}
Michael Yang's avatar
Michael Yang committed
313
314
315
		case contains(t.Name, "cls", "output", "output_norm",
			"altup_proj", "altup_unembd_proj",
			"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
316
			createTensor(tensor{source: t}, output.bts, blocks)
Michael Yang's avatar
Michael Yang committed
317
		case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm.") || strings.HasPrefix(t.Name, "s."):
Michael Yang's avatar
Michael Yang committed
318
			// TODO: assign vision tensors to the gpu if possible
319
			createTensor(tensor{source: t}, output.bts, blocks)
Michael Yang's avatar
Michael Yang committed
320
321
322
323
324
325
		case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
			// these tensors should be repeated per layer
			for i, layer := range layers {
				createTensor(tensor{
					source: t,
					target: "blk." + strconv.Itoa(i) + "." + t.Name,
326
				}, layer.bts, i)
Michael Yang's avatar
Michael Yang committed
327
			}
328
		default:
Michael Yang's avatar
Michael Yang committed
329
330
331
332
			layerIndex := -1
			if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
				if i, err := strconv.Atoi(fields[0]); err == nil {
					layerIndex = i
333
				}
Michael Yang's avatar
Michael Yang committed
334
			}
335

Michael Yang's avatar
Michael Yang committed
336
			if layerIndex >= 0 {
337
				createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex)
338
			} else {
Michael Yang's avatar
Michael Yang committed
339
				// load all other tensors on the cpu
340
				createTensor(tensor{source: t}, input.bts, -1)
341
342
343
			}
		}
	}
Michael Yang's avatar
Michael Yang committed
344

Michael Yang's avatar
Michael Yang committed
345
	// map tensor names to tensors for easy lookup later
346
347
348
349
350
351
352
	tensors := make(map[string]*C.struct_ggml_tensor)
	for _, c := range ctxs {
		for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
			tensors[C.GoString(C.ggml_get_name(t))] = t
		}
	}

353
	// map devices to backend buffer types so new tensors can be assigned to the correct device
354
	deviceBufferTypes := make(map[C.ggml_backend_dev_t]C.ggml_backend_buffer_type_t)
355
356

	// create backends and buffer types used for the compute graph scheduler
357
358
	var schedBackends []C.ggml_backend_t
	var schedBufts []C.ggml_backend_buffer_type_t
359
	for _, d := range append(gpus, append(accels, cpus...)...) {
Jesse Gross's avatar
Jesse Gross committed
360
		b := backends[d]
361
362
		bt := C.ggml_backend_get_default_buffer_type(b)

Jesse Gross's avatar
Jesse Gross committed
363
364
365
366
367
368
369
		// Always include CPU as a fallback but otherwise, just use the devices where we assigned layers
		if !slices.Contains(cpuDeviceBufferType.bts, bt) {
			if c, ok := ctxs[bt]; !ok || C.ggml_get_first_tensor(c) == nil {
				continue
			}
		}

370
371
372
373
374
375
376
377
378
379
380
		deviceBufferTypes[d] = bt

		schedBackends = append(schedBackends, b)
		schedBufts = append(schedBufts, bt)

		if C.ggml_backend_is_cpu(b) {
			// set number of threads for cpu backend
			C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
		}
	}

381
	maxGraphNodes := max(1024, len(meta.Tensors().Items())*8)
382
383
384
385
386
387
388

	sched := C.ggml_backend_sched_new_ext(
		(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
		(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
		C.int(len(schedBackends)),
		C.size_t(maxGraphNodes),
		C._Bool(false),
389
		C._Bool(true),
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
		C._Bool(params.AllocMemory),
	)

	// allocate buffers for each context
	bbs := make(map[*C.struct_ggml_context]C.ggml_backend_buffer_t, len(ctxs))
	for bt, c := range ctxs {
		if C.ggml_get_first_tensor(c) == nil {
			continue
		}

		b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
		if b == nil {
			for _, b := range bbs {
				C.ggml_backend_buffer_free(b)
			}

			for _, ctx := range ctxs {
				C.ggml_free(ctx)
			}

			panic(ml.ErrNoMem{BackendMemory: requiredMemory})
		}

		C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
		bbs[c] = b
	}

	for bs := range maps.Values(bbs) {
		logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
			"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
	}

422
423
	return &Backend{
		modelPath:         modelPath,
Jesse Gross's avatar
Jesse Gross committed
424
		allocMemory:       params.AllocMemory,
425
426
427
428
		flashAttention:    params.FlashAttention,
		meta:              meta,
		tensorLoadTargets: targets,
		tensors:           tensors,
429
430
431
432
433
		sched:             sched,
		schedBackends:     schedBackends,
		schedBufts:        schedBufts,
		input:             deviceBufferTypes[input.d],
		output:            output.d,
Jesse Gross's avatar
Jesse Gross committed
434
435
		layers: func() map[int]layerDevice {
			m := make(map[int]layerDevice)
436
			for i, layer := range layers {
Jesse Gross's avatar
Jesse Gross committed
437
438
439
440
				m[i] = layerDevice{
					d:  layer.d,
					bt: deviceBufferTypes[layer.d],
				}
441
442
443
			}
			return m
		}(),
444
445
446
		requiredMemory: &requiredMemory,
		btDeviceMemory: btDeviceMemory,
		maxGraphNodes:  maxGraphNodes,
Jesse Gross's avatar
Jesse Gross committed
447
		weightBuffers:  bbs,
448
449
450
451
452
453
454
	}, nil
}

func init() {
	ml.RegisterBackend("ggml", New)
}

Jesse Gross's avatar
Jesse Gross committed
455
456
457
458
459
460
461
462
463
464
465
466
467
func (b *Backend) Close() {
	if b == nil {
		return
	}

	for ctx, b := range b.weightBuffers {
		C.ggml_backend_buffer_free(b)
		C.ggml_free(ctx)
	}

	C.ggml_backend_sched_free(b.sched)
}

468
func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
Jesse Gross's avatar
Jesse Gross committed
469
470
471
472
473
474
475
	if !b.allocMemory {
		return errors.New("cannot load model without memory allocation")
	}

	// Mimic llama runner logs summarizing layers and memory
	gpuLayers := 0
	for layer := range maps.Values(b.layers) {
476
477
478
		switch C.ggml_backend_dev_type(layer.d) {
		case C.GGML_BACKEND_DEVICE_TYPE_GPU,
			C.GGML_BACKEND_DEVICE_TYPE_IGPU:
Jesse Gross's avatar
Jesse Gross committed
479
480
481
482
483
484
485
486
			gpuLayers++
		}
	}
	slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", gpuLayers))

	switch C.ggml_backend_dev_type(b.output) {
	case C.GGML_BACKEND_DEVICE_TYPE_CPU:
		slog.Info("offloading output layer to CPU")
487
488
	case C.GGML_BACKEND_DEVICE_TYPE_GPU,
		C.GGML_BACKEND_DEVICE_TYPE_IGPU:
Jesse Gross's avatar
Jesse Gross committed
489
490
491
492
493
494
495
		slog.Info("offloading output layer to GPU")
		gpuLayers++
	case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
		slog.Info("offloading output layer to ACCEL")
	}
	slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(b.layers)+1))

496
	var doneBytes atomic.Uint64
497
	totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
498
499
500

	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(runtime.GOMAXPROCS(0))
501
	for _, t := range b.meta.Tensors().Items() {
502
		g.Go(func() error {
503
			tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
504
			for i := range tts {
505
				target := b.tensorLoadTargets[t.Name][i]
506
507
508
				if target == "" {
					target = t.Name
				}
509

510
				tt, ok := b.tensors[target]
511
512
513
				if !ok {
					return fmt.Errorf("unassigned tensor: %s", t.Name)
				}
Michael Yang's avatar
Michael Yang committed
514

515
516
517
				tts[i] = tt
			}

518
519
			// Create a new FD for each goroutine so that each FD is read sequentially, rather than
			// seeking around within an FD shared between all goroutines.
520
			file, err := os.Open(b.modelPath)
521
			if err != nil {
522
				slog.Warn("file open error", "file", b.modelPath, "error", err)
523
524
525
				return err
			}
			defer file.Close()
526
			sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size()))
527
528
529
530
531
532
533

			if t.Kind == 4 && tts[0]._type == 39 {
				// source is mxfp4, target is ggml mxfp4

				const BS = 17                             // MXFP4 block size
				bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
				var s uint64
534
				var tmp [16]byte
535
536
537
538
539
540
541
542
543
544
545
546
				for s < t.Size() {
					// Stop if either the parent context has been canceled or if any of the other tensors returned an error
					if err := ctx.Err(); err != nil {
						return err
					}
					n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
					if err != nil {
						slog.Warn("file read error", "file", b.modelPath, "error", err)
						return err
					}
					for j := range n / BS {
						for i := 1; i < 9; i++ {
547
548
549
550
							// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
							a, b := bts[j*BS+i], bts[j*BS+i+8]
							tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
							tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
551
						}
552
						copy(bts[j*BS+1:j*BS+17], tmp[:])
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
					}

					for _, tt := range tts {
						C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
					}

					s += uint64(n)

					if progress != nil {
						done := doneBytes.Add(uint64(n))
						progress(float32(done) / float32(totalBytes))
					}
				}
				return nil
			} else if strings.HasSuffix(t.Name, "_exps.bias") && t.Kind == 30 && tts[0]._type == 0 {
				// source is bf16, target is ggml fp32

				// data is bf16 but we need to convert to fp32
				bts := make([]byte, 128*format.KibiByte)
				var e uint64
				for e < t.Elements() {
					// Stop if either the parent context has been canceled or if any of the other tensors returned an error
					if err := ctx.Err(); err != nil {
						return err
					}
					n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Elements()-e)*2)])
					if err != nil {
						slog.Warn("file read error", "file", b.modelPath, "error", err)
						return err
					}
					fp32 := ConvertToF32(bts, uint32(fsggml.TensorTypeBF16), uint64(n/2))

					for _, tt := range tts {
						C.ggml_backend_tensor_set(tt, unsafe.Pointer(&fp32[0]), C.size_t(e*4), C.size_t(n*2))
					}
					e += uint64(n / 2)
					if progress != nil {
						done := doneBytes.Add(uint64(n))
						progress(float32(done) / float32(totalBytes))
					}
				}
				return nil
			}

597
598
599
600
			bts := make([]byte, 128*format.KibiByte)

			var s uint64
			for s < t.Size() {
601
602
603
604
605
				// Stop if either the parent context has been canceled or if any of the other tensors returned an error
				if err := ctx.Err(); err != nil {
					return err
				}

606
607
				n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
				if err != nil {
608
					slog.Warn("file read error", "file", b.modelPath, "error", err)
609
					return err
610
				}
Michael Yang's avatar
Michael Yang committed
611

612
613
				for _, tt := range tts {
					C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
614
				}
Michael Yang's avatar
Michael Yang committed
615

616
617
				s += uint64(n)

618
				if progress != nil {
619
					done := doneBytes.Add(uint64(n))
620
					progress(float32(done) / float32(totalBytes))
621
622
623
624
625
				}
			}

			return nil
		})
Michael Yang's avatar
Michael Yang committed
626
627
	}

628
629
630
631
632
633
634
635
636
637
638
639
	// Cleanup any backend state from devices that we didn't end up using
nextDevice:
	for _, d := range append(gpus, append(accels, cpus...)...) {
		for _, backend := range b.schedBackends {
			if d == C.ggml_backend_get_device(backend) {
				continue nextDevice
			}
		}

		C.ggml_backend_dev_reset(d)
	}

640
	if err := g.Wait(); err != nil {
641
		return err
642
643
	}

644
	return nil
Michael Yang's avatar
Michael Yang committed
645
646
}

647
648
649
650
func (b *Backend) BackendMemory() ml.BackendMemory {
	return *b.requiredMemory
}

651
func (b *Backend) Config() fs.Config {
Michael Yang's avatar
Michael Yang committed
652
653
654
655
	return b.meta.KV()
}

func (b *Backend) Get(name string) ml.Tensor {
656
657
	if t, ok := b.tensors[name]; ok {
		return &Tensor{b: b, t: t}
Michael Yang's avatar
Michael Yang committed
658
659
660
661
662
663
	}

	return nil
}

func (b *Backend) NewContext() ml.Context {
Michael Yang's avatar
Michael Yang committed
664
	return b.NewContextSize(b.maxGraphNodes)
665
666
667
}

func (b *Backend) NewContextSize(n int) ml.Context {
Jesse Gross's avatar
Jesse Gross committed
668
669
670
671
	if n > b.maxGraphNodes {
		panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
	}

672
	var allocatedBuffers []C.ggml_backend_buffer_t
673

Michael Yang's avatar
Michael Yang committed
674
	return &Context{
675
676
		b:             b,
		maxGraphNodes: n,
677
		ctx: C.ggml_init(C.struct_ggml_init_params{
678
			mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
679
680
			no_alloc: true,
		}),
681
		allocatedBuffers: &allocatedBuffers,
682
		layer:            -1,
Michael Yang's avatar
Michael Yang committed
683
684
685
	}
}

686
func (b *Backend) CacheConfig() ml.CacheConfig {
687
688
689
690
691
	if b.flashAttention {
		return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
	} else {
		return ml.CacheConfig{CachePadding: 32, PermutedV: true}
	}
692
693
}

694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
func (b *Backend) BackendDevices() []ml.DeviceInfo {
	deviceInfos := []ml.DeviceInfo{}
	for _, dev := range gpus {
		// If we have a model loaded, and it's only loaded on a subset of the devices
		// skip idle/unused devices to avoid initializing them and causing VRAM allocations
		if b.allocMemory {
			idleDev := true
			for _, backend := range b.schedBackends {
				if dev == C.ggml_backend_get_device(backend) {
					idleDev = false
					break
				}
			}
			if idleDev {
				slog.Debug("skipping unused backend device", "description", C.GoString(C.ggml_backend_dev_description(dev)))
				continue
			}
		}

		info := ml.DeviceInfo{}
		props := C.struct_ggml_backend_dev_props{}
		C.ggml_backend_dev_get_props(dev, &props)
		info.Name = C.GoString(props.name)
		info.Description = C.GoString(props.description)
		info.ID = C.GoString(props.id)
		info.Library = C.GoString(props.library)
		info.ComputeMajor = (int)(props.compute_major)
		info.ComputeMinor = (int)(props.compute_minor)
		info.DriverMajor = (int)(props.driver_major)
		info.DriverMinor = (int)(props.driver_minor)
		info.Integrated = props.integrated != 0
		if props.library != nil {
			info.Library = C.GoString(props.library)
		}
728
729
730
		if props.device_id != nil {
			info.PCIID = C.GoString(props.device_id)
		}
731
732
733
734
735
736
737
738
739
740
		info.LibraryPath = ggml.LibPaths()
		C.ggml_backend_dev_memory(dev, &props.memory_free, &props.memory_total)
		info.TotalMemory = (uint64)(props.memory_total)
		info.FreeMemory = (uint64)(props.memory_free)

		deviceInfos = append(deviceInfos, info)
	}
	return deviceInfos
}

Michael Yang's avatar
Michael Yang committed
741
type Context struct {
742
	b *Backend
Michael Yang's avatar
Michael Yang committed
743

744
	ctx   *C.struct_ggml_context
Michael Yang's avatar
Michael Yang committed
745
	graph *C.struct_ggml_cgraph
746

747
748
749
	// batchSize is a hint to optimize processing
	batchSize int

750
	// buft is the buffer type used for new tensors
751
	buft C.ggml_backend_buffer_type_t
752

753
754
	// allocatedBuffers are buffers for tensors that we have allocated in this context
	// so that we can free them when we close the context
755
	allocatedBuffers *[]C.ggml_backend_buffer_t
756

Michael Yang's avatar
Michael Yang committed
757
	// maxGraphNodes is the maximum allowed number of graph nodes in this context
758
	maxGraphNodes int
759
760
761

	// layer is the graph layer that this context is allocating for - assumed to be cache
	layer int
Michael Yang's avatar
Michael Yang committed
762
763
}

764
func (c *Context) Input() ml.Context {
Michael Yang's avatar
Michael Yang committed
765
	if c.b.input != nil {
766
		return &Context{
767
768
769
770
771
			b:                c.b,
			ctx:              c.ctx,
			buft:             c.b.input,
			allocatedBuffers: c.allocatedBuffers,
			maxGraphNodes:    c.maxGraphNodes,
772
			layer:            -1,
773
774
775
		}
	}

776
	return c
777
778
}

779
func (c *Context) Layer(i int) ml.Context {
Jesse Gross's avatar
Jesse Gross committed
780
	if layer, ok := c.b.layers[i]; ok {
781
		return &Context{
782
783
			b:                c.b,
			ctx:              c.ctx,
Jesse Gross's avatar
Jesse Gross committed
784
			buft:             layer.bt,
785
786
			allocatedBuffers: c.allocatedBuffers,
			maxGraphNodes:    c.maxGraphNodes,
787
			layer:            i,
788
789
790
		}
	}

791
	return c
792
793
}

794
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
Michael Yang's avatar
Michael Yang committed
795
	if c.graph == nil {
796
		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
Michael Yang's avatar
Michael Yang committed
797
798
	}

799
800
801
802
803
	for _, tensor := range tensors {
		C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
	}

	return c
Michael Yang's avatar
Michael Yang committed
804
805
}

806
807
808
809
func (c *Context) SetBatchSize(batchSize int) {
	c.batchSize = batchSize
}

810
func (c *Context) Compute(tensors ...ml.Tensor) {
811
812
813
814
815
816
817
818
819
	c.ComputeWithNotify(nil, tensors...)
}

func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
	c.b.schedMu.Lock()
	defer c.b.schedMu.Unlock()
	if cb != nil {
		go cb()
	}
820
821
822
823
824

	if c.batchSize > 0 {
		C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize))
	}

825
826
827
	if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
		panic(fmt.Errorf("error computing ggml graph: %v", status))
	}
Michael Yang's avatar
Michael Yang committed
828
	C.ggml_backend_sched_reset(c.b.sched)
Michael Yang's avatar
Michael Yang committed
829

830
831
832
	needSync := true
	sync := func() {
		if needSync {
833
			C.ggml_backend_sched_synchronize(c.b.sched)
834
835
836
			needSync = false
		}
	}
Michael Yang's avatar
Michael Yang committed
837

838
839
840
	for _, t := range tensors {
		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
			t.(*Tensor).sync = sync
841
842
		}
	}
Michael Yang's avatar
Michael Yang committed
843
844
}

845
func (c *Context) Reserve() {
846
847
848
849
	if c.batchSize > 0 {
		C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize))
	}

850
	reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
851
852

	slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
853
854
855

	// Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations
	for _, bt := range c.b.schedBufts {
856
		c.b.btDeviceMemory[bt].Graph = 0
857
858
	}

859
	for i := range c.b.schedBackends {
860
861
		bufferSize := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i])
		c.b.btDeviceMemory[c.b.schedBufts[i]].Graph += uint64(bufferSize)
862

863
		logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
864
			"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferSize)))
865
866
	}

867
868
869
	if !reserved {
		panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
	}
870
871
}

872
func (c *Context) MaxGraphNodes() int {
873
	return c.maxGraphNodes
Jesse Gross's avatar
Jesse Gross committed
874
875
}

876
877
878
func shapeToGGML(shape []int) *C.int64_t {
	sh := make([]C.int64_t, len(shape))
	for i, s := range shape {
879
		sh[i] = C.int64_t(s)
880
881
882
883
884
	}

	return &sh[0]
}

885
886
887
888
func pad(length, pad C.size_t) C.size_t {
	return ((length + pad - 1) / pad) * pad
}

Michael Yang's avatar
Michael Yang committed
889
func (c *Context) newTensor(dtype ml.DType, shape []int) *Tensor {
890
	if c.buft == nil {
891
		panic("set Input or Layer before creating tensors")
892
893
	}

894
	cdtype := ggmlDType(dtype)
Michael Yang's avatar
Michael Yang committed
895

Jesse Gross's avatar
Jesse Gross committed
896
	if len(shape) < 1 || shape[0] == 0 {
Michael Yang's avatar
Michael Yang committed
897
		var shape C.int64_t = 0
898
		return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
Michael Yang's avatar
Michael Yang committed
899
	} else if len(shape) > 4 {
Michael Yang's avatar
Michael Yang committed
900
901
902
903
904
905
906
907
908
		panic("unsupported number of dimensions")
	}

	for _, dim := range shape {
		if dim < 1 {
			panic("invalid shape")
		}
	}

Michael Yang's avatar
Michael Yang committed
909
	t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
910
	size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
911

912
	b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
913
	if c.layer >= 0 {
914
		c.b.btDeviceMemory[c.buft].Cache[c.layer] += uint64(size)
915
916
	}

917
	if b == nil {
918
		panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
919
920
	}

921
	*c.allocatedBuffers = append(*c.allocatedBuffers, b)
Michael Yang's avatar
Michael Yang committed
922
	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
923
	return &Tensor{b: c.b, t: t}
924
925
}

926
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
927
	return c.newTensor(dtype, shape)
928
929
}

930
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
931
	t := c.newTensor(dtype, shape)
Jesse Gross's avatar
Jesse Gross committed
932
	if c.b.allocMemory {
Michael Yang's avatar
Michael Yang committed
933
		C.ggml_set_zero(t.t)
Jesse Gross's avatar
Jesse Gross committed
934
	}
935
	return t
Michael Yang's avatar
Michael Yang committed
936
937
}

938
func checkShape[S ~[]E, E any](s S, shape ...int) {
Michael Yang's avatar
Michael Yang committed
939
	n := len(s)
Jesse Gross's avatar
Jesse Gross committed
940
941

	if n == 0 {
942
		return
Jesse Gross's avatar
Jesse Gross committed
943
944
	}

Michael Yang's avatar
Michael Yang committed
945
946
947
948
949
	for _, v := range shape {
		n /= v
	}

	if n != 1 {
950
		panic(fmt.Errorf("invalid shape: %v", shape))
Michael Yang's avatar
Michael Yang committed
951
952
953
	}
}

Michael Yang's avatar
Michael Yang committed
954
955
956
957
958
959
960
961
962
963
964
func (c Context) FromBytes(dtype ml.DType, s []uint8, shape ...int) ml.Tensor {
	// Unchecked to handle quantized types
	t := c.newTensor(dtype, shape)
	if c.b.allocMemory {
		t.FromBytes(s)
	}

	return t
}

func (c *Context) FromFloats(s []float32, shape ...int) ml.Tensor {
965
	checkShape(s, shape...)
966

967
	t := c.newTensor(ml.DTypeF32, shape)
968

Michael Yang's avatar
Michael Yang committed
969
970
	if c.b.allocMemory {
		t.FromFloats(s)
Jesse Gross's avatar
Jesse Gross committed
971
972
	}

973
	return t
Michael Yang's avatar
Michael Yang committed
974
975
}

Michael Yang's avatar
Michael Yang committed
976
func (c *Context) FromInts(s []int32, shape ...int) ml.Tensor {
977
	checkShape(s, shape...)
978

979
	t := c.newTensor(ml.DTypeI32, shape)
Michael Yang's avatar
Michael Yang committed
980
981
	if c.b.allocMemory {
		t.FromInts(s)
Jesse Gross's avatar
Jesse Gross committed
982
983
	}

984
	return t
Michael Yang's avatar
Michael Yang committed
985
986
}

Michael Yang's avatar
arange  
Michael Yang committed
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
	switch dtype {
	case ml.DTypeF32:
		// ggml_arange creates a float32 tensor
		return &Tensor{
			b: c.b,
			t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)),
		}
	case ml.DTypeI32:
		// ggml_cast does not support float32 to int32 conversion
		arange := make([]int32, 0, int((stop-start)/step))
		for i := start; i < stop; i += step {
			arange = append(arange, int32(i))
		}

Michael Yang's avatar
Michael Yang committed
1002
		return c.Input().FromInts(arange, len(arange))
Michael Yang's avatar
arange  
Michael Yang committed
1003
1004
1005
1006
1007
	default:
		panic("unsupported dtype for arange")
	}
}

Michael Yang's avatar
Michael Yang committed
1008
1009
func (c *Context) Close() {
	if c != nil {
1010
1011
1012
1013
1014
		for _, b := range *c.allocatedBuffers {
			C.ggml_backend_buffer_free(b)
		}
		*c.allocatedBuffers = nil

1015
1016
		C.ggml_free(c.ctx)
	}
Michael Yang's avatar
Michael Yang committed
1017
1018
1019
}

type Tensor struct {
1020
	b    *Backend
Michael Yang's avatar
Michael Yang committed
1021
	t    *C.struct_ggml_tensor
1022
	sync func()
Michael Yang's avatar
Michael Yang committed
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
}

func (t *Tensor) LogValue() slog.Value {
	return slog.GroupValue(
		slog.String("name", C.GoString(C.ggml_get_name(t.t))),
		slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
		slog.Any("shape", t.Shape()),
	)
}

1033
1034
func (t *Tensor) Dim(n int) int {
	return int(t.t.ne[n])
Michael Yang's avatar
Michael Yang committed
1035
1036
}

1037
1038
func (t *Tensor) Stride(n int) int {
	return int(t.t.nb[n])
Michael Yang's avatar
Michael Yang committed
1039
1040
}

1041
1042
func (t *Tensor) Shape() []int {
	shape := make([]int, C.ggml_n_dims(t.t))
Michael Yang's avatar
Michael Yang committed
1043
1044
1045
1046
1047
1048
1049
	for i := range shape {
		shape[i] = t.Dim(i)
	}

	return shape
}

1050
1051
1052
1053
1054
1055
1056
1057
1058
func (t *Tensor) Bytes() (data []byte) {
	if t.sync != nil {
		data = make([]byte, C.ggml_nbytes(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
	}

	return
Michael Yang's avatar
Michael Yang committed
1059
1060
}

1061
1062
1063
1064
1065
1066
func (t *Tensor) Floats() (data []float32) {
	if t.sync != nil {
		data = make([]float32, C.ggml_nelements(t.t))

		t.sync()
		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
Michael Yang's avatar
Michael Yang committed
1067
1068
1069
1070
1071
	}

	return
}

Michael Yang's avatar
Michael Yang committed
1072
1073
1074
1075
1076
1077
func tensorSet[S ~[]E, E byte | float32 | int32](t *Tensor, s S) {
	if len(s) == 0 {
		return
	}
	if int(C.ggml_nbytes(t.t)) != len(s)*binary.Size(s[0]) {
		panic("data size does not match tensor size")
1078
	}
Michael Yang's avatar
Michael Yang committed
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
	C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
}

func (t *Tensor) FromBytes(s []byte) {
	tensorSet(t, s)
}

func (t *Tensor) FromFloats(s []float32) {
	tensorSet(t, s)
}

func (t *Tensor) FromInts(s []int32) {
	tensorSet(t, s)
1092
1093
}

Michael Yang's avatar
Michael Yang committed
1094
1095
1096
1097
func (t *Tensor) DType() ml.DType {
	switch t.t._type {
	case C.GGML_TYPE_F32:
		return ml.DTypeF32
Jesse Gross's avatar
Jesse Gross committed
1098
1099
	case C.GGML_TYPE_F16:
		return ml.DTypeF16
1100
1101
1102
1103
	case C.GGML_TYPE_Q8_0:
		return ml.DTypeQ80
	case C.GGML_TYPE_Q4_0:
		return ml.DTypeQ40
Michael Yang's avatar
Michael Yang committed
1104
1105
	case C.GGML_TYPE_I32:
		return ml.DTypeI32
Michael Yang's avatar
Michael Yang committed
1106
1107
	case C.GGML_TYPE_MXFP4:
		return ml.DTypeMXFP4
Michael Yang's avatar
Michael Yang committed
1108
1109
1110
1111
1112
	default:
		return ml.DTypeOther
	}
}

1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
func ggmlDType(dtype ml.DType) uint32 {
	switch dtype {
	case ml.DTypeF32:
		return C.GGML_TYPE_F32
	case ml.DTypeF16:
		return C.GGML_TYPE_F16
	case ml.DTypeQ80:
		return C.GGML_TYPE_Q8_0
	case ml.DTypeQ40:
		return C.GGML_TYPE_Q4_0
	case ml.DTypeI32:
		return C.GGML_TYPE_I32
	case ml.DTypeMXFP4:
		return C.GGML_TYPE_MXFP4
	default:
		panic("unsupported dtype")
	}
}

func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)),
	}
}

Michael Yang's avatar
Michael Yang committed
1139
1140
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
1141
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1142
1143
1144
1145
		t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

Michael Yang's avatar
Michael Yang committed
1146
1147
1148
1149
1150
1151
1152
func (t *Tensor) Sub(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
	if dim < 0 || dim >= C.GGML_MAX_DIMS {
		panic("invalid dimension")
	}

	shape := make([]C.int64_t, C.GGML_MAX_DIMS)
	for i := range C.GGML_MAX_DIMS {
		if i == dim {
			shape[i] = C.int64_t(t.Dim(i) * n)
		} else {
			shape[i] = C.int64_t(t.Dim(i))
		}
	}

	tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape))
	return &Tensor{
		b: t.b,
		t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl),
	}
}

Michael Yang's avatar
Michael Yang committed
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
	if len(s) > 0 {
		return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
	}

	return t
}

func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
	return &Tensor{
1184
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1185
1186
1187
1188
		t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
	}
}

Michael Yang's avatar
Michael Yang committed
1189
func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor {
1190
1191
1192
1193
	if slices.Contains(shape, -1) {
		inferShape(t, shape)
	}

Michael Yang's avatar
Michael Yang committed
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
	switch len(shape) {
	case 0:
		return &Tensor{
			b: t.b,
			t: C.ggml_cont(ctx.(*Context).ctx, t.t),
		}
	case 1:
		return &Tensor{
			b: t.b,
			t: C.ggml_cont_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
		}
	case 2:
		return &Tensor{
			b: t.b,
			t: C.ggml_cont_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
		}
	case 3:
		return &Tensor{
			b: t.b,
			t: C.ggml_cont_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
		}
	case 4:
		return &Tensor{
			b: t.b,
			t: C.ggml_cont_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
		}
	default:
		panic("unsupported number of dimensions")
Michael Yang's avatar
Michael Yang committed
1222
1223
1224
1225
1226
	}
}

func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
1227
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1228
1229
1230
1231
		t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

1232
1233
1234
1235
1236
1237
1238
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

1239
1240
1241
1242
1243
// Mulmat performs matrix multiplication between two tensors.
// If t has shape [m, p, ...] and t2 has shape [m, n, ...],
// Mulmat returns a new Tensor with shape [p, n, ...].
//
// Note: this is similar to matmul(t2, t.tranpose(-1, -2)) in other libraries.
Michael Yang's avatar
Michael Yang committed
1244
1245
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
1246
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1247
1248
1249
1250
		t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

1251
1252
1253
1254
1255
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
	C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)

	return &Tensor{
1256
		b: t.b,
1257
1258
1259
1260
		t: mul,
	}
}

Michael Yang's avatar
llama4  
Michael Yang committed
1261
1262
1263
1264
1265
1266
1267
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
	}
}

1268
1269
1270
1271
1272
1273
1274
func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_add_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
	}
}

1275
1276
1277
1278
1279
1280
1281
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
	}
}

Michael Yang's avatar
Michael Yang committed
1282
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
1283
1284
1285
1286
1287
1288
	tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
	if w != nil {
		tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
		if b != nil {
			tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t)
		}
Michael Yang's avatar
Michael Yang committed
1289
1290
	}

Michael Yang's avatar
llama4  
Michael Yang committed
1291
	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
1292
1293
1294
}

func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
1295
1296
1297
1298
1299
1300
	tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))
	if w != nil {
		tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
	}

	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
1301
1302
}

1303
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
1304
1305
	if len(shape) != 4 {
		panic("expected 4 dimensions")
1306
1307
	} else if shape[3] != 0 {
		panic("cuda does not support 4d tensors")
Michael Yang's avatar
Michael Yang committed
1308
1309
1310
	}

	return &Tensor{
1311
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1312
1313
1314
1315
		t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
	}
}

1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
// Permute permutes t according to order. Permute panics if the number of dimensions
// in order does not match the number of dimensions in t.
func (t *Tensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
	if len(order) != len(t.Shape()) && len(order) != 4 {
		panic("invalid number of dimensions for permute")
	}

	// ggml_permute requires 4 dimensions so fill in the rest
	for i := len(order); i < 4; i++ {
		order = append(order, i)
Michael Yang's avatar
Michael Yang committed
1326
1327
1328
	}

	return &Tensor{
1329
		b: t.b,
1330
		t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(order[0]), C.int(order[1]), C.int(order[2]), C.int(order[3])),
Michael Yang's avatar
Michael Yang committed
1331
1332
1333
1334
1335
	}
}

func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
1336
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1337
1338
1339
1340
1341
1342
		t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
	return &Tensor{
1343
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1344
1345
1346
1347
		t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
	}
}

1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
// inferShape updates shape in place to automatically set a single -1 dimesion
// based on the input tensor and the other dimensions
func inferShape(t *Tensor, shape []int) {
	total := 1
	for _, dim := range t.Shape() {
		total *= dim
	}

	dim := -1
	for i := range shape {
		switch shape[i] {
		case -1:
			if dim != -1 {
				panic("only one dimension can be inferred")
			}
			dim = i
		case 0:
			panic("dimension cannot be zero")
		default:
			if total%shape[i] != 0 {
				panic("cannot infer dimension")
			}

			total /= shape[i]
		}
	}

	if dim != -1 {
		shape[dim] = total
	}
}

1380
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
1381
1382
1383
1384
	if !C.ggml_is_contiguous(t.t) {
		return t.Contiguous(ctx, shape...)
	}

1385
1386
1387
1388
	if slices.Contains(shape, -1) {
		inferShape(t, shape)
	}

Michael Yang's avatar
Michael Yang committed
1389
1390
1391
	switch len(shape) {
	case 1:
		return &Tensor{
1392
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1393
1394
1395
1396
			t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
		}
	case 2:
		return &Tensor{
1397
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1398
1399
1400
1401
			t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
		}
	case 3:
		return &Tensor{
1402
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1403
1404
1405
1406
			t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
		}
	case 4:
		return &Tensor{
1407
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1408
1409
1410
1411
1412
1413
1414
1415
1416
			t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
	return &Tensor{
1417
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1418
1419
1420
1421
		t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
	}
}

1422
1423
1424
1425
1426
1427
1428
func (t *Tensor) SumRows(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sum_rows(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1429
1430
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
	return &Tensor{
1431
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1432
1433
1434
1435
		t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
	}
}

1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sin(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_cos(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1450
1451
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
	return &Tensor{
1452
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1453
1454
1455
1456
		t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
llama4  
Michael Yang committed
1457
1458
1459
1460
1461
1462
1463
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1464
1465
1466
1467
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
	switch len(shape) {
	case 1:
		return &Tensor{
1468
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1469
1470
1471
1472
			t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
		}
	case 3:
		return &Tensor{
1473
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1474
1475
1476
1477
1478
1479
1480
			t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]),
				C.size_t(shape[1]),
				C.size_t(offset)),
		}
	case 5:
		return &Tensor{
1481
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1482
1483
1484
1485
1486
1487
1488
			t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
				C.size_t(shape[1]), C.size_t(shape[3]),
				C.size_t(offset)),
		}
	case 7:
		return &Tensor{
1489
			b: t.b,
Michael Yang's avatar
Michael Yang committed
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
			t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
				C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
				C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
				C.size_t(offset)),
		}
	default:
		panic("unsupported number of dimensions")
	}
}

1500
func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor {
1501
	// Default options
Michael Yang's avatar
Michael Yang committed
1502
	opts := rope.Options{Factors: &Tensor{}}
1503
1504
1505

	// Apply any provided options
	for _, option := range options {
Michael Yang's avatar
Michael Yang committed
1506
		option(&opts)
1507
1508
	}

Jesse Gross's avatar
Jesse Gross committed
1509
1510
1511
1512
1513
	dequant := t.t
	if C.ggml_is_quantized(t.t._type) {
		dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
	}

Michael Yang's avatar
Michael Yang committed
1514
1515
1516
1517
1518
1519
1520
1521
	var tt *C.struct_ggml_tensor
	if len(opts.MRoPE.Sections) > 0 {
		mropeSections := make([]C.int32_t, 4)
		for i, section := range opts.MRoPE.Sections {
			mropeSections[i] = C.int32_t(section)
		}

		tt = C.ggml_rope_multi(
1522
1523
			ctx.(*Context).ctx,
			dequant,
1524
1525
			positions.(*Tensor).t,
			opts.Factors.(*Tensor).t,
Michael Yang's avatar
Michael Yang committed
1526
			C.int(ropeDim),
Michael Yang's avatar
Michael Yang committed
1527
			unsafe.SliceData(mropeSections),
1528
			C.int(opts.Type),
Michael Yang's avatar
Michael Yang committed
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
			cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
			C.float(ropeBase), C.float(ropeScale),
			C.float(opts.YaRN.ExtrapolationFactor),
			cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
			cmp.Or(C.float(opts.YaRN.BetaFast), 32),
			cmp.Or(C.float(opts.YaRN.BetaSlow), 1),
		)
	} else {
		tt = C.ggml_rope_ext(
			ctx.(*Context).ctx,
			dequant,
			positions.(*Tensor).t,
			opts.Factors.(*Tensor).t,
			C.int(ropeDim), C.int(opts.Type),
			cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10),
			C.float(ropeBase), C.float(ropeScale),
			C.float(opts.YaRN.ExtrapolationFactor),
			cmp.Or(C.float(opts.YaRN.AttentionFactor), 1),
			cmp.Or(C.float(opts.YaRN.BetaFast), 32),
			cmp.Or(C.float(opts.YaRN.BetaSlow), 1),
		)
Michael Yang's avatar
Michael Yang committed
1550
	}
Michael Yang's avatar
Michael Yang committed
1551
	return &Tensor{b: t.b, t: tt}
Michael Yang's avatar
Michael Yang committed
1552
1553
}

1554
1555
1556
1557
1558
1559
1560
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
	}
}

1561
1562
1563
1564
1565
1566
1567
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
	if len(t2) > 0 {
		return &Tensor{
			b: t.b,
			t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
		}
	}
Michael Yang's avatar
Michael Yang committed
1568
	return &Tensor{
1569
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1570
1571
1572
1573
		t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
	var tt *C.struct_ggml_tensor
	if len(t2) > 0 {
		tt = C.ggml_geglu_quick_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t)
	} else {
		tt = C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t)
	}
	return &Tensor{b: t.b, t: tt}
}

1584
1585
1586
1587
1588
1589
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
	if len(t2) > 0 {
		return &Tensor{
			b: t.b,
			t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
		}
Michael Yang's avatar
Michael Yang committed
1590
	}
Michael Yang's avatar
Michael Yang committed
1591
	return &Tensor{
1592
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1593
1594
1595
1596
		t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
	}
}

1597
1598
1599
1600
1601
1602
1603
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
	if len(t2) > 0 {
		return &Tensor{
			b: t.b,
			t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
		}
	}
Michael Yang's avatar
Michael Yang committed
1604
1605
1606
1607
1608
1609
	return &Tensor{
		b: t.b,
		t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
	}
}

1610
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
1611
1612
1613
1614
1615
1616
	return &Tensor{
		b: t.b,
		t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
	}
}

Michael Yang's avatar
Michael Yang committed
1617
1618
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
	return &Tensor{
1619
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1620
1621
1622
		t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
	}
}
1623

1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
	var tt ml.Tensor = &Tensor{
		b: t.b,
		t: C.ggml_conv_3d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int64_t(c), C.int(s0), C.int(s1), C.int(s2), C.int(p0), C.int(p1), C.int(p2), C.int(d0), C.int(d1), C.int(d2)),
	}

	tt = tt.Reshape(ctx, t.Dim(3)/c, t2.Dim(3)/c)
	return tt
}

Michael Yang's avatar
Michael Yang committed
1634
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
1635
1636
	return &Tensor{
		b: t.b,
Michael Yang's avatar
Michael Yang committed
1637
		t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
Michael Yang's avatar
Michael Yang committed
1638
1639
1640
	}
}

Grace's avatar
Grace committed
1641
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
1642
1643
1644
1645
1646
	var kqMask *C.struct_ggml_tensor
	if mask != nil {
		kqMask = mask.(*Tensor).t
	}

1647
1648
1649
	query := t.Permute(ctx, 0, 2, 1, 3)
	key = key.Permute(ctx, 0, 2, 1, 3)

1650
1651
	if t.b.flashAttention {
		value = value.Permute(ctx, 0, 2, 1, 3)
1652

1653
		kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
1654
1655
1656
		if sinks != nil {
			C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
		}
1657
		C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
Grace's avatar
Grace committed
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667

		if vmla != nil {
			var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
			cur = cur.Permute(ctx, 0, 2, 1, 3)
			cur = vmla.Mulmat(ctx, cur)
			cur = cur.Permute(ctx, 0, 2, 1, 3)
			cur = cur.Contiguous(ctx)
			kqv = cur.(*Tensor).t
		}

1668
1669
1670
1671
1672
1673
1674
		return &Tensor{b: t.b, t: kqv}
	} else {
		kq := key.MulmatFullPrec(ctx, query)
		kq = &Tensor{
			b: t.b,
			t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
		}
1675
1676
1677
		if sinks != nil {
			C.ggml_soft_max_add_sinks(kq.(*Tensor).t, sinks.(*Tensor).t)
		}
1678
1679

		kqv := value.Mulmat(ctx, kq)
Grace's avatar
Grace committed
1680
1681
1682
1683
		if vmla != nil {
			kqv = vmla.Mulmat(ctx, kqv)
		}

1684
1685
		return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	}
1686
}
1687
1688
1689
1690
1691
1692
1693

func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_dup(ctx.(*Context).ctx, t.t),
	}
}
Michael Yang's avatar
llama4  
Michael Yang committed
1694
1695
1696
1697
1698
1699
1700

func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
	}
}
1701
1702
1703
1704
1705
1706
1707

func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
	}
}
Michael Yang's avatar
Michael Yang committed
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740

func (t *Tensor) Mean(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_mean(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Variance(ctx ml.Context) ml.Tensor {
	return t.Add(ctx, t.Mean(ctx).Scale(ctx, -1)).
		Sqr(ctx).
		SumRows(ctx).
		Scale(ctx, 1/float64(t.Dim(0)))
}

func (t *Tensor) Stddev(ctx ml.Context) ml.Tensor {
	return t.Variance(ctx).Sqrt(ctx)
}

func (t *Tensor) Sqr(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sqr(ctx.(*Context).ctx, t.t),
	}
}

func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
	return &Tensor{
		b: t.b,
		t: C.ggml_sqrt(ctx.(*Context).ctx, t.t),
	}
}

Michael Yang's avatar
Michael Yang committed
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
	var mode C.uint32_t
	switch samplingMode {
	case ml.SamplingModeNearest:
		mode = C.GGML_SCALE_MODE_NEAREST
	case ml.SamplingModeBilinear:
		mode = C.GGML_SCALE_MODE_BILINEAR
	default:
		panic("unsupported interpolate mode")
	}

	return &Tensor{
		b: t.b,
		t: C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode),
	}
}

1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
func (t *Tensor) Slice(ctx ml.Context, dim int, low, high, step int) ml.Tensor {
	if dim < 0 || dim >= C.GGML_MAX_DIMS {
		panic("invalid dimension")
	} else if low < 0 || high > t.Dim(dim) || low >= high || step < 1 {
		panic("invalid slice parameters")
	}

	if dim == 0 && step > 1 {
		// dim=0,step>1 is a special case so handle it here first
		return t.View(ctx,
			low*t.Stride(0), 1,
			step*t.Stride(0), (high-low+1)/step,
			t.Stride(1), t.Dim(1),
			// preserve dim 3 by merging it into dim 2
			t.Stride(2), t.Dim(2)*t.Dim(3),
		).Contiguous(ctx, (high-low+1)/step, t.Dim(1), t.Dim(2), t.Dim(3))
	}

	args := []int{
		low * t.Stride(dim), t.Dim(0),
		t.Stride(1), t.Dim(1),
		t.Stride(2), t.Dim(2),
		t.Stride(3), t.Dim(3),
	}

	if step == 1 {
		args[dim*2+1] = high - low
		return t.View(ctx, args[0], args[1:]...)
	} else {
		args[dim*2] = step * t.Stride(dim)
		args[dim*2+1] = (high - low + 1) / step
		return t.View(ctx, args[0], args[1:]...)
	}
}

// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of
// the original.
func (t *Tensor) Chunk(ctx ml.Context, dim, chunk int) []ml.Tensor {
	sections := make([]int, 0, t.Dim(dim)/chunk+1)
	for rest := t.Dim(dim); rest > 0; rest -= chunk {
		sections = append(sections, min(chunk, rest))
	}
	return t.ChunkSections(ctx, dim, sections...)
}

// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a
// view of the original. The size of the dim must equal the sum of sections.
func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Tensor {
	var offset int
	s := make([]ml.Tensor, len(sections))
	for i, section := range sections {
		s[i] = t.Slice(ctx, dim, offset, offset+section, 1)
		offset += section
	}
	if offset != t.Dim(dim) {
		panic("sections do not sum to tensor dimension")
	}
	return s
}