convert_mistral_causal.go 5.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package convert

import (
	"cmp"
	"fmt"
	"strings"

	"github.com/pdevine/tensor"
	"github.com/pdevine/tensor/native"

	"github.com/ollama/ollama/fs/ggml"
)

type mistral3CausalModel struct {
	ModelParameters

	NumHiddenLayers       uint32  `json:"num_hidden_layers"`
	MaxPositionEmbeddings uint32  `json:"max_position_embeddings"`
	HiddenSize            uint32  `json:"hidden_size"`
	IntermediateSize      uint32  `json:"intermediate_size"`
	NumAttentionHeads     uint32  `json:"num_attention_heads"`
	NumKeyValueHeads      uint32  `json:"num_key_value_heads"`
	RopeTheta             float32 `json:"rope_theta"`
	RMSNormEPS            float32 `json:"rms_norm_eps"`
	HeadDim               uint32  `json:"head_dim"`
	SlidingWindow         *uint32 `json:"sliding_window"`
	HiddenAct             string  `json:"hidden_act"`
	VocabSize             uint32  `json:"vocab_size"`
	RopeParameters        struct {
		BetaFast                  float32  `json:"beta_fast"`
		BetaSlow                  float32  `json:"beta_slow"`
		Factor                    float32  `json:"factor"`
		Llama4ScalingBeta         *float32 `json:"llama_4_scaling_beta"`
		OrigMaxPositionEmbeddings uint32   `json:"original_max_position_embeddings"`
		RopeType                  string   `json:"rope_type"`
		RopeTheta                 float32  `json:"rope_theta"`
		Mscale                    *float32 `json:"mscale"`
		MscaleAllDim              *float32 `json:"mscale_all_dim"`
	} `json:"rope_parameters"`
}

42
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
	kv := p.ModelParameters.KV(t)
	kv["general.architecture"] = "mistral3"
	kv["mistral3.vocab_size"] = p.VocabSize

	// Text configuration
	kv["mistral3.block_count"] = p.NumHiddenLayers
	kv["mistral3.context_length"] = p.MaxPositionEmbeddings
	kv["mistral3.embedding_length"] = p.HiddenSize
	kv["mistral3.feed_forward_length"] = p.IntermediateSize
	kv["mistral3.attention.head_count"] = p.NumAttentionHeads
	kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
	kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
	kv["mistral3.attention.key_length"] = p.HeadDim
	kv["mistral3.attention.value_length"] = p.HeadDim
	kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
	kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
	kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
	kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
	kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
	kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow

	if p.RopeParameters.Mscale != nil {
		kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
	}

	if p.RopeParameters.MscaleAllDim != nil {
		kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
	}

	if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
		kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
		kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
	}

	if p.RopeParameters.Llama4ScalingBeta != nil {
		kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
	}

	return kv
}

func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
	var out []*ggml.Tensor

	for _, t := range ts {
		if !strings.HasPrefix(t.Name(), "v.") {
			if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
				strings.HasSuffix(t.Name(), ".attn_k.weight") {
				t.SetRepacker(p.repack)
			}
		}

		out = append(out, &ggml.Tensor{
			Name:     t.Name(),
			Kind:     t.Kind(),
			Shape:    t.Shape(),
			WriterTo: t,
		})
	}

	return out
}

func (p *mistral3CausalModel) Replacements() []string {
	return []string{
		"model.norm", "output_norm",
		"model.", "",
		"layers", "blk",
		"transformer.layers", "blk",
		"vision_tower", "v",
		"ln_pre", "encoder_norm",
		"input_layernorm", "attn_norm",
		"post_attention_layernorm", "ffn_norm",
		"embed_tokens", "token_embd",
		"self_attn.q_proj", "attn_q",
		"self_attn.k_proj", "attn_k",
		"self_attn.v_proj", "attn_v",
		"self_attn.o_proj", "attn_output",
		"mlp.down_proj", "ffn_down",
		"mlp.gate_proj", "ffn_gate",
		"mlp.up_proj", "ffn_up",
		"attention.q_proj", "attn_q",
		"attention.k_proj", "attn_k",
		"attention.v_proj", "attn_v",
		"attention.o_proj", "attn_output",
		"attention_norm", "attn_norm",
		"feed_forward.gate_proj", "ffn_gate",
		"feed_forward.down_proj", "ffn_down",
		"feed_forward.up_proj", "ffn_up",
		"multi_modal_projector", "mm",
		"ffn_norm", "ffn_norm",
		"lm_head", "output",
	}
}

func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
	var dims []int
	for _, dim := range shape {
		dims = append(dims, int(dim))
	}

	var heads uint32
	if strings.HasSuffix(name, ".attn_q.weight") {
		heads = p.NumAttentionHeads
	} else if strings.HasSuffix(name, ".attn_k.weight") {
		heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
	} else {
		return nil, fmt.Errorf("unknown tensor for repack: %s", name)
	}

	n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
	if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
		return nil, err
	}

	if err := n.T(0, 2, 1, 3); err != nil {
		return nil, err
	}

	if err := n.Reshape(dims...); err != nil {
		return nil, err
	}

	if err := n.Transpose(); err != nil {
		return nil, err
	}

	ts, err := native.SelectF32(n, 1)
	if err != nil {
		return nil, err
	}

	var f32s []float32
	for _, t := range ts {
		f32s = append(f32s, t...)
	}

	return f32s, nil
}