Unverified Commit 33ee7168 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

Add experimental MLX backend and engine with imagegen support (#13648)



* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------
Co-authored-by: default avatarjmorganca <jmorganca@gmail.com>
parent 34d0c55e
package nn
import "github.com/ollama/ollama/x/ml"
type Embedding struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
return m.Weight.TakeAxes(ctx, hiddenState, 0)
}
package nn
import "github.com/ollama/ollama/x/ml"
type Linear struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}
type LinearBatch struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
panic("not yet ported")
// t = m.Weight.MulmatID(ctx, t, indices)
// if m.Bias != nil {
// t = t.AddID(ctx, m.Bias, indices)
// }
// return t
}
package nn
import (
"github.com/ollama/ollama/x/ml"
)
type LayerNorm struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
}
type RMSNorm struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
// slog.Info("RMSNorm", "eps", eps)
// fmt.Fprintln(os.Stderr, t.ToString())
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
// TODO this is probably model specific, not generalized...
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
return t.RMSNorm(ctx, w, eps)
}
package pooling
import (
"github.com/ollama/ollama/x/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
)
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
// case TypeMean:
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
// case TypeCLS:
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
// case TypeLast:
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
default:
panic("unknown pooling type")
}
}
package rope
import "github.com/ollama/ollama/x/ml"
// Options contains optional parameters for RoPE function
type Options struct {
Type int
Factors ml.Tensor
// YaRN options
YaRN struct {
OriginalContextLength int
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
}
// MRoPE options
MRoPE struct {
Sections []int
}
}
// WithTypeNeoX sets RoPE type to NeoX
func WithTypeNeoX() func(*Options) {
return func(opts *Options) {
opts.Type = 2
}
}
// WithFactors sets custom rope factors
func WithFactors(factors ml.Tensor) func(*Options) {
return func(opts *Options) {
if factors != nil {
opts.Factors = factors
}
}
}
// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
return func(opts *Options) {
opts.YaRN.OriginalContextLength = n
}
}
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.ExtrapolationFactor = extrapolationFactor
}
}
func WithAttentionFactor(attentionFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.AttentionFactor = attentionFactor
}
}
func WithMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1 << 3
opts.MRoPE.Sections = sections
}
}
func WithInterleaveMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1<<3 | 1<<5
opts.MRoPE.Sections = sections
}
}
package ml
import (
"os"
"path/filepath"
"runtime"
)
// LibPath is a path to lookup dynamic libraries
// in development it's usually 'build/lib/ollama'
// in distribution builds it's 'lib/ollama' on Windows
// '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as
// 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string {
exe, err := os.Executable()
if err != nil {
return ""
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var libPath string
switch runtime.GOOS {
case "windows":
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
case "linux":
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
case "darwin":
libPath = filepath.Dir(exe)
}
cwd, err := os.Getwd()
if err != nil {
return ""
}
paths := []string{
libPath,
// build paths for development
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
filepath.Join(cwd, "build", "lib", "ollama"),
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p
}
}
return filepath.Dir(exe)
}()
package model
import (
"cmp"
"fmt"
"iter"
"log/slog"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
type lazyIdsString struct {
ids []int32
}
func (l lazyIdsString) LogValue() slog.Value {
return slog.AnyValue(fmt.Sprint(l.ids))
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
return sb.String(), nil
}
package model
import (
"bufio"
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
t.Fatal(err)
}
types := make([]int32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token
types[id] = 1
}
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
tokens = append(tokens, token) //nolint:makezero
types = append(types, 3) //nolint:makezero
vocab[token] = int32(len(vocab))
}
}
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
merges := make([]string, 0, 50000)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "#") {
merges = append(merges, scanner.Text())
}
}
return NewBytePairEncoding(
&Vocabulary{
Values: tokens,
Types: types,
Merges: merges,
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
}
func TestLlama(t *testing.T) {
tokenizer := llama(t)
t.Run("simple", func(t *testing.T) {
t.Parallel()
ids, err := tokenizer.Encode("hello world", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
s, err := tokenizer.Decode([]int32{15339, 1917})
if err != nil {
t.Fatal(err)
}
if s != "hello world" {
t.Errorf("got %q, want hello world", s)
}
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
t.Run("simple repeated", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
strings.Repeat("0", 1): {15},
strings.Repeat("0", 2): {410},
strings.Repeat("0", 3): {931},
strings.Repeat("0", 4): {931, 15},
strings.Repeat("0", 5): {931, 410},
strings.Repeat("0", 6): {931, 931},
strings.Repeat("0", 7): {931, 931, 15},
strings.Repeat("0", 8): {931, 931, 410},
strings.Repeat("0", 9): {931, 931, 931},
strings.Repeat("0", 10): {931, 931, 931, 15},
strings.Repeat("0", 11): {931, 931, 931, 410},
strings.Repeat("0", 12): {931, 931, 931, 931},
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
}
}
})
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件!12345",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Error(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
})
t.Run("special", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("split", func(t *testing.T) {
t.Parallel()
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
"Hello World": {"Hello", " ", " World"},
"Hello\nWorld": {"Hello", "\n", "World"},
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
t.Parallel()
for b := 0x00; b <= 0xFF; b++ {
input := string(rune(b))
ids, err := tokenizer.Encode(input, false)
if err != nil {
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
continue
}
if b == 0x00 {
if len(decoded) != 0 {
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
}
continue
}
if decoded != input {
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
}
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
if err != nil {
b.Fatal(err)
}
for i := range 8 {
n := min(int(math.Pow10(i)), len(bts))
bts := bts[:n]
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Decode(ids)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
slices.Collect(tokenizer.split(string(bts)))
}
})
}
}
func TestSplit(t *testing.T) {
cases := []struct {
name string
patterns,
want []string
}{
{
name: "default",
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
},
{
name: "unicode",
patterns: []string{
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
},
{
name: "individual digits",
patterns: []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
}
}
package input
import "github.com/ollama/ollama/x/ml"
// Multimodal is a multimodal embedding or a component of one.
// For example, it could be a row of an image that can be processed
// independently.
type Multimodal struct {
// Tensor is the embedding data. Implementations may chose what to
// store here or it may be nil if not needed. However, any ml.Tensor
// objects must be stored here and not in Data.
Tensor ml.Tensor
// Data is implementation-specific opaque data, such as metadata on how
// to layout Tensor. It may be nil if not needed. It may also store larger
// objects such as complete images if they are to be processed later.
Data any
}
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is represents a non-text element such as an
// image (or part of one if the image can be processed in pieces).
// It may be used either together with Token or on its own.
Multimodal []Multimodal
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
// SameBatch forces the following number of tokens to be processed
// in a single batch, breaking and extending batches as needed.
// Useful for things like images that must be processed in one
// shot.
SameBatch int
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal []Multimodal
}
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs ml.Tensor
// TODO maybe not the optimal way to handle this
// Offset of final tensor in the final batch
Offset int
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
}
package model
import (
"errors"
"fmt"
_ "image/jpeg"
_ "image/png"
"log/slog"
"os"
"reflect"
"strconv"
"strings"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
_ "github.com/ollama/ollama/x/ml/backend"
"github.com/ollama/ollama/x/ml/nn/pooling"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend
Config() config
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is one or more tensors, each with optional model-specific
// opaque metadata. Typically, the tensors might be views into an embedding
// with each view representing a chunk of data that can be processed independently
// in different batches.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]*input.Input) ([]*input.Input, error)
}
// Base implements the common fields and methods for all models
type Base struct {
b ml.Backend
config
}
type config struct {
Cache kvcache.Cache
}
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
}
func (m *Base) Config() config {
return m.config
}
var models = make(map[string]func(fs.Config) (Model, error))
// Register registers a model constructor for the given architecture
func Register(name string, f func(fs.Config) (Model, error)) {
if _, ok := models[name]; ok {
panic("model: model already registered")
}
models[name] = f
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
b, err := ml.NewBackend(modelPath, params)
if err != nil {
return nil, err
}
m, err := modelForArch(b.Config())
if err != nil {
return nil, err
}
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
m, err := modelForArch(meta.KV())
if err != nil {
return nil, err
}
tp, ok := m.(TextProcessor)
if !ok {
return nil, ErrUnsupportedTokenizer
}
return tp, nil
}
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, ErrUnsupportedModel
}
return f(c)
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
if t.Kind() == reflect.Struct {
allNil := true
for i := range t.NumField() {
tt := t.Field(i).Type
vv := v.Field(i)
if !vv.CanSet() {
continue
}
// make a copy
tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, parseTag(tag))
}
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 {
var names []string
if tags[0].name != "" {
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
names = append(names, prefix+n+suffix)
}
}
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
if len(names) == 0 {
// current tag has no name, use child names only
fullNames = append(fullNames, childNames...)
} else if len(childNames) == 0 {
// current tag has names but no children, create branches for each name
for _, name := range names {
fullNames = append(fullNames, []string{name})
}
} else {
// merge each name with each child
for _, name := range names {
for _, childName := range childNames {
fullNames = append(fullNames, append([]string{name}, childName...))
}
}
}
}
return fullNames
}
names := fn(tagsCopy, "", "")
for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
logutil.Trace("found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor))
break
}
}
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
setPointer(base, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
for i := range vv.Len() {
vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
} else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
}
}
}
if !canNil(tt) || !vv.IsNil() {
allNil = false
}
}
if allNil {
return reflect.Zero(t)
}
}
return v
}
func setPointer(base Base, v reflect.Value, tags []Tag) {
vv := v
if v.Kind() == reflect.Interface {
if v.IsNil() {
return
}
vv = vv.Elem()
}
vv = reflect.Indirect(vv)
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}
if f := populateFields(base, vv, tags...); f.CanAddr() {
v.Set(f.Addr())
}
}
type Tag struct {
name,
// prefix and suffix are applied to child tags
prefix,
suffix string
alternatives []string
}
func parseTag(s string) (tag Tag) {
parts := strings.Split(s, ",")
if len(parts) > 0 {
tag.name = parts[0]
for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
// elevate alternative to primary if no primary given
tag.name = value
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
} else if ok {
tag.alternatives = append(tag.alternatives, value)
}
if value, ok := strings.CutPrefix(part, "pre:"); ok {
tag.prefix = value
}
if value, ok := strings.CutPrefix(part, "suf:"); ok {
tag.suffix = value
}
}
}
return
}
func canNil(t reflect.Type) bool {
return t.Kind() == reflect.Chan ||
t.Kind() == reflect.Func ||
t.Kind() == reflect.Interface ||
t.Kind() == reflect.Map ||
t.Kind() == reflect.Pointer ||
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, false)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
}
ctx.Forward(t)
return t, nil
}
//go:build mlx
package gemma3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/ml/nn/pooling"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
)
type embedModel struct {
model.Base
model.SentencePiece
*TextModel
poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: newTextModel(c),
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
return m, nil
}
//go:build mlx
package gemma3
import (
"bytes"
"image"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
)
type Model struct {
model.Base
model.SentencePiece
*VisionModel `gguf:"vision_tower.vision_model"`
*TextModel `gguf:"language_model.model"`
*MultiModalProjector `gguf:"multi_modal_projector"`
ImageProcessor
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight
tokensPerImage int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false))
return visionOutputs
}
func New(c fs.Config) (model.Model, error) {
// slog.Info("XXX Config", "c", c)
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
MultiModalProjector: &MultiModalProjector{
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
},
}
// slidingWindowLen := int32(c.Uint("attention.sliding_window"))
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
// TODO need to implement sliding window...
m.Cache = kvcache.NewMLXCausalCache()
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues := ctx.Input().FromFloats(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal[0].Tensor
result = append(result,
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
&input.Input{Token: 255999}, // "<start_of_image>""
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result,
&input.Input{Token: 256000}, // <end_of_image>
&input.Input{Token: 108}, // "\n\n"
)
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("gemma3", New)
model.Register("gemma3_embed", newEmbedModel)
}
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model/input"
)
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
Layers []TextLayer `gguf:"layers"`
OutputNorm *nn.RMSNorm `gguf:"norm"`
Output *nn.Linear `gguf:"embed_tokens"`
*TextConfig
}
const (
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 62
)
// const (
// cacheTypeSWA = iota
// cacheTypeCausal
// )
func newTextModel(c fs.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
ropeScale: 1, // 1 - default is 1, implied in python code
// vocabSize: vocabSize, // 262144
// attnValLen: int(c.Uint("attention.value_length", 256)), //256
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
},
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"q_proj"`
QueryNorm *nn.RMSNorm `gguf:"q_norm"`
Key *nn.Linear `gguf:"k_proj"`
KeyNorm *nn.RMSNorm `gguf:"k_norm"`
Value *nn.Linear `gguf:"v_proj"`
Output *nn.Linear `gguf:"o_proj"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
B := hiddenState.Dim(0)
L := hiddenState.Dim(1)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
k := sa.Key.Forward(ctx, hiddenState)
v := sa.Value.Forward(ctx, hiddenState)
q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
traditional := false
q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
// TODO - this is wrong somehow so commenting out
// if opts.largeModelScaling {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
// } else {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
// }
scaleFactor := math.Pow(256, -0.5)
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// ropeBase := m.TextConfig.ropeLocalBase
// if (layer+1)%gemmaGlobalCacheCount == 0 {
// ropeBase = m.TextConfig.ropeGlobalBase
// }
// q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
panic("not yet implemented")
// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"up_proj"`
Down *nn.Linear `gguf:"down_proj"`
Gate *nn.Linear `gguf:"gate_proj"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
SelfAttention *TextSelfAttention `gguf:"self_attn"`
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
MLP *TextMLP `gguf:"mlp"`
PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
residual = residual.TakeAxes(ctx, outputs, 1)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
// var except []int
// for _, image := range batch.Multimodal {
// visionOutputs := image.Multimodal[0].Tensor
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
// []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
// []int{image.Index * hiddenState.Stride(1)}, 0)))
// for i := range visionOutputs.Dim(1) {
// except = append(except, image.Index+i)
// }
// }
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
if cache != nil {
// cacheType := cacheTypeSWA
// if (i+1)%gemmaGlobalCacheCount == 0 {
// cacheType = cacheTypeCausal
// }
cache.SetLayer(i)
// TODO this needs to come back
// wc := cache.(*kvcache.WrapperCache)
// wc.SetLayerType(cacheType)
// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
// causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
// }
}
var offset int
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
offset = batch.Offset
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"self_attn.q_proj"`
Key *nn.Linear `gguf:"self_attn.k_proj"`
Value *nn.Linear `gguf:"self_attn.v_proj"`
Output *nn.Linear `gguf:"self_attn.out_proj"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Layers []VisionEncoderLayer `gguf:"encoder.layers"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32)
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}
//go:build mlx
package gemma3
import (
"image"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
var pixelVals, rVals, gVals, bVals []float32
bounds := img.Bounds()
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil
}
package models
// _ "github.com/ollama/ollama/x/model/models/gemma3"
package model
import (
"container/heap"
"fmt"
"log/slog"
"strconv"
"strings"
"github.com/ollama/ollama/logutil"
)
const spmWhitespaceSep = "▁"
type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
if id := spm.vocab.Encode(text); id >= 0 {
ids = append(ids, id)
continue
}
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(q, pair)
}
}
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
heap.Push(q, pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
heap.Push(q, pair)
}
}
for _, merge := range merges {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
if id >= 0 {
ids = append(ids, id)
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
result = append(result, unknownID)
} else {
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
}
ids = append(ids, result...)
}
}
}
if addSpecial {
ids = spm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err)
}
if err := sb.WriteByte(byte(byteVal)); err != nil {
return "", err
}
} else {
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "ids", ids, "string", sb.String())
return sb.String(), nil
}
package model
import (
"log/slog"
"os"
"path/filepath"
"slices"
"testing"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}
var spm sentencepiece.ModelProto
if err := proto.Unmarshal(bts, &spm); err != nil {
t.Fatal(err)
}
var v Vocabulary
for _, piece := range spm.GetPieces() {
v.Values = append(v.Values, piece.GetPiece())
v.Scores = append(v.Scores, piece.GetScore())
switch t := piece.GetType(); t {
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, int32(t))
default:
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
v.Types = append(v.Types, tt)
}
}
return NewSentencePiece(&v)
}
func TestSentencePieceEncode(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
tokenizer := loadSentencePieceVocab(t)
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件!12345",
"你好",
"Hello 你好 world!",
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
"Numbers and symbols: 123456789 +- */",
"Special tokens: <bos> text <eos>",
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Fatal(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q [%#v]", got, want, ids)
}
}
})
t.Run("special tokens", func(t *testing.T) {
type candidate struct {
token string
ids []int32
}
cases := []candidate{
{"<bos>", []int32{2}},
{"<eos>", []int32{1}},
}
for _, want := range cases {
ids, err := tokenizer.Encode(want.token, true)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ids, want.ids) {
t.Errorf("got %#v, want %#v", ids, want.ids)
}
}
})
}
func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
"<0xEA>",
"<0x41>",
"<0xC3>",
"<0xA3>",
},
Types: []int32{
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
},
Scores: []float32{0, 0, 0, 0, 0},
}
spm := NewSentencePiece(vocab)
tests := []struct {
name string
ids []int32
expected string
}{
{
name: "single byte token",
ids: []int32{1},
expected: "\xea",
},
{
name: "ASCII byte token",
ids: []int32{2},
expected: "A",
},
{
name: "multiple byte tokens forming UTF-8 character",
ids: []int32{3, 4},
expected: "ã",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := spm.Decode(tt.ids)
if err != nil {
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
}
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
package model
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}
package model
import (
"log/slog"
"slices"
"sync"
)
type Special int32
const (
SpecialBOS Special = iota
SpecialEOS
)
type Vocabulary struct {
Values []string
Types []int32
Scores []float32
Merges []string
BOS, EOS []int32
AddBOS, AddEOS bool
specialOnce sync.Once
special []string
valuesOnce sync.Once
values map[string]int32
mergeOnce sync.Once
merge map[string]int32
}
func (v *Vocabulary) Is(id int32, special Special) bool {
switch special {
case SpecialBOS:
return slices.Contains(v.BOS, id)
case SpecialEOS:
return slices.Contains(v.EOS, id)
default:
return false
}
}
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
if v.AddBOS && len(v.BOS) > 0 {
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
}
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
ids = append([]int32{v.BOS[0]}, ids...)
}
if v.AddEOS && len(v.EOS) > 0 {
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
}
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
ids = append(ids, v.EOS[0])
}
return ids
}
func (v *Vocabulary) Encode(s string) int32 {
v.valuesOnce.Do(func() {
v.values = make(map[string]int32, len(v.Values))
for i, value := range v.Values {
v.values[value] = int32(i)
}
})
if id, ok := v.values[s]; ok {
return id
}
return -1
}
func (v *Vocabulary) Decode(id int32) string {
return v.Values[id]
}
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
v.special = append(v.special, v.Values[i])
}
}
})
return v.special
}
func (v *Vocabulary) Merge(left, right string) int {
v.mergeOnce.Do(func() {
v.merge = make(map[string]int32, len(v.Merges))
for i, merge := range v.Merges {
v.merge[merge] = int32(i)
}
})
if id, ok := v.merge[left+" "+right]; ok {
return int(id)
}
return -1
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment