ggml.go 14.2 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
package ggml
2
3
4
5

import (
	"encoding/binary"
	"errors"
Michael Yang's avatar
Michael Yang committed
6
	"fmt"
7
	"io"
Michael Yang's avatar
Michael Yang committed
8
	"log/slog"
9
	"slices"
Michael Yang's avatar
Michael Yang committed
10
	"strings"
11

Michael Yang's avatar
Michael Yang committed
12
	"github.com/ollama/ollama/fs/util/bufioutil"
13
14
)

Michael Yang's avatar
Michael Yang committed
15
16
17
18
type GGML struct {
	container
	model
}
19

Michael Yang's avatar
Michael Yang committed
20
type model interface {
Michael Yang's avatar
Michael Yang committed
21
	KV() KV
Michael Yang's avatar
Michael Yang committed
22
	Tensors() Tensors
23
24
}

25
26
type KV map[string]any

Michael Yang's avatar
Michael Yang committed
27
func (kv KV) Architecture() string {
Michael Yang's avatar
Michael Yang committed
28
	return kv.String("general.architecture", "unknown")
Michael Yang's avatar
Michael Yang committed
29
30
}

31
func (kv KV) Kind() string {
Michael Yang's avatar
Michael Yang committed
32
	return kv.String("general.type", "unknown")
33
34
}

Michael Yang's avatar
Michael Yang committed
35
func (kv KV) ParameterCount() uint64 {
Michael Yang's avatar
Michael Yang committed
36
	return keyValue[uint64](kv, "general.parameter_count")
Michael Yang's avatar
Michael Yang committed
37
38
}

Michael Yang's avatar
Michael Yang committed
39
func (kv KV) FileType() fileType {
Michael Yang's avatar
Michael Yang committed
40
41
	if t := kv.Uint("general.file_type"); t > 0 {
		return fileType(t)
Michael Yang's avatar
Michael Yang committed
42
43
	}

Michael Yang's avatar
Michael Yang committed
44
	return fileTypeUnknown
Michael Yang's avatar
Michael Yang committed
45
46
47
}

func (kv KV) BlockCount() uint64 {
Michael Yang's avatar
Michael Yang committed
48
49
50
51
52
	return uint64(kv.Uint("block_count"))
}

func (kv KV) EmbeddingLength() uint64 {
	return uint64(kv.Uint("embedding_length"))
Michael Yang's avatar
Michael Yang committed
53
54
55
}

func (kv KV) HeadCount() uint64 {
Michael Yang's avatar
Michael Yang committed
56
	return uint64(kv.Uint("attention.head_count"))
Michael Yang's avatar
Michael Yang committed
57
58
59
}

func (kv KV) HeadCountKV() uint64 {
Michael Yang's avatar
Michael Yang committed
60
	return uint64(kv.Uint("attention.head_count_kv", 1))
Michael Yang's avatar
Michael Yang committed
61
62
}

Michael Yang's avatar
Michael Yang committed
63
64
func (kv KV) EmbeddingHeadCount() uint64 {
	if heads := kv.HeadCount(); heads > 0 {
Michael Yang's avatar
Michael Yang committed
65
		return kv.EmbeddingLength() / heads
Michael Yang's avatar
Michael Yang committed
66
67
68
69
70
71
	}

	return 0
}

func (kv KV) EmbeddingHeadCountK() uint64 {
Michael Yang's avatar
Michael Yang committed
72
	return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
Michael Yang's avatar
Michael Yang committed
73
74
75
}

func (kv KV) EmbeddingHeadCountV() uint64 {
Michael Yang's avatar
Michael Yang committed
76
	return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
Michael Yang's avatar
Michael Yang committed
77
78
}

Michael Yang's avatar
Michael Yang committed
79
func (kv KV) GQA() uint64 {
Michael Yang's avatar
Michael Yang committed
80
	return kv.HeadCount() / kv.HeadCountKV()
Michael Yang's avatar
Michael Yang committed
81
82
83
}

func (kv KV) ContextLength() uint64 {
Michael Yang's avatar
Michael Yang committed
84
	return uint64(kv.Uint("context_length"))
Michael Yang's avatar
Michael Yang committed
85
86
}

Michael Yang's avatar
Michael Yang committed
87
func (kv KV) ChatTemplate() string {
Michael Yang's avatar
Michael Yang committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
	return kv.String("tokenizer.chat_template")
}

func (kv KV) String(key string, defaultValue ...string) string {
	return keyValue(kv, key, append(defaultValue, "")...)
}

func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
	return keyValue(kv, key, append(defaultValue, 0)...)
}

func (kv KV) Float(key string, defaultValue ...float32) float32 {
	return keyValue(kv, key, append(defaultValue, 0)...)
}

func (kv KV) Strings(key string, defaultValue ...[]string) []string {
	r := keyValue(kv, key, &array{})
	s := make([]string, r.size)
	for i := range r.size {
		s[i] = r.values[i].(string)
	}

	return s
}

func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
	r := keyValue(kv, key, &array{})
	s := make([]uint32, r.size)
	for i := range r.size {
		s[i] = uint32(r.values[i].(int32))
	}

Michael Yang's avatar
Michael Yang committed
120
121
122
	return s
}

Michael Yang's avatar
Michael Yang committed
123
124
125
126
127
128
129
130
131
132
133
134
135
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
	if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
		key = kv.Architecture() + "." + key
	}

	if val, ok := kv[key]; ok {
		return val.(T)
	}

	slog.Warn("key not found", "key", key, "default", defaultValue[0])
	return defaultValue[0]
}

136
type Tensors struct {
Michael Yang's avatar
Michael Yang committed
137
	items  []*Tensor
138
	Offset uint64
Michael Yang's avatar
Michael Yang committed
139
}
Michael Yang's avatar
Michael Yang committed
140

Michael Yang's avatar
Michael Yang committed
141
142
143
144
func (s Tensors) Items(prefix ...string) []*Tensor {
	if len(prefix) == 0 {
		return s.items
	}
145

Michael Yang's avatar
Michael Yang committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
	var items []*Tensor
	for _, t := range s.items {
		if strings.HasPrefix(t.Name, prefix[0]) {
			items = append(items, t)
		}
	}

	return items
}

func (ts Tensors) GroupLayers() map[string]Layer {
	layers := make(map[string]Layer)
	for _, t := range ts.items {
		parts := strings.Split(t.Name, ".")
		if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
			if len(parts) > index+2 {
				// blk and mm should have a number after them, join it
				parts = append(
					[]string{strings.Join(parts[:index+2], ".")},
					parts[index+2:]...)
166
			}
Michael Yang's avatar
Michael Yang committed
167
		}
168

Michael Yang's avatar
Michael Yang committed
169
170
		if _, ok := layers[parts[0]]; !ok {
			layers[parts[0]] = make(Layer)
Michael Yang's avatar
Michael Yang committed
171
172
		}

Michael Yang's avatar
Michael Yang committed
173
174
175
176
		layers[parts[0]][strings.Join(parts[1:], ".")] = t
	}

	return layers
Michael Yang's avatar
Michael Yang committed
177
178
179
180
}

type Layer map[string]*Tensor

Michael Yang's avatar
Michael Yang committed
181
func (l Layer) Size() (size uint64) {
Michael Yang's avatar
Michael Yang committed
182
	for _, t := range l {
Michael Yang's avatar
Michael Yang committed
183
		size += t.Size()
Michael Yang's avatar
Michael Yang committed
184
185
186
187
188
	}

	return size
}

189
type Tensor struct {
Michael Yang's avatar
Michael Yang committed
190
191
192
	Name   string `json:"name"`
	Kind   uint32 `json:"kind"`
	Offset uint64 `json:"-"`
193
194

	// Shape is the number of elements in each dimension
Michael Yang's avatar
Michael Yang committed
195
	Shape []uint64 `json:"shape"`
196

Michael Yang's avatar
Michael Yang committed
197
	io.WriterTo `json:"-"`
198
199
}

200
201
202
203
204
205
206
207
func (t Tensor) block() (n int) {
	if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
		return -1
	}

	return
}

208
func (t Tensor) blockSize() uint64 {
209
	switch t.Kind {
Michael Yang's avatar
Michael Yang committed
210
	case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
211
		return 1
Michael Yang's avatar
Michael Yang committed
212
	case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
213
		return 32
214
	default: // All others
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
		return 256
	}
}

func (t Tensor) typeSize() uint64 {
	blockSize := t.blockSize()

	switch t.Kind {
	case 0: // FP32
		return 4
	case 1: // FP16
		return 2
	case 2: // Q4_0
		return 2 + blockSize/2
	case 3: // Q4_1
		return 2 + 2 + blockSize/2
	case 6: // Q5_0
		return 2 + 4 + blockSize/2
	case 7: // Q5_1
		return 2 + 2 + 4 + blockSize/2
	case 8: // Q8_0
		return 2 + blockSize
	case 9: // Q8_1
		return 4 + 4 + blockSize
	case 10: // Q2_K
		return blockSize/16 + blockSize/4 + 2 + 2
	case 11: // Q3_K
		return blockSize/8 + blockSize/4 + 12 + 2
	case 12: // Q4_K
		return 2 + 2 + 12 + blockSize/2
	case 13: // Q5_K
		return 2 + 2 + 12 + blockSize/8 + blockSize/2
	case 14: // Q6_K
		return blockSize/2 + blockSize/4 + blockSize/16 + 2
	case 15: // Q8_K
		return 2 + blockSize + 2*blockSize/16
	case 16: // IQ2_XXS
		return 2 + 2*blockSize/8
	case 17: // IQ2_XS
		return 2 + 2*blockSize/8 + blockSize/32
	case 18: // IQ3_XXS
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
		return 2 + blockSize/4 + blockSize/8
	case 19: // IQ1_S
		return 2 + blockSize/8 + blockSize/16
	case 20: // IQ4_NL
		return 2 + blockSize/2
	case 21: // IQ3_S
		return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
	case 22: // IQ2_S
		return 2 + blockSize/4 + blockSize/16
	case 23: // IQ4_XS
		return 2 + 2 + blockSize/2 + blockSize/64
	case 24: // I8
		return 1
	case 25: // I16
		return 2
	case 26: // I32
		return 4
	case 27: // I64
		return 8
	case 28: // F64
		return 8
	case 29: // IQ1_M
		return blockSize/8 + blockSize/16 + blockSize/32
279
280
281
282
283
284
285
286
287
288
289
290
291
	default:
		return 0
	}
}

func (t Tensor) parameters() uint64 {
	var count uint64 = 1
	for _, n := range t.Shape {
		count *= n
	}
	return count
}

Michael Yang's avatar
Michael Yang committed
292
func (t Tensor) Size() uint64 {
293
294
295
	return t.parameters() * t.typeSize() / t.blockSize()
}

296
297
type container interface {
	Name() string
Michael Yang's avatar
Michael Yang committed
298
	Decode(io.ReadSeeker) (model, error)
299
300
301
}

const (
Bruce MacDonald's avatar
Bruce MacDonald committed
302
	// Magic constant for `ggml` files (unversioned).
303
	FILE_MAGIC_GGML = 0x67676d6c
Bruce MacDonald's avatar
Bruce MacDonald committed
304
	// Magic constant for `ggml` files (versioned, ggmf).
305
	FILE_MAGIC_GGMF = 0x67676d66
Bruce MacDonald's avatar
Bruce MacDonald committed
306
	// Magic constant for `ggml` files (versioned, ggjt).
307
	FILE_MAGIC_GGJT = 0x67676a74
Bruce MacDonald's avatar
Bruce MacDonald committed
308
	// Magic constant for `ggla` files (LoRA adapter).
309
	FILE_MAGIC_GGLA = 0x67676C61
Bruce MacDonald's avatar
Bruce MacDonald committed
310
	// Magic constant for `gguf` files (versioned, gguf)
311
312
	FILE_MAGIC_GGUF_LE = 0x46554747
	FILE_MAGIC_GGUF_BE = 0x47475546
313
314
)

Bruce MacDonald's avatar
Bruce MacDonald committed
315
316
var ErrUnsupportedFormat = errors.New("unsupported model format")

Michael Yang's avatar
Michael Yang committed
317
func DetectContentType(b []byte) string {
Michael Yang's avatar
Michael Yang committed
318
319
320
321
322
323
324
325
326
	switch binary.LittleEndian.Uint32(b[:4]) {
	case FILE_MAGIC_GGML:
		return "ggml"
	case FILE_MAGIC_GGMF:
		return "ggmf"
	case FILE_MAGIC_GGJT:
		return "ggjt"
	case FILE_MAGIC_GGLA:
		return "ggla"
327
	case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
Michael Yang's avatar
Michael Yang committed
328
329
330
331
332
333
		return "gguf"
	default:
		return ""
	}
}

Michael Yang's avatar
Michael Yang committed
334
// Decode decodes a GGML model from the given reader.
335
336
337
338
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
// the maxArraySize is negative, all arrays are collected.
Michael Yang's avatar
Michael Yang committed
339
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
340
341
342
343
344
345
	if maxArraySize == 0 {
		maxArraySize = 1024
	}

	rs = bufioutil.NewBufferedSeeker(rs, 32<<10)

346
	var magic uint32
Michael Yang's avatar
Michael Yang committed
347
	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
Michael Yang's avatar
Michael Yang committed
348
		return nil, 0, err
349
350
351
	}

	var c container
352
353
	switch magic {
	case FILE_MAGIC_GGUF_LE:
354
		c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
355
	case FILE_MAGIC_GGUF_BE:
356
		c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
357
	default:
Michael Yang's avatar
Michael Yang committed
358
		return nil, 0, errors.New("invalid file magic")
359
360
	}

Michael Yang's avatar
Michael Yang committed
361
	model, err := c.Decode(rs)
362
	if err != nil {
Michael Yang's avatar
Michael Yang committed
363
		return nil, 0, err
364
365
	}

Michael Yang's avatar
Michael Yang committed
366
367
	offset, err := rs.Seek(0, io.SeekCurrent)
	if err != nil {
Michael Yang's avatar
Michael Yang committed
368
		return nil, 0, err
Michael Yang's avatar
Michael Yang committed
369
370
	}

371
	// final model type
372
373
374
	return &GGML{
		container: c,
		model:     model,
Michael Yang's avatar
Michael Yang committed
375
	}, offset, nil
376
}
Michael Yang's avatar
Michael Yang committed
377

Michael Yang's avatar
Michael Yang committed
378
379
380
381
382
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
	embedding := f.KV().EmbeddingLength()
	heads := f.KV().HeadCount()
	headsKV := f.KV().HeadCountKV()
	vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
Michael Yang's avatar
Michael Yang committed
383

Michael Yang's avatar
Michael Yang committed
384
385
386
	embeddingHeads := f.KV().EmbeddingHeadCount()
	embeddingHeadsK := f.KV().EmbeddingHeadCountK()
	embeddingHeadsV := f.KV().EmbeddingHeadCountV()
Michael Yang's avatar
Michael Yang committed
387

Michael Yang's avatar
Michael Yang committed
388
	layers := f.Tensors().GroupLayers()
Michael Yang's avatar
Michael Yang committed
389

390
	bytesPerElement := kvCacheBytesPerElement(kvCacheType)
Michael Yang's avatar
Michael Yang committed
391
	kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
Michael Yang's avatar
Michael Yang committed
392

Michael Yang's avatar
Michael Yang committed
393
	switch f.KV().Architecture() {
Michael Yang's avatar
Michael Yang committed
394
	case "llama":
Michael Yang's avatar
Michael Yang committed
395
396
397
398
		fullOffload = max(
			4*batch*(1+4*embedding+context*(1+heads)),
			4*batch*(embedding+vocab),
		)
Michael Yang's avatar
Michael Yang committed
399
400
401

		partialOffload = 4 * batch * embedding
		partialOffload += max(
Michael Yang's avatar
Michael Yang committed
402
			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
Michael Yang's avatar
Michael Yang committed
403
404
			4*batch*(embedding+vocab)+embedding*vocab*105/128,
		)
Michael Yang's avatar
Michael Yang committed
405

Michael Yang's avatar
Michael Yang committed
406
407
		if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
			// mixtral 8x22b
Michael Yang's avatar
Michael Yang committed
408
			ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
Michael Yang's avatar
Michael Yang committed
409
			partialOffload = max(
Michael Yang's avatar
Michael Yang committed
410
411
				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
				4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
Michael Yang's avatar
Michael Yang committed
412
413
414
			)
		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
			// mixtral 8x7b
Michael Yang's avatar
Michael Yang committed
415
416
417
			ffnGateWeight1 := ffnGateWeight.Shape[1]
			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
			partialOffload = max(
Michael Yang's avatar
Michael Yang committed
418
				4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
Michael Yang's avatar
Michael Yang committed
419
420
421
				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
			)
		}
Michael Yang's avatar
Michael Yang committed
422
423
424
	case "mllama":
		var visionTokens, tiles uint64 = 1601, 4

Michael Yang's avatar
Michael Yang committed
425
		if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
Michael Yang's avatar
Michael Yang committed
426
427
428
			kv = headsKV *
				(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
				(2* // sizeof(float16)
Michael Yang's avatar
Michael Yang committed
429
					(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
Michael Yang's avatar
Michael Yang committed
430
431
432
433
434
435
436
					context +
					4* // sizeof(float32)
						uint64(crossAttentionLayers.size)* // num cross attention layers
						visionTokens*
						tiles)
		}

Michael Yang's avatar
Michael Yang committed
437
438
439
440
441
442
443
		fullOffload = max(
			4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
			// vocab graph
			4*batch*(embedding+vocab),
		)

		var ropeFreqsCount uint64
Michael Yang's avatar
Michael Yang committed
444
		if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
Michael Yang's avatar
Michael Yang committed
445
446
447
448
449
450
451
452
453
454
455
456
457
			if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
				ropeFreqsCount = ropeFreqsWeights.parameters()
			}
		}

		partialOffload = max(
			4*(batch*
				(2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
				ropeFreqsCount+
				embeddingHeadsK*context*headsKV),
			// vocab graph
			4*batch*(embedding+vocab)+embedding*vocab*105/128,
		)
Michael Yang's avatar
Michael Yang committed
458
459
460
461
462
463
464
465
466
467
468
469
	case "gemma", "gemma2":
		fullOffload = max(
			4*batch*(embedding+vocab),
			4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
		)

		partialOffload = max(
			4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
			4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
				4*embeddingHeadsK*context*8+
				embedding*embeddingHeadsK*heads*9/16,
		)
Michael Yang's avatar
Michael Yang committed
470
471
472
473
474
475
476
477
	case "command-r":
		fullOffload = max(
			4*batch*(embedding+vocab),
			4*batch*(2+4*embedding+context*(1+heads)),
		)

		partialOffload = max(
			4*batch*(embedding+vocab)+embedding*vocab*105/128,
Michael Yang's avatar
Michael Yang committed
478
			4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
Michael Yang's avatar
Michael Yang committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
		)
	case "qwen2":
		fullOffload = max(
			4*batch*(embedding+vocab),
			4*batch*(1+2*embedding+context+context*heads),
		)

		partialOffload = max(
			4*batch*(embedding+vocab)+embedding*vocab*105/128,
			4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
		)
	case "phi2":
		fullOffload = max(
			4*batch*(embedding+vocab),
			4*batch*(1+4*embedding+context+context*heads),
		)
Michael Yang's avatar
Michael Yang committed
495

Michael Yang's avatar
Michael Yang committed
496
497
498
499
		partialOffload = max(
			4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
			4*batch*(2+3*embedding+context+context*heads),
		)
Michael Yang's avatar
Michael Yang committed
500
501
502
503
504
505
	case "stablelm":
		fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
		partialOffload = max(
			4*batch*(vocab+2*embedding),
			fullOffload,
		)
Michael Yang's avatar
Michael Yang committed
506
507
508
	case "deepseek2":
		fullOffload = max(
			4*batch*(3*embedding+vocab),
Michael Yang's avatar
Michael Yang committed
509
			4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
Michael Yang's avatar
Michael Yang committed
510
511
512
513
		)

		partialOffload = max(
			4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
Michael Yang's avatar
Michael Yang committed
514
			4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
Michael Yang's avatar
Michael Yang committed
515
		)
Michael Yang's avatar
Michael Yang committed
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
	case "chatglm":
		fullOffload = 4 * batch * (embedding + vocab)
		partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
		if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
			fullOffload = max(
				fullOffload,
				4*batch*(2+
					2*embedding+
					context+
					context*heads+
					embeddingHeadsK*heads+
					qkvBias.Shape[0]),
			)

			partialOffload = max(
				partialOffload,
				4*batch*(1+
					2*embedding+
					embeddingHeadsK*heads+
					context+
					context*heads)+
					4*embeddingHeadsK*context+
					4*context*embeddingHeadsK+
					4*qkvBias.Shape[0],
			)
		}
Michael Yang's avatar
Michael Yang committed
542
543
	}

Michael Yang's avatar
Michael Yang committed
544
	return
Michael Yang's avatar
Michael Yang committed
545
}
546
547

// SupportsKVCacheType checks if the requested cache type is supported
Michael Yang's avatar
Michael Yang committed
548
549
func (f GGML) SupportsKVCacheType(cacheType string) bool {
	return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
550
551
552
}

// SupportsFlashAttention checks if the model supports flash attention
Michael Yang's avatar
Michael Yang committed
553
554
func (f GGML) SupportsFlashAttention() bool {
	_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
555
556
557
558
559
	if isEmbedding {
		return false
	}

	// Check head counts match and are non-zero
Michael Yang's avatar
Michael Yang committed
560
561
	headCountK := f.KV().EmbeddingHeadCountK()
	headCountV := f.KV().EmbeddingHeadCountV()
562
563
564
565
566
567
568
569
570
571
572
573
574
575
	return headCountK != 0 && headCountV != 0 && headCountK == headCountV
}

// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 {
	switch cacheType {
	case "q8_0":
		return 1 // 1/2 of fp16
	case "q4_0":
		return 0.5 // 1/4 of fp16
	default:
		return 2 // f16 (default)
	}
}