ggml.go 22.3 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
package ggml
2
3

import (
Michael Yang's avatar
Michael Yang committed
4
	"cmp"
5
6
	"encoding/binary"
	"errors"
Michael Yang's avatar
Michael Yang committed
7
	"fmt"
8
	"io"
Michael Yang's avatar
Michael Yang committed
9
	"log/slog"
10
	"math"
11
	"slices"
Michael Yang's avatar
Michael Yang committed
12
	"strings"
13

14
	"github.com/ollama/ollama/format"
Michael Yang's avatar
Michael Yang committed
15
	"github.com/ollama/ollama/fs/util/bufioutil"
16
17
)

Michael Yang's avatar
Michael Yang committed
18
19
20
type GGML struct {
	container
	model
21
	Length int64
Michael Yang's avatar
Michael Yang committed
22
}
23

Michael Yang's avatar
Michael Yang committed
24
type model interface {
Michael Yang's avatar
Michael Yang committed
25
	KV() KV
Michael Yang's avatar
Michael Yang committed
26
	Tensors() Tensors
27
28
}

29
30
type KV map[string]any

Michael Yang's avatar
Michael Yang committed
31
func (kv KV) Architecture() string {
Michael Yang's avatar
Michael Yang committed
32
	return kv.String("general.architecture", "unknown")
Michael Yang's avatar
Michael Yang committed
33
34
}

35
func (kv KV) Kind() string {
Michael Yang's avatar
Michael Yang committed
36
	return kv.String("general.type", "unknown")
37
38
}

Michael Yang's avatar
Michael Yang committed
39
func (kv KV) ParameterCount() uint64 {
40
41
	val, _ := keyValue(kv, "general.parameter_count", uint64(0))
	return val
Michael Yang's avatar
Michael Yang committed
42
43
}

44
func (kv KV) FileType() FileType {
Michael Yang's avatar
Michael Yang committed
45
	if t := kv.Uint("general.file_type"); t > 0 {
46
		return FileType(t)
Michael Yang's avatar
Michael Yang committed
47
48
	}

49
	return FileTypeUnknown
Michael Yang's avatar
Michael Yang committed
50
51
52
}

func (kv KV) BlockCount() uint64 {
Michael Yang's avatar
Michael Yang committed
53
54
55
56
57
	return uint64(kv.Uint("block_count"))
}

func (kv KV) EmbeddingLength() uint64 {
	return uint64(kv.Uint("embedding_length"))
Michael Yang's avatar
Michael Yang committed
58
59
}

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
func (kv KV) HeadCount() []uint64 {
	headCountDefault := uint32(1)
	headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
	if len(headCount) == 1 {
		headCountDefault = headCount[0]
	}
	nLayers := int(kv.BlockCount())
	if len(headCount) > nLayers {
		slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
	}
	out := make([]uint64, nLayers)
	for i := range nLayers {
		if i >= len(headCount) {
			out[i] = uint64(headCountDefault)
		} else {
			out[i] = uint64(headCount[i])
		}
	}
	return out
}

81
82
func (kv KV) HeadCountMax() uint64 {
	return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
Michael Yang's avatar
Michael Yang committed
83
84
}

85
86
func (kv KV) HeadCountMin() uint64 {
	return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
Michael Yang's avatar
Michael Yang committed
87
88
}

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
func (kv KV) HeadCountKV() []uint64 {
	headCountKVDefault := uint32(1)
	headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
	if len(headCountKV) == 1 {
		headCountKVDefault = headCountKV[0]
	}
	nLayers := int(kv.BlockCount())
	if len(headCountKV) > nLayers {
		slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
	}
	out := make([]uint64, nLayers)
	for i := range nLayers {
		if i >= len(headCountKV) {
			out[i] = uint64(headCountKVDefault)
		} else {
			out[i] = uint64(headCountKV[i])
		}
	}
	return out
}

110
111
112
113
114
115
116
117
118
119
func (kv KV) HeadCountKVMax() uint64 {
	return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
}

func (kv KV) HeadCountKVMin() uint64 {
	return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
}

func (kv KV) EmbeddingHeadCountMax() uint64 {
	if heads := kv.HeadCountMin(); heads > 0 {
Michael Yang's avatar
Michael Yang committed
120
		return kv.EmbeddingLength() / heads
Michael Yang's avatar
Michael Yang committed
121
122
123
124
125
126
	}

	return 0
}

func (kv KV) EmbeddingHeadCountK() uint64 {
127
	return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax())))
Michael Yang's avatar
Michael Yang committed
128
129
130
}

func (kv KV) EmbeddingHeadCountV() uint64 {
131
	return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax())))
Michael Yang's avatar
Michael Yang committed
132
133
134
}

func (kv KV) ContextLength() uint64 {
Michael Yang's avatar
Michael Yang committed
135
	return uint64(kv.Uint("context_length"))
Michael Yang's avatar
Michael Yang committed
136
137
}

Michael Yang's avatar
Michael Yang committed
138
func (kv KV) ChatTemplate() string {
Michael Yang's avatar
Michael Yang committed
139
140
141
	return kv.String("tokenizer.chat_template")
}

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// ssm architecture parameters

func (kv KV) SSMConvKernel() uint64 {
	return uint64(kv.Uint("ssm.conv_kernel"))
}

func (kv KV) SSMInnerSize() uint64 {
	return uint64(kv.Uint("ssm.inner_size"))
}

func (kv KV) SSMStateSize() uint64 {
	return uint64(kv.Uint("ssm.state_size"))
}

func (kv KV) SSMGroupCount() uint64 {
	return uint64(kv.Uint("ssm.group_count"))
}

// general types

Michael Yang's avatar
Michael Yang committed
162
func (kv KV) String(key string, defaultValue ...string) string {
163
164
	val, _ := keyValue(kv, key, append(defaultValue, "")...)
	return val
Michael Yang's avatar
Michael Yang committed
165
166
167
}

func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
168
169
	val, _ := keyValue(kv, key, append(defaultValue, 0)...)
	return val
Michael Yang's avatar
Michael Yang committed
170
171
172
}

func (kv KV) Float(key string, defaultValue ...float32) float32 {
173
174
	val, _ := keyValue(kv, key, append(defaultValue, 0)...)
	return val
Michael Yang's avatar
Michael Yang committed
175
176
}

177
func (kv KV) Bool(key string, defaultValue ...bool) bool {
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
	val, _ := keyValue(kv, key, append(defaultValue, false)...)
	return val
}

func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
	_, max := kv.UintOrArrayValue(key, defaultValue)
	return max
}

func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
	min, _ := kv.UintOrArrayValue(key, defaultValue)
	return min
}

func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
193
194
195
196
197
	arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
	return slices.Min(arrVal), slices.Max(arrVal)
}

func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
198
	if u32, ok := keyValue(kv, key, uint32(0)); ok {
199
		return []uint32{u32}
200
	} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
201
		return u32s.values
202
	} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
203
204
205
206
207
208
		dst := make([]uint32, len(i32s.values))
		for i, v := range i32s.values {
			if v < 0 {
				slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
			}
			dst[i] = uint32(v)
209
		}
210
		return dst
211
212
	}

213
	return []uint32{defaultValue}
214
215
}

Michael Yang's avatar
Michael Yang committed
216
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
217
218
	val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]})
	return val.values
Michael Yang's avatar
Michael Yang committed
219
220
}

Michael Yang's avatar
Michael Yang committed
221
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
222
223
	val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
	return val.values
Michael Yang's avatar
Michael Yang committed
224
225
}

Michael Yang's avatar
Michael Yang committed
226
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
227
228
	val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]})
	return val.values
Michael Yang's avatar
Michael Yang committed
229
230
}

Patrick Devine's avatar
Patrick Devine committed
231
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
232
233
	val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]})
	return val.values
Patrick Devine's avatar
Patrick Devine committed
234
235
}

Michael Yang's avatar
Michael Yang committed
236
237
238
239
240
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
	val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
	return val.values
}

241
func (kv KV) OllamaEngineRequired() bool {
242
243
	return slices.Contains([]string{
		"gemma3",
244
		"gemma3n",
245
		"gptoss", "gpt-oss",
Michael Yang's avatar
llama4  
Michael Yang committed
246
		"llama4",
247
		"mistral3",
248
		"mllama",
249
		"qwen25vl",
250
251
		"qwen3", "qwen3moe",
		"qwen3vl", "qwen3vlmoe",
Michael Yang's avatar
Michael Yang committed
252
		"deepseekocr",
253
	}, kv.Architecture())
254
255
}

Michael Yang's avatar
Michael Yang committed
256
type valueTypes interface {
Michael Yang's avatar
Michael Yang committed
257
258
259
260
261
262
263
264
265
	uint8 | int8 | uint16 | int16 |
		uint32 | int32 | uint64 | int64 |
		string | float32 | float64 | bool
}

type arrayValueTypes interface {
	*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
		*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
		*array[string] | *array[float32] | *array[float64] | *array[bool]
Michael Yang's avatar
Michael Yang committed
266
267
}

268
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
Michael Yang's avatar
Michael Yang committed
269
270
271
272
	if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
		key = kv.Architecture() + "." + key
	}

273
274
	if val, ok := kv[key].(T); ok {
		return val, true
Michael Yang's avatar
Michael Yang committed
275
276
	}

277
	slog.Debug("key with type not found", "key", key, "default", defaultValue[0])
278
	return defaultValue[0], false
Michael Yang's avatar
Michael Yang committed
279
280
}

281
type Tensors struct {
Michael Yang's avatar
Michael Yang committed
282
	items  []*Tensor
283
	Offset uint64
Michael Yang's avatar
Michael Yang committed
284
}
Michael Yang's avatar
Michael Yang committed
285

Michael Yang's avatar
Michael Yang committed
286
287
288
289
func (s Tensors) Items(prefix ...string) []*Tensor {
	if len(prefix) == 0 {
		return s.items
	}
290

Michael Yang's avatar
Michael Yang committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
	var items []*Tensor
	for _, t := range s.items {
		if strings.HasPrefix(t.Name, prefix[0]) {
			items = append(items, t)
		}
	}

	return items
}

func (ts Tensors) GroupLayers() map[string]Layer {
	layers := make(map[string]Layer)
	for _, t := range ts.items {
		parts := strings.Split(t.Name, ".")
		if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
			if len(parts) > index+2 {
				// blk and mm should have a number after them, join it
				parts = append(
					[]string{strings.Join(parts[:index+2], ".")},
					parts[index+2:]...)
311
			}
Michael Yang's avatar
Michael Yang committed
312
		}
313

Michael Yang's avatar
Michael Yang committed
314
315
		if _, ok := layers[parts[0]]; !ok {
			layers[parts[0]] = make(Layer)
Michael Yang's avatar
Michael Yang committed
316
317
		}

Michael Yang's avatar
Michael Yang committed
318
319
320
321
		layers[parts[0]][strings.Join(parts[1:], ".")] = t
	}

	return layers
Michael Yang's avatar
Michael Yang committed
322
323
324
325
}

type Layer map[string]*Tensor

Michael Yang's avatar
Michael Yang committed
326
func (l Layer) Size() (size uint64) {
Michael Yang's avatar
Michael Yang committed
327
	for _, t := range l {
Michael Yang's avatar
Michael Yang committed
328
		size += t.Size()
Michael Yang's avatar
Michael Yang committed
329
330
331
332
333
	}

	return size
}

334
type Tensor struct {
Michael Yang's avatar
Michael Yang committed
335
336
337
	Name   string `json:"name"`
	Kind   uint32 `json:"kind"`
	Offset uint64 `json:"-"`
338
339

	// Shape is the number of elements in each dimension
Michael Yang's avatar
Michael Yang committed
340
	Shape []uint64 `json:"shape"`
341

Michael Yang's avatar
Michael Yang committed
342
	io.WriterTo `json:"-"`
343
344
}

345
346
func (t Tensor) block() (n int) {
	if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
347
		return math.MaxInt
348
349
350
351
352
	}

	return
}

353
func (t Tensor) blockSize() uint64 {
Michael Yang's avatar
Michael Yang committed
354
	return TensorType(t.Kind).BlockSize()
355
356
357
358
}

func (t TensorType) BlockSize() uint64 {
	switch t {
Michael Yang's avatar
Michael Yang committed
359
	case
360
361
362
363
364
365
366
367
		TensorTypeF32,
		TensorTypeF16,
		TensorTypeI8,
		TensorTypeI16,
		TensorTypeI32,
		TensorTypeI64,
		TensorTypeF64,
		TensorTypeBF16:
368
		return 1
Michael Yang's avatar
Michael Yang committed
369
	case
370
371
372
373
374
375
376
377
		TensorTypeQ4_0,
		TensorTypeQ4_1,
		TensorTypeQ5_0,
		TensorTypeQ5_1,
		TensorTypeQ8_0,
		TensorTypeQ8_1,
		tensorTypeIQ4_NL,
		4, TensorTypeMXFP4:
378
		return 32
Michael Yang's avatar
Michael Yang committed
379
	default:
380
381
382
383
384
		return 256
	}
}

func (t Tensor) typeSize() uint64 {
385
386
387
388
389
	return TensorType(t.Kind).TypeSize()
}

func (t TensorType) TypeSize() uint64 {
	blockSize := t.BlockSize()
390

391
392
	switch t {
	case TensorTypeF32:
393
		return 4
394
	case TensorTypeF16:
395
		return 2
396
	case TensorTypeQ4_0:
397
		return 2 + blockSize/2
398
	case TensorTypeQ4_1:
399
		return 2 + 2 + blockSize/2
400
	case TensorTypeQ5_0:
401
		return 2 + 4 + blockSize/2
402
	case TensorTypeQ5_1:
403
		return 2 + 2 + 4 + blockSize/2
404
	case TensorTypeQ8_0:
405
		return 2 + blockSize
406
	case TensorTypeQ8_1:
Michael Yang's avatar
Michael Yang committed
407
		return 2 + 2 + blockSize
408
	case TensorTypeQ2_K:
409
		return blockSize/16 + blockSize/4 + 2 + 2
410
	case TensorTypeQ3_K:
411
		return blockSize/8 + blockSize/4 + 12 + 2
412
	case TensorTypeQ4_K:
413
		return 2 + 2 + 12 + blockSize/2
414
	case TensorTypeQ5_K:
415
		return 2 + 2 + 12 + blockSize/8 + blockSize/2
416
	case TensorTypeQ6_K:
417
		return blockSize/2 + blockSize/4 + blockSize/16 + 2
418
	case TensorTypeQ8_K:
Michael Yang's avatar
Michael Yang committed
419
		return 4 + blockSize + 2*blockSize/16
420
	case tensorTypeIQ2_XXS:
421
		return 2 + 2*blockSize/8
422
	case tensorTypeIQ2_XS:
423
		return 2 + 2*blockSize/8 + blockSize/32
424
	case tensorTypeIQ3_XXS:
425
		return 2 + blockSize/4 + blockSize/8
426
	case tensorTypeIQ1_S:
427
		return 2 + blockSize/8 + blockSize/16
428
	case tensorTypeIQ4_NL:
429
		return 2 + blockSize/2
430
	case tensorTypeIQ3_S:
431
		return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
432
	case tensorTypeIQ2_S:
433
		return 2 + blockSize/4 + blockSize/16
434
	case tensorTypeIQ4_XS:
435
		return 2 + 2 + blockSize/2 + blockSize/64
436
	case TensorTypeI8:
437
		return 1
438
	case TensorTypeI16:
439
		return 2
440
	case TensorTypeI32:
441
		return 4
442
	case TensorTypeI64:
443
		return 8
444
	case TensorTypeF64:
445
		return 8
446
	case tensorTypeIQ1_M:
447
		return blockSize/8 + blockSize/16 + blockSize/32
448
	case TensorTypeBF16:
Michael Yang's avatar
Michael Yang committed
449
		return 2
450
451
	case 4, TensorTypeMXFP4:
		return 1 + blockSize/2
452
453
454
455
456
	default:
		return 0
	}
}

457
func (t Tensor) Elements() uint64 {
458
459
460
461
462
463
464
	var count uint64 = 1
	for _, n := range t.Shape {
		count *= n
	}
	return count
}

Michael Yang's avatar
Michael Yang committed
465
func (t Tensor) Size() uint64 {
466
	return t.Elements() * t.typeSize() / t.blockSize()
467
468
}

469
func (t Tensor) Type() string {
470
	return TensorType(t.Kind).String()
471
472
}

473
474
type container interface {
	Name() string
Michael Yang's avatar
Michael Yang committed
475
	Decode(io.ReadSeeker) (model, error)
476
477
478
}

const (
Bruce MacDonald's avatar
Bruce MacDonald committed
479
	// Magic constant for `ggml` files (unversioned).
480
	FILE_MAGIC_GGML = 0x67676d6c
Bruce MacDonald's avatar
Bruce MacDonald committed
481
	// Magic constant for `ggml` files (versioned, ggmf).
482
	FILE_MAGIC_GGMF = 0x67676d66
Bruce MacDonald's avatar
Bruce MacDonald committed
483
	// Magic constant for `ggml` files (versioned, ggjt).
484
	FILE_MAGIC_GGJT = 0x67676a74
Bruce MacDonald's avatar
Bruce MacDonald committed
485
	// Magic constant for `ggla` files (LoRA adapter).
486
	FILE_MAGIC_GGLA = 0x67676C61
Bruce MacDonald's avatar
Bruce MacDonald committed
487
	// Magic constant for `gguf` files (versioned, gguf)
488
489
	FILE_MAGIC_GGUF_LE = 0x46554747
	FILE_MAGIC_GGUF_BE = 0x47475546
490
491
)

Bruce MacDonald's avatar
Bruce MacDonald committed
492
493
var ErrUnsupportedFormat = errors.New("unsupported model format")

Michael Yang's avatar
Michael Yang committed
494
func DetectContentType(b []byte) string {
Michael Yang's avatar
Michael Yang committed
495
496
497
498
499
500
501
502
503
	switch binary.LittleEndian.Uint32(b[:4]) {
	case FILE_MAGIC_GGML:
		return "ggml"
	case FILE_MAGIC_GGMF:
		return "ggmf"
	case FILE_MAGIC_GGJT:
		return "ggjt"
	case FILE_MAGIC_GGLA:
		return "ggla"
504
	case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
Michael Yang's avatar
Michael Yang committed
505
506
507
508
509
510
		return "gguf"
	default:
		return ""
	}
}

Michael Yang's avatar
Michael Yang committed
511
// Decode decodes a GGML model from the given reader.
512
513
//
// It collects array values for arrays with a size less than or equal to
Michael Yang's avatar
Michael Yang committed
514
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
515
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
516
517
	rs = bufioutil.NewBufferedSeeker(rs, 32<<10)

518
	var magic uint32
Michael Yang's avatar
Michael Yang committed
519
	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
520
		return nil, err
521
522
523
	}

	var c container
524
525
	switch magic {
	case FILE_MAGIC_GGUF_LE:
526
		c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
527
	case FILE_MAGIC_GGUF_BE:
528
		c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
529
	default:
530
		return nil, errors.New("invalid file magic")
531
532
	}

Michael Yang's avatar
Michael Yang committed
533
	model, err := c.Decode(rs)
534
	if err != nil {
535
		return nil, err
536
537
	}

Michael Yang's avatar
Michael Yang committed
538
539
	offset, err := rs.Seek(0, io.SeekCurrent)
	if err != nil {
540
		return nil, err
Michael Yang's avatar
Michael Yang committed
541
542
	}

543
	// final model type
544
545
546
	return &GGML{
		container: c,
		model:     model,
547
548
		Length:    offset,
	}, nil
549
}
Michael Yang's avatar
Michael Yang committed
550

551
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
Jesse Gross's avatar
Jesse Gross committed
552
553
	context *= uint64(numParallel)

Michael Yang's avatar
Michael Yang committed
554
	embedding := f.KV().EmbeddingLength()
555
	heads := f.KV().HeadCountMax()
556
	headsArr := f.KV().HeadCount()
557
	headsKV := f.KV().HeadCountKVMax()
558
	headsKVArr := f.KV().HeadCountKV()
Michael Yang's avatar
Michael Yang committed
559
	vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
Michael Yang's avatar
Michael Yang committed
560

561
	embeddingHeads := f.KV().EmbeddingHeadCountMax()
Michael Yang's avatar
Michael Yang committed
562
563
	embeddingHeadsK := f.KV().EmbeddingHeadCountK()
	embeddingHeadsV := f.KV().EmbeddingHeadCountV()
Michael Yang's avatar
Michael Yang committed
564

Michael Yang's avatar
Michael Yang committed
565
	layers := f.Tensors().GroupLayers()
Michael Yang's avatar
Michael Yang committed
566

567
	bytesPerElement := kvCacheBytesPerElement(kvCacheType)
568
569
570
571
572
573
574
575
576
577
578

	// Default for models unless special-cased below. These defaults mirror the
	// cache usage in llama.cpp under the assumption that models without special
	// cases below will use the llamarunner and caching will be handled by the
	// llama.cpp layer.
	//
	// This also assumes that a layer without heads or headsKV set is recurrent
	// which is usually the case. Some models (eg nemotronh) use "blocks" in
	// place of layers where some are MLP blocks that don't have any cache.
	// Models like this will need a special case below to be accurately
	// estimated.
Michael Yang's avatar
Michael Yang committed
579
	var kvTotal uint64
580
	kv = make([]uint64, f.KV().BlockCount())
581
582
	kvSizeAttn := uint64(0)
	kvSizeRecurrent := uint64(0)
583
	for i := range kv {
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
		headsL := headsArr[i]
		headsKVL := headsKVArr[i]
		if headsL > 0 && headsKVL > 0 {
			// full attention layer
			// NOTE: Assumes uniform values for all attn layers
			kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
			kvSizeAttn += kv[i]
		} else {
			// recurrent layer
			ssmDConv := f.KV().SSMConvKernel()
			ssmDState := f.KV().SSMStateSize()
			ssmDInner := f.KV().SSMInnerSize()
			ssmNGroups := f.KV().SSMGroupCount()
			nEmbdR := uint64(0)
			if ssmDConv > 0 {
				nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
			}
			nEmbdS := ssmDState * ssmDInner

			// recurrent always uses F32 in llama.cpp backend
			// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
			bytesPerElementRecurrent := kvCacheBytesPerElement("f32")

			kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
			kvSizeRecurrent += kv[i]
		}
Michael Yang's avatar
Michael Yang committed
610
		kvTotal += kv[i]
611
	}
612
	slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
Michael Yang's avatar
Michael Yang committed
613

Michael Yang's avatar
Michael Yang committed
614
	switch f.KV().Architecture() {
Michael Yang's avatar
memory  
Michael Yang committed
615
	case "llama", "llama4":
Michael Yang's avatar
Michael Yang committed
616
617
618
619
		fullOffload = max(
			4*batch*(1+4*embedding+context*(1+heads)),
			4*batch*(embedding+vocab),
		)
Michael Yang's avatar
Michael Yang committed
620
621
622

		partialOffload = 4 * batch * embedding
		partialOffload += max(
Michael Yang's avatar
Michael Yang committed
623
			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
Michael Yang's avatar
Michael Yang committed
624
625
			4*batch*(embedding+vocab)+embedding*vocab*105/128,
		)
Michael Yang's avatar
Michael Yang committed
626

Michael Yang's avatar
Michael Yang committed
627
628
		if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
			// mixtral 8x22b
Michael Yang's avatar
memory  
Michael Yang committed
629
			ff := uint64(f.KV().Uint("feed_forward_length"))
Michael Yang's avatar
Michael Yang committed
630
			partialOffload = max(
Michael Yang's avatar
Michael Yang committed
631
632
				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
				4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
Michael Yang's avatar
Michael Yang committed
633
634
635
			)
		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
			// mixtral 8x7b
Michael Yang's avatar
Michael Yang committed
636
637
638
			ffnGateWeight1 := ffnGateWeight.Shape[1]
			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
			partialOffload = max(
Michael Yang's avatar
Michael Yang committed
639
				4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
Michael Yang's avatar
Michael Yang committed
640
641
642
				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
			)
		}
Michael Yang's avatar
Michael Yang committed
643
644
645
	case "mllama":
		var visionTokens, tiles uint64 = 1601, 4

Michael Yang's avatar
Michael Yang committed
646
		crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
647
		for i := range kv {
Michael Yang's avatar
Michael Yang committed
648
			if slices.Contains(crossAttentionLayers, int32(i)) {
649
650
651
652
653
				kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
					4 * // sizeof(float32)
					visionTokens *
					tiles
			}
Michael Yang's avatar
Michael Yang committed
654
655
		}

Michael Yang's avatar
Michael Yang committed
656
657
658
659
660
661
662
		fullOffload = max(
			4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
			// vocab graph
			4*batch*(embedding+vocab),
		)

		var ropeFreqsCount uint64
Michael Yang's avatar
Michael Yang committed
663
		if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
Michael Yang's avatar
Michael Yang committed
664
			if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
665
				ropeFreqsCount = ropeFreqsWeights.Elements()
Michael Yang's avatar
Michael Yang committed
666
667
668
669
670
671
672
673
674
675
676
			}
		}

		partialOffload = max(
			4*(batch*
				(2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
				ropeFreqsCount+
				embeddingHeadsK*context*headsKV),
			// vocab graph
			4*batch*(embedding+vocab)+embedding*vocab*105/128,
		)
677
	case "gemma", "gemma2", "gemma3", "gemma3n":
Michael Yang's avatar
Michael Yang committed
678
679
680
681
682
683
684
685
686
687
688
		fullOffload = max(
			4*batch*(embedding+vocab),
			4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
		)

		partialOffload = max(
			4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
			4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
				4*embeddingHeadsK*context*8+
				embedding*embeddingHeadsK*heads*9/16,
		)
689

690
691
692
693
694
		if f.KV().Architecture() == "gemma3n" {
			fullOffload *= 4
			partialOffload *= 4
		}

695
696
697
698
699
700
701
702
703
704
705
706
707
		// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
		// engine. Gemma3 always uses the Ollama engine.
		if f.KV().Architecture() == "gemma3" {
			const gemma3GlobalCacheCount = 6
			slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
			for i := range kv {
				// Every 6th layer is a global layer, which is the full context size that has already been set. The other
				// layers are the smaller local (sliding) layers.
				if (i+1)%gemma3GlobalCacheCount != 0 {
					kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
				}
			}
		}
Michael Yang's avatar
Michael Yang committed
708
709
710
711
712
713
714
715
	case "command-r":
		fullOffload = max(
			4*batch*(embedding+vocab),
			4*batch*(2+4*embedding+context*(1+heads)),
		)

		partialOffload = max(
			4*batch*(embedding+vocab)+embedding*vocab*105/128,
Michael Yang's avatar
Michael Yang committed
716
			4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
Michael Yang's avatar
Michael Yang committed
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
		)
	case "qwen2":
		fullOffload = max(
			4*batch*(embedding+vocab),
			4*batch*(1+2*embedding+context+context*heads),
		)

		partialOffload = max(
			4*batch*(embedding+vocab)+embedding*vocab*105/128,
			4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
		)
	case "phi2":
		fullOffload = max(
			4*batch*(embedding+vocab),
			4*batch*(1+4*embedding+context+context*heads),
		)
Michael Yang's avatar
Michael Yang committed
733

Michael Yang's avatar
Michael Yang committed
734
735
736
737
		partialOffload = max(
			4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
			4*batch*(2+3*embedding+context+context*heads),
		)
Michael Yang's avatar
Michael Yang committed
738
739
740
741
742
743
	case "stablelm":
		fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
		partialOffload = max(
			4*batch*(vocab+2*embedding),
			fullOffload,
		)
Michael Yang's avatar
Michael Yang committed
744
745
746
	case "deepseek2":
		fullOffload = max(
			4*batch*(3*embedding+vocab),
Michael Yang's avatar
Michael Yang committed
747
			4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
Michael Yang's avatar
Michael Yang committed
748
749
750
751
		)

		partialOffload = max(
			4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
Michael Yang's avatar
Michael Yang committed
752
			4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
Michael Yang's avatar
Michael Yang committed
753
		)
Michael Yang's avatar
Michael Yang committed
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
	case "chatglm":
		fullOffload = 4 * batch * (embedding + vocab)
		partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
		if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
			fullOffload = max(
				fullOffload,
				4*batch*(2+
					2*embedding+
					context+
					context*heads+
					embeddingHeadsK*heads+
					qkvBias.Shape[0]),
			)

			partialOffload = max(
				partialOffload,
				4*batch*(1+
					2*embedding+
					embeddingHeadsK*heads+
					context+
					context*heads)+
					4*embeddingHeadsK*context+
					4*context*embeddingHeadsK+
					4*qkvBias.Shape[0],
			)
		}
780
	case "gptoss", "gpt-oss":
Michael Yang's avatar
Michael Yang committed
781
782
783
784
785
786
787
788
789
		kv = make([]uint64, f.KV().BlockCount())
		for i := range kv {
			kv[i] = uint64(float64((embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
			if i%2 == 0 {
				kv[i] *= (uint64(numParallel)*4096 + batch)
			} else {
				kv[i] *= context
			}
		}
790

791
		partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
792
793
794
795
		if useFlashAttention {
			// rough estimate of graph size with flash attention on
			partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
		}
Michael Yang's avatar
Michael Yang committed
796
797
	}

Michael Yang's avatar
Michael Yang committed
798
	return
Michael Yang's avatar
Michael Yang committed
799
}
800
801

// SupportsKVCacheType checks if the requested cache type is supported
Michael Yang's avatar
Michael Yang committed
802
func (f GGML) SupportsKVCacheType(cacheType string) bool {
803
804
805
806
807
	if cacheType == "" || cacheType == "f16" {
		return true
	}

	return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
808
809
810
}

// SupportsFlashAttention checks if the model supports flash attention
Michael Yang's avatar
Michael Yang committed
811
812
func (f GGML) SupportsFlashAttention() bool {
	_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
813
814
815
816
	if isEmbedding {
		return false
	}

817
818
819
820
	if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
		return false
	}

821
	// Check head counts match and are non-zero
Michael Yang's avatar
Michael Yang committed
822
823
	headCountK := f.KV().EmbeddingHeadCountK()
	headCountV := f.KV().EmbeddingHeadCountV()
824
825
826
	return headCountK != 0 && headCountV != 0 && headCountK == headCountV
}

827
828
829
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
	return slices.Contains([]string{
830
		"gemma3",
831
		"gptoss", "gpt-oss",
832
833
		"qwen3", "qwen3moe",
		"qwen3vl", "qwen3vlmoe",
834
835
836
	}, f.KV().String("general.architecture"))
}

837
838
839
840
841
842
843
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 {
	switch cacheType {
	case "q8_0":
		return 1 // 1/2 of fp16
	case "q4_0":
		return 0.5 // 1/4 of fp16
844
845
	case "f32":
		return 4 // f32 (default for recurrent)
846
847
848
849
	default:
		return 2 // f16 (default)
	}
}