mixtral.go 2.26 KB
Newer Older
1
2
3
package convert

import (
Michael Yang's avatar
rebase  
Michael Yang committed
4
	"io"
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
	"regexp"

	"github.com/ollama/ollama/llm"
)

type MixtralModel struct {
	ModelData
}

func (m *MixtralModel) GetTensors() error {
	t, err := m.Format.GetTensors(m.Path, m.Params)
	if err != nil {
		return err
	}

	pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
	re, err := regexp.Compile(pattern)
	if err != nil {
		return err
	}

	for _, l := range t {
		matches := re.FindAllStringSubmatch(l.Name, -1)
		if len(matches) > 0 {
			wt := l.WriterTo.(safetensorWriterTo)
			wt.handler = mistralLayerHandler
			l.WriterTo = wt
		}
		m.Tensors = append(m.Tensors, l)
	}

	return nil
}

func (m *MixtralModel) LoadVocab() error {
	v, err := LoadSentencePieceTokens(m.Path, m.Params)
	if err != nil {
		return err
	}
	m.Vocab = v
	return nil
}

Michael Yang's avatar
rebase  
Michael Yang committed
48
func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
	kv := llm.KV{
		"general.architecture":          "llama",
		"general.name":                  m.Name,
		"llama.block_count":             uint32(m.Params.HiddenLayers),
		"llama.context_length":          uint32(m.Params.ContextSize),
		"llama.embedding_length":        uint32(m.Params.HiddenSize),
		"llama.feed_forward_length":     uint32(m.Params.IntermediateSize),
		"llama.attention.head_count":    uint32(m.Params.AttentionHeads),
		"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),

		"llama.rope.freq_base":                   float32(m.Params.RopeFrequencyBase),
		"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),

		"llama.expert_count":      uint32(m.Params.Experts),
		"llama.expert_used_count": uint32(m.Params.ExpertsUsed),

		"llama.vocab_size":           uint32(len(m.Vocab.Tokens)),
		"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),

		"general.file_type":    uint32(1),
		"tokenizer.ggml.model": "llama",

		"tokenizer.ggml.tokens":     m.Vocab.Tokens,
		"tokenizer.ggml.scores":     m.Vocab.Scores,
		"tokenizer.ggml.token_type": m.Vocab.Types,

		"tokenizer.ggml.bos_token_id":     uint32(m.Params.BoSTokenID),
		"tokenizer.ggml.eos_token_id":     uint32(m.Params.EoSTokenID),
		"tokenizer.ggml.unknown_token_id": uint32(0),
		"tokenizer.ggml.add_bos_token":    true,
		"tokenizer.ggml.add_eos_token":    false,
	}

Michael Yang's avatar
rebase  
Michael Yang committed
82
	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
83
}