gemma.go 3.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
package convert

import (
	"encoding/binary"
	"fmt"
	"io"
	"log/slog"
	"os"
	"strings"

	"github.com/d4l3k/go-bfloat16"
	"github.com/pdevine/tensor"
	"github.com/pdevine/tensor/native"

	"github.com/ollama/ollama/llm"
)

type GemmaModel struct {
	ModelData
}

func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
	slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))

	data := make([]byte, r.end-r.start)
	if err := binary.Read(f, r.bo, data); err != nil {
		return err
	}

	tDataF32 := bfloat16.DecodeFloat32(data)

	var err error
	tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
	if err != nil {
		return err
	}

	if err := binary.Write(w, r.bo, tDataF32); err != nil {
		return err
	}
	return nil
}

func addOnes(data []float32, vectorSize int) ([]float32, error) {
	n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
	ones := tensor.Ones(tensor.Float32, vectorSize)

	var err error
	n, err = n.Add(ones)
	if err != nil {
		return []float32{}, err
	}

	newN, err := native.SelectF32(n, 0)
	if err != nil {
		return []float32{}, err
	}

	var fullTensor []float32
	for _, v := range newN {
		fullTensor = append(fullTensor, v...)
	}

	return fullTensor, nil
}

func (m *GemmaModel) GetTensors() error {
68
	t, err := m.Format.GetTensors(m.Path, m.Params)
69
70
71
72
	if err != nil {
		return err
	}

73
	slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))
74
75
76
77
78
79
80
81
82
83
84
85
86
	for _, l := range t {
		if strings.HasSuffix(l.Name, "norm.weight") {
			wt := l.WriterTo.(safetensorWriterTo)
			wt.handler = gemmaLayerHandler
			l.WriterTo = wt
		}
		m.Tensors = append(m.Tensors, l)
	}

	return nil
}

func (m *GemmaModel) LoadVocab() error {
87
	v, err := LoadSentencePieceTokens(m.Path, m.Params)
88
89
90
91
92
93
94
	if err != nil {
		return err
	}
	m.Vocab = v
	return nil
}

Michael Yang's avatar
Michael Yang committed
95
func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
	kv := llm.KV{
		"general.architecture":                   "gemma",
		"general.name":                           m.Name,
		"gemma.context_length":                   uint32(m.Params.ContextSize),
		"gemma.embedding_length":                 uint32(m.Params.HiddenSize),
		"gemma.block_count":                      uint32(m.Params.HiddenLayers),
		"gemma.feed_forward_length":              uint32(m.Params.IntermediateSize),
		"gemma.attention.head_count":             uint32(m.Params.AttentionHeads),
		"gemma.attention.head_count_kv":          uint32(m.Params.KeyValHeads),
		"gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
		"gemma.attention.key_length":             uint32(m.Params.HeadDimension),
		"gemma.attention.value_length":           uint32(m.Params.HeadDimension),
		"general.file_type":                      uint32(1),
		"tokenizer.ggml.model":                   "llama",

		"tokenizer.ggml.tokens":     m.Vocab.Tokens,
		"tokenizer.ggml.scores":     m.Vocab.Scores,
		"tokenizer.ggml.token_type": m.Vocab.Types,

		"tokenizer.ggml.bos_token_id":     uint32(m.Params.BoSTokenID),
		"tokenizer.ggml.eos_token_id":     uint32(m.Params.EoSTokenID),
		"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
		"tokenizer.ggml.unknown_token_id": uint32(3),
		"tokenizer.ggml.add_bos_token":    true,
		"tokenizer.ggml.add_eos_token":    false,
	}

Michael Yang's avatar
Michael Yang committed
123
	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
124
}