convert_deepseek2.go 5.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
package convert

import (
	"cmp"
	"fmt"
	"log/slog"
	"regexp"
	"strconv"

	"github.com/ollama/ollama/fs/ggml"
)

type deepseek2Model struct {
	ModelParameters               // architectures, vocab_size
	MaxPositionEmbeddings uint32  `json:"max_position_embeddings"`
	HiddenSize            uint32  `json:"hidden_size"`
	HiddenLayers          uint32  `json:"num_hidden_layers"`
	IntermediateSize      uint32  `json:"intermediate_size"`
	NumAttentionHeads     uint32  `json:"num_attention_heads"`
	NumKeyValueHeads      uint32  `json:"num_key_value_heads"`
	RMSNormEPS            float32 `json:"rms_norm_eps"`

	RopeTheta     float32 `json:"rope_theta"`
	QKNopeHeadDim uint32  `json:"qk_nope_head_dim"`
	QKRopeHeadDim uint32  `json:"qk_rope_head_dim"`
	KVLoraRank    uint32  `json:"kv_lora_rank"`
	QLoraRank     uint32  `json:"q_lora_rank"`
	VHeadDim      uint32  `json:"v_head_dim"`

	ExpertCount            uint32  `json:"n_routed_experts"`
	ExpertSharedCount      uint32  `json:"n_shared_experts"`
	ExpertIntermediateSize uint32  `json:"moe_intermediate_size"`
	ExpertUsedCount        uint32  `json:"num_experts_per_tok"`
	ExpertWeightsNorm      bool    `json:"norm_topk_prob"`
	ExpertWeightsScale     float32 `json:"routed_scaling_factor"`

	ScoringFunc            string `json:"scoring_func"`
	LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`

	RopeScaling struct {
		Factor                        float32 `json:"factor"`
		OriginalMaxPositionEmbeddings uint32  `json:"original_max_position_embeddings"`
		Type                          string  `json:"type"`
		MScaleAllDim                  float32 `json:"mscale_all_dim"`
	} `json:"rope_scaling"`

	Architecture string
}

50
func (p *deepseek2Model) KV(t *Tokenizer) KV {
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
	kv := p.ModelParameters.KV(t)
	kv["general.architecture"] = "deepseek2"
	kv["general.type"] = "model"
	kv["deepseek2.block_count"] = p.HiddenLayers

	numHeads := p.NumAttentionHeads
	numKVHeads := p.NumKeyValueHeads

	kv["deepseek2.attention.head_count"] = numHeads
	kv["deepseek2.attention.head_count_kv"] = numKVHeads
	kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
	kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
	kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
	kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
	kv["deepseek2.attention.value_length"] = p.VHeadDim
	kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
	kv["deepseek2.embedding_length"] = p.HiddenSize
	kv["deepseek2.expert_count"] = p.ExpertCount
	kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
	kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount

	var scoringFunc uint32
	switch p.ScoringFunc {
	case "softmax":
		// not currently supported in the model, but needed for Deepseek-OCR
		scoringFunc = 1
	case "sigmoid":
		scoringFunc = 2
	}
	kv["deepseek2.expert_gating_func"] = scoringFunc
	kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
	kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
	kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
	kv["deepseek2.feed_forward_length"] = p.IntermediateSize
	kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount

	kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
	kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
	kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
	kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
	kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
	kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim

	kv["tokenizer.ggml.pre"] = "deepseek-v3"

	return kv
}

func (p *deepseek2Model) Replacements() []string {
	return []string{
		"lm_head", "output",
		"model.embed_tokens", "token_embd",
		"model.norm", "output_norm",
		"language_model.", "",
		"model.layers", "blk",
		"input_layernorm", "attn_norm",
		"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
		"self_attn.kv_a_layernorm", "attn_kv_a_norm",
		"self_attn.kv_b_proj", "attn_kv_b",
		"self_attn.q_a_proj", "attn_q_a",
		"self_attn.q_a_layernorm", "attn_q_a_norm",
		"self_attn.q_b_proj", "attn_q_b",
		"self_attn.o_proj", "attn_output",
		"post_attention_layernorm", "ffn_norm",
		"mlp.shared_experts.down_proj", "ffn_down_shexp",
		"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
		"mlp.shared_experts.up_proj", "ffn_up_shexp",
		"mlp.gate_proj", "ffn_gate",
		"mlp.down_proj", "ffn_down",
		"mlp.up_proj", "ffn_up",
		"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
		"mlp.gate", "ffn_gate_inp",
	}
}

func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
	merges := make([]merge, p.HiddenLayers*3)
	for i := range p.HiddenLayers {
		merges[i*3+0] = merge{
			fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
			fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
		}
		merges[i*3+1] = merge{
			fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
			fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
		}
		merges[i*3+2] = merge{
			fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
			fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
		}
	}

	skipLayer := func(n string, minValue uint32) bool {
		re := regexp.MustCompile(`^blk\.(\d+)`)
		matches := re.FindStringSubmatch(n)
		if matches == nil {
			return false
		}

		blkNum, err := strconv.Atoi(matches[1])
		if err != nil {
			return false
		}

		return uint32(blkNum) >= minValue
	}

	out, s = mergeTensors(s, merges...)
	for _, t := range s {
		// skip any additional layers (such as the Multi-Token Prediction layer)
		if skipLayer(t.Name(), p.HiddenLayers) {
			slog.Debug("skipping layer", "name", t.Name())
			continue
		}
		out = append(out, &ggml.Tensor{
			Name:     t.Name(),
			Kind:     t.Kind(),
			Shape:    t.Shape(),
			WriterTo: t,
		})
	}
	return out
}