Unverified Commit 33ee7168 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

Add experimental MLX backend and engine with imagegen support (#13648)



* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------
Co-authored-by: default avatarjmorganca <jmorganca@gmail.com>
parent 34d0c55e
//go:build mlx
package safetensors
import (
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
func TestLoadModelWeights(t *testing.T) {
// Skip if no model available
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// Check we found tensors
tensors := mw.ListTensors()
if len(tensors) == 0 {
t.Fatal("no tensors found")
}
t.Logf("found %d tensors", len(tensors))
// Check HasTensor
if !mw.HasTensor(tensors[0]) {
t.Errorf("HasTensor(%q) = false", tensors[0])
}
if mw.HasTensor("nonexistent.weight") {
t.Error("HasTensor returned true for nonexistent tensor")
}
}
func TestGetTensor(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
tensors := mw.ListTensors()
if len(tensors) == 0 {
t.Skip("no tensors")
}
// Load first tensor
arr, err := mw.GetTensor(tensors[0])
if err != nil {
t.Fatalf("GetTensor(%q): %v", tensors[0], err)
}
// Verify it has a shape
shape := arr.Shape()
if len(shape) == 0 {
t.Error("tensor has no shape")
}
t.Logf("%s: shape=%v dtype=%v", tensors[0], shape, arr.Dtype())
}
func TestLoadWithDtype(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// Load all tensors as bfloat16
if err := mw.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Load: %v", err)
}
// Get a tensor from cache
tensors := mw.ListTensors()
arr, err := mw.Get(tensors[0])
if err != nil {
t.Fatalf("Get: %v", err)
}
// Verify dtype (unless it was already bf16)
t.Logf("%s: dtype=%v", tensors[0], arr.Dtype())
}
func TestLookupTensor(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// HasTensor returns false for nonexistent
if mw.HasTensor("nonexistent") {
t.Error("HasTensor should return false for nonexistent")
}
// HasTensor returns true for existing tensor
tensors := mw.ListTensors()
if !mw.HasTensor(tensors[0]) {
t.Error("HasTensor should return true for existing tensor")
}
}
func TestParseSafetensorHeader(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
// Find a safetensors file
entries, err := os.ReadDir(modelDir)
if err != nil {
t.Fatal(err)
}
var stFile string
for _, e := range entries {
if filepath.Ext(e.Name()) == ".safetensors" {
stFile = filepath.Join(modelDir, e.Name())
break
}
}
if stFile == "" {
t.Skip("no safetensors file found")
}
header, err := parseSafetensorHeader(stFile)
if err != nil {
t.Fatalf("parseSafetensorHeader: %v", err)
}
if len(header) == 0 {
t.Error("header is empty")
}
// Check a tensor has valid info
for name, info := range header {
if info.Dtype == "" {
t.Errorf("%s: empty dtype", name)
}
if len(info.Shape) == 0 {
t.Errorf("%s: empty shape", name)
}
break // just check one
}
}
# Tokenizer
Tokenizer for LLM inference supporting BPE, SentencePiece, and WordPiece algorithms. The goal of this package is to see if a pure Go tokenizer can be fast and correct. It primarily supports the `imagegen` models however it (or parts of it) could be considered to replace Ollama's tokenizer in the `model` package.
## Features
- **BPE (Byte Pair Encoding)** - GPT-2/Llama style with byte-level encoding
- **SentencePiece** - Gemma style with `▁` space handling
- **WordPiece** - BERT style with `##` continuation tokens
- **Parallel encoding** - Automatic parallelization for inputs >4KB
- **HuggingFace compatible** - Loads `tokenizer.json` directly
## Usage
```go
import "github.com/ollama/ollama/x/imagegen/tokenizer"
// Load from HuggingFace model directory
tok, err := tokenizer.Load("./weights/Llama-3.2-1B")
if err != nil {
log.Fatal(err)
}
// Encode text to token IDs
ids := tok.Encode("Hello, world!", false) // false = don't add BOS
// Decode back to text
text := tok.Decode(ids)
// Check special tokens
if tok.IsEOS(ids[len(ids)-1]) {
// End of sequence
}
```
## Performance
Benchmarks on Apple M3 Max:
| Input Size | Encode | Decode | Tokens |
|------------|--------|--------|--------|
| 1 KB | 14.5 MB/s | 267 MB/s | 231 |
| 10 KB | 10.9 MB/s | 321 MB/s | 2,301 |
| 100 KB | 8.9 MB/s | 311 MB/s | 23,001 |
| 1 MB | 9.6 MB/s | 321 MB/s | 230,001 |
Comparison with other implementations (10 MB input):
| Implementation | Encode Speed | Notes |
|----------------|--------------|-------|
| Engine (this) | ~10 MB/s | stdlib RE2, parallel >4KB |
| tiktoken (Rust) | ~17 MB/s | Highly optimized regex |
| Ollama (Go) | ~2-3 MB/s | regexp2 backtracking |
## Performance Opportunities
Potential optimizations not yet implemented:
| Optimization | Expected Gain | Complexity |
|--------------|---------------|------------|
| Aho-Corasick for special tokens | 2-3x for many special tokens | Medium |
| Custom regex engine (like tiktoken) | 1.5-2x | High |
| SIMD byte scanning | 1.3-1.5x for pretokenizer | Medium |
| Assembly BPE merge loop | 1.2-1.5x | High |
| Memoization for repeated substrings | Variable | Low |
Current bottleneck is the pretokenizer regex (~60% of encode time). tiktoken achieves ~17 MB/s with a hand-tuned Rust regex engine.
## Not Yet Implemented
| Feature | Used By | Notes |
|---------|---------|-------|
| Unigram tokenizer | T5, ALBERT, mBART | Different algorithm (not BPE) |
| Unicode normalizers | Some multilingual models | NFD, NFKC, lowercase, etc. |
| Custom pretokenizers | Model-specific | Beyond standard patterns |
Most HuggingFace models use BPE or SentencePiece, which are fully supported. WordPiece (BERT-style) is also supported with standard `[UNK]` fallback for out-of-vocabulary characters.
## Files
| File | Description |
|------|-------------|
| `tokenizer.go` | Main implementation (~1000 lines) |
| `tokenizer_test.go` | Tests and benchmarks |
| `testdata/` | Mini tokenizer for unit tests |
{"model": {"type": "BPE", "vocab": {"!": 0, "\"": 1, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "fé": 59958, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "r": 81, "q": 80, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "¡": 94, "¢": 95, "£": 96, "¤": 97, "¥": 98, "¦": 99, "§": 100, "¨": 101, "World": 10343, "©": 102, "ª": 103, "«": 104, "¬": 105, "®": 106, "world": 14957, "¯": 107, "°": 108, "±": 109, "²": 110, "³": 111, "´": 112, "µ": 113, "¶": 114, "·": 115, "¸": 116, "¹": 117, "º": 118, "»": 119, "¼": 120, "½": 121, "¾": 122, "¿": 123, "À": 124, "Á": 125, "Â": 126, "Ã": 127, "Ä": 128, "Å": 129, "Æ": 130, "Ç": 131, "È": 132, "É": 133, "Ê": 134, "Ë": 135, "Ì": 136, "Í": 137, "Î": 138, "Ï": 139, "Ð": 140, "Ñ": 141, "Ò": 142, "Ó": 143, "Ô": 144, "Õ": 145, "Ö": 146, "×": 147, "Ø": 148, "Ù": 149, "Ú": 150, "Û": 151, "Ü": 152, "Ý": 153, "Þ": 154, "ß": 155, "à": 156, "á": 157, "â": 158, "ã": 159, "ä": 160, "å": 161, "æ": 162, "ç": 163, "è": 164, "é": 165, "ê": 166, "ë": 167, "ì": 168, "Ġhello": 24748, "í": 169, "î": 170, "ï": 171, "ð": 172, "ñ": 173, "Hello": 9906, "ò": 174, "ó": 175, "ô": 176, "õ": 177, "ö": 178, "Ġ{}": 4792, "÷": 179, "ø": 180, "ù": 181, "ú": 182, "û": 183, "ü": 184, "ý": 185, "þ": 186, "ÿ": 187, "Ā": 188, "ā": 189, "Ă": 190, "ă": 191, "Ċ": 198, "Ą": 192, "ą": 193, "Ć": 194, "ć": 195, "Ĉ": 196, "ĉ": 197, "ċ": 199, "Č": 200, "č": 201, "Ď": 202, "ď": 203, "Đ": 204, "đ": 205, "Ē": 206, "ē": 207, "Ĕ": 208, "ĕ": 209, "Ė": 210, "ė": 211, "Ę": 212, "ę": 213, "Ġ": 220, "Ě": 214, "ě": 215, "Ĝ": 216, "ĝ": 217, "Ğ": 218, "ğ": 219, "ġ": 221, "Ģ": 222, "ģ": 223, "Ĥ": 224, "ĥ": 225, "Ħ": 226, "ħ": 227, "Ĩ": 228, "ĩ": 229, "Ī": 230, "ī": 231, "Ĭ": 232, "ĭ": 233, "Į": 234, "į": 235, "İ": 236, "ı": 237, "IJ": 238, "ij": 239, "Ĵ": 240, "ĵ": 241, "Ķ": 242, "ķ": 243, "ĸ": 244, "Ĺ": 245, "ĺ": 246, "Ļ": 247, "ļ": 248, "Ľ": 249, "ĠĠ": 256, "ľ": 250, "Ŀ": 251, "ŀ": 252, "Ł": 253, "rer": 38149, "ĠĠĠ": 262, "ł": 254, "Ń": 255, "'m": 2846, "'re": 2351, "can": 4919, "func": 2900, "()": 368, "Ġworld": 1917, "Ġmain": 1925, "00": 410, "123": 4513, "000": 931, "ca": 936, "'t": 956, "é": 978, "hello": 15339, "Ġw": 289, "orld": 1410, "Ġwor": 4191, "ld": 509, "main": 3902, "Ġm": 296, "ain": 467, "Ġma": 7643, "in": 258, "Ġmai": 17154, "re": 265, "'r": 97670, "unc": 1371, "fun": 12158, "fu": 33721, "nc": 1031, "ma": 1764, "mai": 77585, "wor": 50810, "or": 269, "Ġwo": 24670, "23": 1419, "12": 717, "{}": 6390, "Ġ{": 314, "an": 276, "ello": 4896, "Hel": 33813, "lo": 385, "Hell": 81394, "un": 359, "hel": 50222, "hell": 57195, "ai": 2192, "wo": 1146, "Ġh": 305, "Ġhel": 11591, "Ġhell": 15123, "el": 301, "He": 1548, "er": 261, "he": 383, "ell": 616, "ll": 657}, "merges": ["Ġ Ġ", "Ġ ĠĠ", "ĠĠ Ġ", "( )", "0 0", "0 00", "00 0", "c a", "' t", "à ©", "Ġ world", "Ġw orld", "Ġwor ld", "Ġ main", "Ġm ain", "Ġma in", "Ġmai n", "' re", "'r e", "' m", "f unc", "fun c", "fu nc", "m ain", "ma in", "mai n", "Ġ wor", "Ġw or", "Ġwo r", "1 23", "12 3", "Ġ {}", "Ġ{ }", "c an", "ca n", "{ }", "Ġ ma", "Ġm a", "H ello", "Hel lo", "Hell o", "W orld", "f un", "fu n", "w orld", "wor ld", "h ello", "hel lo", "hell o", "Ġ mai", "Ġm ai", "Ġma i", "Ġ wo", "Ġw o", "Ġ hello", "Ġh ello", "Ġhel lo", "Ġhell o", "f u", "H el", "He l", "r er", "re r", "h el", "he l", "w or", "wo r", "h ell", "he ll", "hel l", "f é", "m ai", "ma i", "H ell", "He ll", "Hel l", "' r"]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}, "behavior": "Isolated", "invert": false}, {"type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": false}]}, "decoder": {"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true, "use_regex": true}, "added_tokens": [{"id": 128000, "content": "<|begin_of_text|>", "special": true}, {"id": 128001, "content": "<|end_of_text|>", "special": true}]}
\ No newline at end of file
//go:build mlx
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
// - GPT-2 byte-level encoding (OpenAI tiktoken)
// - HuggingFace tokenizer.json pretokenizer patterns
// - SentencePiece ▁-style space handling
package tokenizer
import (
"encoding/json"
"fmt"
"os"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
// TokenizerType identifies the tokenization algorithm
type TokenizerType int
const (
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
TokenizerSentencePiece // SentencePiece with ▁ for spaces
TokenizerWordPiece // BERT style with ## continuations
)
// Vocabulary holds the tokenizer vocabulary and merges
type Vocabulary struct {
Values []string
Reverse map[string]int32
Merges map[string]int
BOS int32
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
PAD int32 // Padding token (often <|endoftext|> or <pad>)
AddBOS bool
AddEOS bool
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
byteTokens [256]int32
}
// Tokenizer handles BPE, SentencePiece, and WordPiece tokenization
type Tokenizer struct {
vocab *Vocabulary
pretokenizer *regexp.Regexp
specialTokens map[string]int32 // Special tokens for direct lookup
typ TokenizerType // Algorithm type
unkToken int32 // [UNK] token ID for WordPiece fallback
}
// Precomputed GPT-2 byte-level encoding table
// Maps byte values to their encoded rune equivalents
var byteToRune [256]rune
func init() {
for b := 0; b < 256; b++ {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
byteToRune[b] = r
}
}
// loadSpecialTokenConfig loads special token configuration from HuggingFace companion files.
//
// Loading priority for EOS tokens (can be single int or []int):
// 1. generation_config.json - eos_token_id (preferred, matches HuggingFace generation)
// 2. config.json - eos_token_id (model config fallback)
// 3. tokenizer_config.json - eos_token string + add_bos/add_eos flags
// 4. special_tokens_map.json - final fallback
func loadSpecialTokenConfig(dir string, t *Tokenizer) {
// Helper to parse eos_token_id which can be int or []int
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json (eos_token_id can be int or []int)
if data, err := os.ReadFile(dir + "generation_config.json"); err == nil {
var config struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(data, &config); err == nil {
if ids := parseTokenIDs(config.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(config.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json (model config, same format)
if len(t.vocab.EOS) == 0 || t.vocab.BOS < 0 {
if data, err := os.ReadFile(dir + "config.json"); err == nil {
var config struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(data, &config); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(config.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(config.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
}
// Priority 3: tokenizer_config.json (token strings + add_bos/add_eos flags)
if data, err := os.ReadFile(dir + "tokenizer_config.json"); err == nil {
var config struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(data, &config); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(config.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(config.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(config.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if config.AddBOSToken != nil {
t.vocab.AddBOS = *config.AddBOSToken
}
if config.AddEOSToken != nil {
t.vocab.AddEOS = *config.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json (final fallback)
if data, err := os.ReadFile(dir + "special_tokens_map.json"); err == nil {
var tokensMap map[string]interface{}
if err := json.Unmarshal(data, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
// Tokens can be represented as:
// - string: "token"
// - object: {"content": "token", ...}
func extractTokenString(v interface{}) string {
if v == nil {
return ""
}
// Direct string
if s, ok := v.(string); ok {
return s
}
// Object with content field
if m, ok := v.(map[string]interface{}); ok {
if content, ok := m["content"].(string); ok {
return content
}
}
return ""
}
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
// - (?!\S) negative lookahead - RE2 doesn't support this
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
//
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
func rewritePatternForRE2(pattern string) string {
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
// Expand case-insensitive contraction pattern to explicit alternations
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
return pattern
}
// Load loads a tokenizer from a path which can be:
// - A tokenizer.json file
// - A directory containing tokenizer.json or vocab.json + merges.txt
func Load(path string) (*Tokenizer, error) {
// Check if path is a directory
if info, err := os.Stat(path); err == nil && info.IsDir() {
dir := strings.TrimSuffix(path, "/") + "/"
// Try tokenizer.json first
if data, err := os.ReadFile(dir + "tokenizer.json"); err == nil {
return loadFromTokenizerJSON(data, dir)
}
// Fall back to vocab.json + merges.txt
return LoadVocabMerges(path)
}
// It's a file - read it directly
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer: %w", err)
}
// Get directory for loading companion files
dir := ""
if idx := strings.LastIndex(path, "/"); idx >= 0 {
dir = path[:idx+1]
}
return loadFromTokenizerJSON(data, dir)
}
// loadFromTokenizerJSON parses a tokenizer.json file
func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
var raw struct {
Model struct {
Type string `json:"type"` // "BPE" or "WordPiece"
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
}
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS)
// WordPiece models don't have merges
var mergesStrings []string
if raw.Model.Type != "WordPiece" && raw.Model.Merges != nil {
var mergesArrays [][]string
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
// Try array of arrays format
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
return nil, fmt.Errorf("failed to parse merges: %w", err)
}
// Convert [][]string to []string
mergesStrings = make([]string, len(mergesArrays))
for i, pair := range mergesArrays {
mergesStrings[i] = pair[0] + " " + pair[1]
}
}
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(raw.Model.Vocab)),
Reverse: raw.Model.Vocab,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Build values array
for token, id := range raw.Model.Vocab {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Add special tokens to vocabulary
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
if tok.Special {
t.specialTokens[tok.Content] = tok.ID
}
}
// Load special token configuration from companion files
loadSpecialTokenConfig(dir, t)
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// Determine tokenizer type
switch {
case raw.Model.Type == "WordPiece":
t.typ = TokenizerWordPiece
case detectSentencePiece(raw.Decoder):
t.typ = TokenizerSentencePiece
default:
t.typ = TokenizerBPE
}
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
if t.typ == TokenizerBPE {
pattern := extractPretokenizer(raw.PreTokenizer)
if pattern == "" {
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
}
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
}
t.pretokenizer = re
}
return t, nil
}
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
// vs GPT-2 byte-level encoding
func detectSentencePiece(data json.RawMessage) bool {
if data == nil {
return false
}
// Check for Sequence decoder with Replace step (SentencePiece style)
var seq struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(data, &seq); err == nil {
if seq.Type == "Sequence" {
for _, dec := range seq.Decoders {
// Look for Replace decoder that converts ▁ to space
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
return true
}
}
}
}
// Check for direct ByteLevel decoder (GPT-2 style)
var simple struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &simple); err == nil {
if simple.Type == "ByteLevel" {
return false
}
}
return false
}
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
func initByteTokens(t *Tokenizer) {
for i := range t.vocab.byteTokens {
t.vocab.byteTokens[i] = -1
}
for b := 0; b < 256; b++ {
token := fmt.Sprintf("<0x%02X>", b)
if id, ok := t.vocab.Reverse[token]; ok {
t.vocab.byteTokens[b] = id
}
}
}
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
func extractPretokenizer(data json.RawMessage) string {
if data == nil {
return ""
}
// Try to parse as a single Split pretokenizer
var single struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
return single.Pattern.Regex
}
// Try to parse as Sequence of pretokenizers - use first Split pattern
var seq struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
}
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
for _, pt := range seq.Pretokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
return pt.Pattern.Regex
}
}
}
return ""
}
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
func isNonNewlineWhitespace(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r == '\n' || r == '\r' {
return false
}
if !unicode.IsSpace(r) {
return false
}
}
return true
}
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
if len(t.specialTokens) == 0 {
return []string{s}
}
// Sort special tokens by length (longest first) to match greedily
tokens := make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
var result []string
remaining := s
for len(remaining) > 0 {
found := false
for _, tok := range tokens {
if strings.HasPrefix(remaining, tok) {
result = append(result, tok)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Find next special token position
nextPos := len(remaining)
for _, tok := range tokens {
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
nextPos = idx
}
}
if nextPos > 0 {
result = append(result, remaining[:nextPos])
}
remaining = remaining[nextPos:]
}
}
return result
}
// Encode tokenizes text to token IDs. Parallelizes for large inputs (>10KB).
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
// First: split by special tokens
parts := t.splitBySpecialTokens(s)
// Second: collect all pretokenizer chunks
type chunk struct {
text string
isSpecial bool
}
var allChunks []chunk
if t.pretokenizer != nil {
re := t.pretokenizer
for _, part := range parts {
if _, ok := t.specialTokens[part]; ok {
allChunks = append(allChunks, chunk{part, true})
continue
}
// Split by pretokenizer regex
type match struct{ start, end int }
var matches []match
offset := 0
for offset < len(part) {
loc := re.FindStringIndex(part[offset:])
if loc == nil {
break
}
matches = append(matches, match{offset + loc[0], offset + loc[1]})
offset += loc[1]
}
// Apply whitespace boundary fix for Python regex compatibility
for i := 0; i < len(matches)-1; i++ {
m := part[matches[i].start:matches[i].end]
next := part[matches[i+1].start:matches[i+1].end]
if isNonNewlineWhitespace(m) && len(next) > 0 {
firstRune, _ := utf8.DecodeRuneInString(next)
if unicode.IsLetter(firstRune) {
lastSpaceStart := matches[i].end
for j := matches[i].end; j > matches[i].start; {
r, size := utf8.DecodeLastRuneInString(part[matches[i].start:j])
if unicode.IsSpace(r) {
lastSpaceStart = j - size
break
}
j -= size
}
if lastSpaceStart > matches[i].start {
matches[i].end = lastSpaceStart
matches[i+1].start = lastSpaceStart
} else {
matches[i+1].start = matches[i].start
matches[i].end = matches[i].start
}
}
}
}
for _, m := range matches {
if m.end > m.start {
allChunks = append(allChunks, chunk{part[m.start:m.end], false})
}
}
}
} else {
// No pretokenizer - treat each part as a single chunk
for _, part := range parts {
if _, ok := t.specialTokens[part]; ok {
allChunks = append(allChunks, chunk{part, true})
} else {
allChunks = append(allChunks, chunk{part, false})
}
}
}
// Encode chunks - parallel for large inputs (>4KB), sequential otherwise
var ids []int32
if len(s) < 4096 {
for _, c := range allChunks {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
ids = append(ids, id)
}
} else {
ids = t.encodeChunkInto(c.text, ids)
}
}
} else {
numWorkers := runtime.GOMAXPROCS(0)
if numWorkers > len(allChunks) {
numWorkers = len(allChunks)
}
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
results := make([][]int32, numWorkers)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
start := i * chunksPer
end := start + chunksPer
if end > len(allChunks) {
end = len(allChunks)
}
if start >= end {
continue
}
wg.Add(1)
go func(i int, chunks []chunk) {
defer wg.Done()
var r []int32
for _, c := range chunks {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
r = append(r, id)
}
} else {
r = t.encodeChunkInto(c.text, r)
}
}
results[i] = r
}(i, allChunks[start:end])
}
wg.Wait()
for _, r := range results {
ids = append(ids, r...)
}
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// encodeChunkInto appends encoded tokens to ids and returns the extended slice
// Uses BPE merge algorithm when merges are available, otherwise longest-match
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
if t.typ == TokenizerWordPiece {
return t.encodeWordPieceInto(s, ids)
}
if s == "" {
return ids
}
// Apply encoding transformation
// SentencePiece: replace space with ▁
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
var encoded string
if t.typ == TokenizerSentencePiece {
encoded = strings.ReplaceAll(s, " ", "▁")
} else {
var sb strings.Builder
sb.Grow(len(s) * 2)
for i := 0; i < len(s); i++ {
sb.WriteRune(byteToRune[s[i]])
}
encoded = sb.String()
}
// Fast path: check if entire chunk is a single token
if id, ok := t.vocab.Reverse[encoded]; ok {
return append(ids, id)
}
return t.encodeBPEMerge(encoded, ids)
}
// encodeBPEMerge encodes using BPE merge algorithm.
// Repeatedly merges the pair with lowest rank until no more merges possible.
// Works correctly with empty merges (falls back to individual rune/byte encoding).
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
// Start with individual runes as parts
runes := []rune(encoded)
parts := make([]string, len(runes))
for i, r := range runes {
parts[i] = string(r)
}
// Repeatedly merge lowest-rank pair
for len(parts) > 1 {
minRank := int(0x7FFFFFFF)
minIdx := -1
for i := 0; i < len(parts)-1; i++ {
// Merge key format: "token1 token2" (space-separated)
mergeKey := parts[i] + " " + parts[i+1]
if rank, ok := t.vocab.Merges[mergeKey]; ok {
if rank < minRank {
minRank = rank
minIdx = i
}
}
}
if minIdx < 0 {
break // No more merges possible
}
// Merge the pair
parts[minIdx] = parts[minIdx] + parts[minIdx+1]
parts = append(parts[:minIdx+1], parts[minIdx+2:]...)
}
// Convert parts to token IDs
for _, part := range parts {
if id, ok := t.vocab.Reverse[part]; ok {
ids = append(ids, id)
} else {
// Byte fallback for unknown parts
for _, b := range []byte(part) {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
}
}
return ids
}
// encodeWordPieceInto appends WordPiece tokens to ids and returns extended slice
// Uses greedy longest-match with ## prefix for continuation tokens
func (t *Tokenizer) encodeWordPieceInto(s string, ids []int32) []int32 {
if s == "" {
return ids
}
// Check if entire string is in vocabulary (common case)
if id, ok := t.vocab.Reverse[s]; ok {
return append(ids, id)
}
runes := []rune(s)
start := 0
for start < len(runes) {
end := len(runes)
found := false
// Greedy longest-match
for end > start {
substr := string(runes[start:end])
if start > 0 {
// Continuation token: prefix with ##
substr = "##" + substr
}
if id, ok := t.vocab.Reverse[substr]; ok {
ids = append(ids, id)
found = true
start = end
break
}
end--
}
if !found {
// No match found - use [UNK] token or skip
if t.unkToken >= 0 {
ids = append(ids, t.unkToken)
}
start++
}
}
return ids
}
// Decode converts token IDs back to text
func (t *Tokenizer) Decode(ids []int32) string {
var sb strings.Builder
for _, id := range ids {
if int(id) >= len(t.vocab.Values) {
continue
}
token := t.vocab.Values[id]
switch t.typ {
case TokenizerWordPiece:
// WordPiece style: strip ## prefix from continuation tokens
if strings.HasPrefix(token, "##") {
sb.WriteString(token[2:])
} else {
sb.WriteString(token)
}
case TokenizerSentencePiece:
// SentencePiece style: replace ▁ with space, decode byte tokens
token = strings.ReplaceAll(token, "▁", " ")
// Handle byte fallback tokens like <0x0D>
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
sb.WriteByte(byte(v))
continue
}
}
sb.WriteString(token)
default:
// GPT-2 BPE style: decode byte-level encoding
for _, r := range token {
switch {
case r == 0x0100:
// NULL byte (0x00 encoded as 0x0100)
sb.WriteByte(0)
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// Write as byte, not UTF-8 encoded rune
sb.WriteByte(byte(r))
}
}
}
return sb.String()
}
// VocabSize returns the vocabulary size
func (t *Tokenizer) VocabSize() int {
return len(t.vocab.Values)
}
// BOS returns the beginning of sequence token ID
func (t *Tokenizer) BOS() int32 {
return t.vocab.BOS
}
// EOS returns the first end of sequence token ID (for backwards compatibility)
func (t *Tokenizer) EOS() int32 {
if len(t.vocab.EOS) > 0 {
return t.vocab.EOS[0]
}
return -1
}
// EOSTokens returns all end of sequence token IDs
func (t *Tokenizer) EOSTokens() []int32 {
return t.vocab.EOS
}
// PAD returns the padding token ID, or -1 if not set
func (t *Tokenizer) PAD() int32 {
return t.vocab.PAD
}
// IsEOS returns true if the token ID is an end of sequence token
func (t *Tokenizer) IsEOS(id int32) bool {
for _, eos := range t.vocab.EOS {
if id == eos {
return true
}
}
return false
}
// GetSpecialToken returns the token ID for a special token string
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
id, ok := t.specialTokens[name]
return id, ok
}
// LoadVocabMerges loads a tokenizer from vocab.json + merges.txt format (GPT-style)
func LoadVocabMerges(dir string) (*Tokenizer, error) {
vocabPath := dir + "/vocab.json"
mergesPath := dir + "/merges.txt"
addedTokensPath := dir + "/added_tokens.json"
// Load vocab
vocabData, err := os.ReadFile(vocabPath)
if err != nil {
return nil, fmt.Errorf("failed to read vocab.json: %w", err)
}
vocabMap := make(map[string]int32)
if err := json.Unmarshal(vocabData, &vocabMap); err != nil {
return nil, fmt.Errorf("failed to parse vocab.json: %w", err)
}
// Load merges
mergesData, err := os.ReadFile(mergesPath)
if err != nil {
return nil, fmt.Errorf("failed to read merges.txt: %w", err)
}
mergesLines := strings.Split(string(mergesData), "\n")
var mergesStrings []string
for _, line := range mergesLines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
mergesStrings = append(mergesStrings, line)
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(vocabMap)),
Reverse: vocabMap,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Load added tokens if exists
if addedData, err := os.ReadFile(addedTokensPath); err == nil {
addedMap := make(map[string]int32)
if err := json.Unmarshal(addedData, &addedMap); err == nil {
for token, id := range addedMap {
vocabMap[token] = id
t.specialTokens[token] = id
}
}
}
// Build values array
for token, id := range vocabMap {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Load special token configuration from companion files
loadSpecialTokenConfig(dir+"/", t)
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// GPT-2/tiktoken pretokenizer pattern
pattern := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex: %w", err)
}
t.pretokenizer = re
return t, nil
}
//go:build mlx
package tokenizer
import (
"bytes"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"sync"
"testing"
)
// TestPatternCompilation validates that HuggingFace pretokenizer patterns
// can be rewritten for Go's RE2 regexp engine and compiled successfully.
func TestPatternCompilation(t *testing.T) {
patterns := []struct {
name string
pattern string
}{
{"llama3", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
{"qwen2", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
{"gpt4o", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
{"gpt2", `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`},
{"deepseek_cjk", `[一-龥\x{3040}-ゟ゠-ヿ]+`},
}
for _, p := range patterns {
t.Run(p.name, func(t *testing.T) {
rewritten := rewritePatternForRE2(p.pattern)
if _, err := regexp.Compile(rewritten); err != nil {
t.Errorf("failed to compile pattern: %v\noriginal: %s\nrewritten: %s", err, p.pattern, rewritten)
}
})
}
}
// TestRoundtrip verifies the fundamental property: encode(text) -> decode -> text
// This is the key invariant from tiktoken's test suite.
func TestRoundtrip(t *testing.T) {
tok, err := Load("testdata/mini_llama.json")
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
// Test cases covering key edge cases from tiktoken
inputs := []string{
// Empty and simple
"",
"a",
"hello",
"hello world",
// Whitespace edge cases
" ",
" ",
" ",
" hello",
"hello ",
" hello ",
"hello world",
"hello world",
"\t",
"\n",
"\r\n",
"hello\nworld",
"hello\n\nworld",
// Contractions
"don't",
"I'm",
"we'll",
"they're",
"it's",
"DON'T", // uppercase
// Numbers
"123",
"1234567890",
"3.14159",
"$100",
"50%",
// Unicode
"こんにちは", // Japanese
"你好", // Chinese
"مرحبا", // Arabic (RTL)
"🎉", // Emoji
"Hello 世界", // Mixed
"café", // Accented
"naïve", // Diaeresis
"Ω≈ç√∫", // Math symbols
// Code
"func main() {}",
"if (x == 0) { return; }",
"import \"fmt\"",
"x := 42",
"// comment",
"/* block */",
// Repetition (tiktoken specifically tests this)
"aaaa",
"aaaaaaaaaaaa",
strings.Repeat("a", 100),
strings.Repeat("hello ", 50),
// Punctuation
"...",
"!!!",
"???",
"hello, world!",
"(parentheses)",
"[brackets]",
"{braces}",
// Mixed complexity
"The quick brown fox jumps over the lazy dog.",
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
"func TestRoundtrip(t *testing.T) { t.Run(\"test\", func(t *testing.T) {}) }",
}
for _, input := range inputs {
name := input
if len(name) > 30 {
name = name[:30] + "..."
}
if name == "" {
name = "<empty>"
}
name = strings.ReplaceAll(name, "\n", "\\n")
name = strings.ReplaceAll(name, "\t", "\\t")
t.Run(name, func(t *testing.T) {
tokens := tok.Encode(input, false)
decoded := tok.Decode(tokens)
if decoded != input {
t.Errorf("roundtrip failed:\n input: %q\n tokens: %v\n decoded: %q", input, tokens, decoded)
}
})
}
}
// TestSpecialTokens verifies that special tokens are handled correctly
func TestSpecialTokens(t *testing.T) {
tok, err := Load("testdata/mini_llama.json")
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
// Special tokens should be preserved through encode/decode
t.Run("bos_preserved", func(t *testing.T) {
if tok.BOS() < 0 {
t.Skip("no BOS token")
}
tokens := tok.Encode("hello", true)
if len(tokens) == 0 || tokens[0] != tok.BOS() {
t.Errorf("BOS not prepended: got %v, want first token to be %d", tokens, tok.BOS())
}
})
t.Run("special_token_split", func(t *testing.T) {
// If we have special tokens, verify they're split correctly
for tokenStr, tokenID := range tok.specialTokens {
input := "before" + tokenStr + "after"
tokens := tok.Encode(input, false)
found := false
for _, id := range tokens {
if id == tokenID {
found = true
break
}
}
if !found {
t.Errorf("special token %q (id=%d) not found in encoding of %q: %v",
tokenStr, tokenID, input, tokens)
}
}
})
}
// TestConcurrency verifies thread-safe encoding
func TestConcurrency(t *testing.T) {
tok, err := Load("testdata/mini_llama.json")
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "The quick brown fox jumps over the lazy dog."
expected := tok.Encode(input, false)
var wg sync.WaitGroup
errors := make(chan error, 100)
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
got := tok.Encode(input, false)
if len(got) != len(expected) {
errors <- nil // just signal error
return
}
for j := range got {
if got[j] != expected[j] {
errors <- nil
return
}
}
}()
}
wg.Wait()
close(errors)
if len(errors) > 0 {
t.Errorf("concurrent encoding produced inconsistent results")
}
}
// TestIntegration runs against real model directories, comparing with Python transformers.
// Skips if model weights are not available.
func TestIntegration(t *testing.T) {
models := []string{
"../weights/Llama-3.2-1B",
"../weights/gemma-3-1b-it",
"../weights/gpt-oss-20b",
}
// Test inputs covering various edge cases
inputs := []string{
"Hello, world!",
"The quick brown fox jumps over the lazy dog.",
"こんにちは世界",
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
"1234567890",
" spaces ",
"don't won't can't",
}
for _, modelPath := range models {
modelName := filepath.Base(modelPath)
t.Run(modelName, func(t *testing.T) {
tokenizerPath := filepath.Join(modelPath, "tokenizer.json")
if _, err := os.Stat(tokenizerPath); err != nil {
t.Skipf("skipping: %s not found", tokenizerPath)
}
tok, err := Load(tokenizerPath)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
for _, input := range inputs {
t.Run(truncate(input, 20), func(t *testing.T) {
// Test roundtrip
tokens := tok.Encode(input, false)
decoded := tok.Decode(tokens)
if decoded != input {
t.Errorf("roundtrip failed:\n input: %q\n decoded: %q", input, decoded)
}
// Compare with Python if available
if pythonTokens, err := pythonEncode(modelPath, input); err == nil {
if !equalInt32Slice(tokens, pythonTokens) {
t.Errorf("mismatch with Python:\n go: %v\n python: %v", tokens, pythonTokens)
}
}
})
}
})
}
}
// pythonEncode calls Python transformers to encode text, for comparison
func pythonEncode(modelPath, text string) ([]int32, error) {
script := `
import sys, json
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(sys.argv[1])
tokens = tok.encode(sys.argv[2], add_special_tokens=False)
print(json.dumps(tokens))
`
cmd := exec.Command("python3", "-c", script, modelPath, text)
var out bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = nil
if err := cmd.Run(); err != nil {
return nil, err
}
// Parse JSON array
var tokens []int32
output := strings.TrimSpace(out.String())
if output == "" || output == "[]" {
return []int32{}, nil
}
// Simple parsing for [1, 2, 3] format
output = strings.Trim(output, "[]")
if output == "" {
return []int32{}, nil
}
for _, s := range strings.Split(output, ",") {
s = strings.TrimSpace(s)
var v int32
if _, err := parseIntSimple(s, &v); err == nil {
tokens = append(tokens, v)
}
}
return tokens, nil
}
func parseIntSimple(s string, v *int32) (bool, error) {
var n int64
for _, c := range s {
if c >= '0' && c <= '9' {
n = n*10 + int64(c-'0')
}
}
*v = int32(n)
return true, nil
}
func equalInt32Slice(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
// TestBPEPretokenizer verifies BPE pretokenizer splits text correctly
// using the GPT-2 style regex pattern (no dependency on tokenizer files)
func TestBPEPretokenizer(t *testing.T) {
pattern := `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
re := regexp.MustCompile(rewritePatternForRE2(pattern))
tests := []struct {
input string
expected []string
}{
{"Hello", []string{"Hello"}},
{"Hello world", []string{"Hello", " world"}},
{"Hello, world!", []string{"Hello", ",", " world", "!"}},
{"don't", []string{"don", "'t"}},
{"I'm", []string{"I", "'m"}},
{"123", []string{"123"}},
{"12345", []string{"12345"}}, // GPT-2 pattern matches any digit sequence
{"a b", []string{"a", " ", " b"}}, // whitespace boundary: last space prepends to word
{" ", []string{" "}}, // pure whitespace stays together
{"\n\n", []string{"\n\n"}}, // newlines stay together
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
// Get regex matches
matches := re.FindAllStringIndex(tt.input, -1)
var chunks []string
for _, m := range matches {
chunks = append(chunks, tt.input[m[0]:m[1]])
}
// Apply whitespace boundary fix (same logic as Encode)
for i := 0; i < len(chunks)-1; i++ {
if isNonNewlineWhitespace(chunks[i]) && len(chunks[i+1]) > 0 {
r, _ := []rune(chunks[i+1])[0], 0
if r >= 'A' && r <= 'z' { // simplified letter check
// Move last space to next chunk
if len(chunks[i]) > 0 {
lastSpace := chunks[i][len(chunks[i])-1:]
chunks[i] = chunks[i][:len(chunks[i])-1]
chunks[i+1] = lastSpace + chunks[i+1]
}
}
}
}
// Filter empty chunks
var result []string
for _, c := range chunks {
if c != "" {
result = append(result, c)
}
}
if len(result) != len(tt.expected) {
t.Errorf("got %v, want %v", result, tt.expected)
return
}
for i := range result {
if result[i] != tt.expected[i] {
t.Errorf("chunk %d: got %q, want %q", i, result[i], tt.expected[i])
}
}
})
}
}
// TestSentencePiecePretokenizer verifies SentencePiece doesn't use pretokenizer
// and correctly replaces spaces with ▁ (no dependency on tokenizer files)
func TestSentencePiecePretokenizer(t *testing.T) {
// SentencePiece has no pretokenizer - whole text is one chunk
// Spaces are replaced with ▁ during encoding
tests := []struct {
input string
expected string // after space replacement
}{
{"Hello", "Hello"},
{"Hello world", "Hello▁world"},
{"Hello, world!", "Hello,▁world!"},
{" spaces ", "▁▁▁spaces▁▁▁"},
{" Hello", "▁Hello"},
{"Hello ", "Hello▁"},
{"a b c", "a▁b▁c"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
// SentencePiece encoding: replace space with ▁
result := strings.ReplaceAll(tt.input, " ", "▁")
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
// TestWordPiecePretokenizer verifies WordPiece (BERT) pretokenizer splits correctly
// BertPreTokenizer splits on whitespace and punctuation
func TestWordPiecePretokenizer(t *testing.T) {
// BertPreTokenizer behavior: split on whitespace and punctuation
// Whitespace is stripped, punctuation becomes separate tokens
tests := []struct {
input string
expected []string
}{
{"Hello", []string{"Hello"}},
{"Hello world", []string{"Hello", "world"}}, // whitespace stripped
{"Hello, world!", []string{"Hello", ",", "world", "!"}}, // punct separate
{"don't", []string{"don", "'", "t"}}, // apostrophe separate (unlike BPE)
{" spaces ", []string{"spaces"}}, // whitespace stripped
{"Hello.World", []string{"Hello", ".", "World"}}, // punct splits
{"test@email.com", []string{"test", "@", "email", ".", "com"}},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := splitBertStyle(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("got %v, want %v", result, tt.expected)
return
}
for i := range result {
if result[i] != tt.expected[i] {
t.Errorf("token %d: got %q, want %q", i, result[i], tt.expected[i])
}
}
})
}
}
// splitBertStyle mimics BertPreTokenizer: split on whitespace and punctuation
func splitBertStyle(s string) []string {
var result []string
var current strings.Builder
for _, r := range s {
if r == ' ' || r == '\t' || r == '\n' || r == '\r' {
// Whitespace: flush current token, don't add whitespace
if current.Len() > 0 {
result = append(result, current.String())
current.Reset()
}
} else if isPunct(r) {
// Punctuation: flush current, add punct as separate token
if current.Len() > 0 {
result = append(result, current.String())
current.Reset()
}
result = append(result, string(r))
} else {
current.WriteRune(r)
}
}
if current.Len() > 0 {
result = append(result, current.String())
}
return result
}
func isPunct(r rune) bool {
// Common ASCII punctuation
return (r >= '!' && r <= '/') || (r >= ':' && r <= '@') ||
(r >= '[' && r <= '`') || (r >= '{' && r <= '~')
}
// TestRepeatedDigits verifies correct tokenization of repeated digit sequences.
// Llama-style tokenizers split digits in groups of 1-3 due to the \p{N}{1,3} pattern.
func TestRepeatedDigits(t *testing.T) {
tok, err := Load("./testdata/mini_llama.json")
if err != nil {
t.Skipf("mini_llama.json not available: %v", err)
}
// Pattern: 1 digit, 2 digits, 3 digits, then repeats
// "0" -> [single], "00" -> [double], "000" -> [triple]
// "0000" -> [triple, single], etc.
tests := []struct {
input string
count int // expected token count
}{
{"0", 1},
{"00", 1},
{"000", 1},
{"0000", 2}, // 3 + 1
{"00000", 2}, // 3 + 2
{"000000", 2}, // 3 + 3
{"0000000", 3},
{"00000000", 3},
{"000000000", 3},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
ids := tok.Encode(tt.input, false)
if len(ids) != tt.count {
t.Errorf("Encode(%q) = %d tokens, want %d", tt.input, len(ids), tt.count)
}
// Verify roundtrip
decoded := tok.Decode(ids)
if decoded != tt.input {
t.Errorf("Decode(Encode(%q)) = %q", tt.input, decoded)
}
})
}
}
// TestNullByte verifies that null bytes roundtrip correctly
func TestNullByte(t *testing.T) {
tok, err := Load("./testdata/mini_llama.json")
if err != nil {
t.Skipf("mini_llama.json not available: %v", err)
}
ids := tok.Encode("\x00", false)
decoded := tok.Decode(ids)
if decoded != "\x00" {
t.Errorf("null byte roundtrip failed: got %q, want %q", decoded, "\x00")
}
}
// TestTokenizerTypeDetection verifies correct detection of tokenizer types
func TestTokenizerTypeDetection(t *testing.T) {
tests := []struct {
name string
decoder string
expected TokenizerType
}{
{
name: "ByteLevel decoder (BPE)",
decoder: `{"type": "ByteLevel"}`,
expected: TokenizerBPE,
},
{
name: "Sequence with Replace ▁ (SentencePiece)",
decoder: `{
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "}
]
}`,
expected: TokenizerSentencePiece,
},
{
name: "null decoder (BPE default)",
decoder: `null`,
expected: TokenizerBPE,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isSPM := detectSentencePiece([]byte(tt.decoder))
var got TokenizerType
if isSPM {
got = TokenizerSentencePiece
} else {
got = TokenizerBPE
}
if got != tt.expected {
t.Errorf("got %v, want %v", got, tt.expected)
}
})
}
}
// TestPADTokenDefault verifies PAD() returns -1 when not configured
func TestPADTokenDefault(t *testing.T) {
tok, err := Load("testdata/mini_llama.json")
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
// mini_llama.json has no PAD token configured, should return -1
if got := tok.PAD(); got != -1 {
t.Errorf("PAD() = %d, want -1 (not configured)", got)
}
}
// TestPADTokenFromConfig verifies PAD token is loaded from tokenizer_config.json
func TestPADTokenFromConfig(t *testing.T) {
// Create temp directory with tokenizer files
dir := t.TempDir()
// Write minimal tokenizer.json
tokenizerJSON := `{
"model": {
"type": "BPE",
"vocab": {"<|endoftext|>": 0, "hello": 1, "world": 2},
"merges": []
},
"added_tokens": [
{"id": 0, "content": "<|endoftext|>", "special": true}
]
}`
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
t.Fatalf("failed to write tokenizer.json: %v", err)
}
// Write tokenizer_config.json with pad_token
configJSON := `{
"pad_token": "<|endoftext|>"
}`
if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write tokenizer_config.json: %v", err)
}
tok, err := Load(dir)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
if got := tok.PAD(); got != 0 {
t.Errorf("PAD() = %d, want 0 (<|endoftext|>)", got)
}
}
// TestPADTokenFromSpecialTokensMap verifies PAD falls back to special_tokens_map.json
func TestPADTokenFromSpecialTokensMap(t *testing.T) {
dir := t.TempDir()
// Write minimal tokenizer.json
tokenizerJSON := `{
"model": {
"type": "BPE",
"vocab": {"<pad>": 0, "hello": 1, "world": 2},
"merges": []
},
"added_tokens": [
{"id": 0, "content": "<pad>", "special": true}
]
}`
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
t.Fatalf("failed to write tokenizer.json: %v", err)
}
// Write special_tokens_map.json with pad_token
mapJSON := `{
"pad_token": "<pad>"
}`
if err := os.WriteFile(filepath.Join(dir, "special_tokens_map.json"), []byte(mapJSON), 0o644); err != nil {
t.Fatalf("failed to write special_tokens_map.json: %v", err)
}
tok, err := Load(dir)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
if got := tok.PAD(); got != 0 {
t.Errorf("PAD() = %d, want 0 (<pad>)", got)
}
}
// TestPADTokenWithContentObject verifies PAD token works with {"content": "..."} format
func TestPADTokenWithContentObject(t *testing.T) {
dir := t.TempDir()
// Write minimal tokenizer.json
tokenizerJSON := `{
"model": {
"type": "BPE",
"vocab": {"[PAD]": 0, "hello": 1},
"merges": []
},
"added_tokens": [
{"id": 0, "content": "[PAD]", "special": true}
]
}`
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
t.Fatalf("failed to write tokenizer.json: %v", err)
}
// Write tokenizer_config.json with pad_token as object (HuggingFace format)
configJSON := `{
"pad_token": {"content": "[PAD]", "lstrip": false, "normalized": false}
}`
if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write tokenizer_config.json: %v", err)
}
tok, err := Load(dir)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
if got := tok.PAD(); got != 0 {
t.Errorf("PAD() = %d, want 0 ([PAD])", got)
}
}
// Benchmarks
func BenchmarkEncode(b *testing.B) {
tok, err := Load("testdata/mini_llama.json")
if err != nil {
b.Fatalf("failed to load tokenizer: %v", err)
}
inputs := []struct {
name string
text string
}{
{"short", "Hello, world!"},
{"medium", "The quick brown fox jumps over the lazy dog. " + strings.Repeat("This is a test. ", 10)},
{"long", strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)},
}
for _, input := range inputs {
b.Run(input.name, func(b *testing.B) {
b.SetBytes(int64(len(input.text)))
for i := 0; i < b.N; i++ {
tok.Encode(input.text, false)
}
})
}
}
func BenchmarkDecode(b *testing.B) {
tok, err := Load("testdata/mini_llama.json")
if err != nil {
b.Fatalf("failed to load tokenizer: %v", err)
}
text := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)
tokens := tok.Encode(text, false)
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok.Decode(tokens)
}
}
package kvcache
import (
"errors"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}
package kvcache
// import (
// "errors"
// "fmt"
// "log/slog"
// "math"
// "slices"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
// // Causal cache stores K and V tensors according to their position in the
// // sequence. Returns the history and a mask for attending to past tokens
// //
// // The tensors are of shape embed dim, kv heads, batch size
// // The mask is of shape history size, batch size
// type Causal struct {
// DType ml.DType
// // swaWindowSize is the number of tokens that will be included in the mask
// // during attention operations. swaMemorySize is the number of tokens that
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
// // for unlimited or if sliding window attention is not being used.
// swaWindowSize int32
// swaMemorySize int32
// chunkSize int32
// opts CausalOptions
// // maxBatch is the largest batch that we might receive
// maxBatch int
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // size of the current batch
// curBatchSize int
// // locations for data storage for this batch
// curLoc ml.Tensor
// // mask of the cache as used by this batch
// curMask ml.Tensor
// // the active layer for Get and Put
// curLayer int
// // locations in the cache that are needed for this batch
// curCellRange cellRange
// // curSequences is the sequences corresponding to this pass's entries in the cache
// curSequences []int
// // curPositions is the positions corresponding to this pass's entries in the cache
// curPositions []int32
// // ** cache metadata **
// // for each possible location in the cache, stores the position and set of sequences
// // that reference the data there
// cells []cacheCell
// // maps from sequence to the range of locations where it is stored in the cache
// cellRanges map[int]cellRange
// // ** cache data storage **
// shiftFn shiftFn
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// kHeadDims, vHeadDims, numKVHeads map[int]int
// }
// type cacheCell struct {
// pos int32
// sequences []int
// }
// type cellRange struct {
// min int
// max int
// }
// func NewCausalCache(shift shiftFn) *Causal {
// return &Causal{
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
// return &Causal{
// swaWindowSize: windowSize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
// return &Causal{
// swaWindowSize: windowSize,
// swaMemorySize: memorySize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
// return &Causal{
// chunkSize: chunkSize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if c.config.CachePadding == 0 {
// c.config.CachePadding = 1
// }
// if c.config.MaskBatchPadding == 0 {
// c.config.MaskBatchPadding = 1
// }
// // TODO what types do we handle here?
// // if c.config.MaskDType == ml.DTypeOther {
// // c.config.MaskDType = ml.DTypeFloat32
// // }
// if c.swaWindowSize == 0 {
// c.swaWindowSize = math.MaxInt32
// }
// if c.swaMemorySize == 0 {
// c.swaMemorySize = c.swaWindowSize
// }
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
// // causing a cache break. As an optimization, only do this when we have parallel sequences
// // because the extra token will live in the batch buffer and won't get overwritten if we
// // only have a single sequence.
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
// }
// if int(c.swaMemorySize) >= capacity {
// c.swaMemorySize = math.MaxInt32
// }
// if c.swaMemorySize < c.swaWindowSize {
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
// }
// var cacheSize int
// if c.swaMemorySize == math.MaxInt32 {
// cacheSize = maxSequences * capacity
// } else {
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
// }
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
// c.cells = make([]cacheCell, cacheSize)
// c.DType = dtype
// c.cellRanges = make(map[int]cellRange)
// c.backend = backend
// c.maxBatch = maxBatch
// }
// func (c *Causal) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
// // panic("XXX Causal.StartForward")
// c.curBatchSize = len(batch.Positions)
// c.curSequences = batch.Sequences
// c.curPositions = batch.Positions
// c.opts.Except = nil
// var locs []int32
// if !reserve {
// c.updateSlidingWindow()
// var err error
// locs, err = c.findLocs()
// if err != nil {
// return err
// }
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
// for i, pos := range batch.Positions {
// seq := batch.Sequences[i]
// loc := int(locs[i])
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
// seqRange, ok := c.cellRanges[seq]
// if !ok {
// seqRange = newRange()
// }
// seqRange.min = min(seqRange.min, loc)
// c.curCellRange.min = min(c.curCellRange.min, loc)
// seqRange.max = max(seqRange.max, loc)
// c.curCellRange.max = max(c.curCellRange.max, loc)
// c.cellRanges[seq] = seqRange
// }
// } else {
// // If we are reserving memory, don't update any of the cache metadata but set the size
// // to the worst case.
// locs = make([]int32, c.curBatchSize)
// for i := range locs {
// locs[i] = int32(i)
// }
// c.curCellRange.min = 0
// c.curCellRange.max = len(c.cells) - 1
// }
// // XXX Building up the locs for what's already processed (if any)
// dummyLocs := []int{}
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
// for i := range c.curBatchSize {
// enabled := !slices.Contains(c.opts.Except, i)
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
// } else {
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
// dummyLocs = append(dummyLocs, i)
// }
// }
// }
// }
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
// slog.Info("XXX Causal.StartForward", "locs", locs)
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
// c.curMask = c.buildMask(ctx)
// return nil
// }
// func newRange() cellRange {
// return cellRange{
// min: math.MaxInt,
// max: 0,
// }
// }
// // Returns a slice of locations where each token in the batch should be stored
// func (c *Causal) findLocs() ([]int32, error) {
// loc := make([]int32, 0, c.curBatchSize)
// for i := range c.cells {
// if len(c.cells[i].sequences) == 0 {
// loc = append(loc, int32(i))
// if len(loc) >= c.curBatchSize {
// return loc, nil
// }
// }
// }
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
// }
// func (c *Causal) updateSlidingWindow() {
// c.curCellRange = newRange()
// if c.swaMemorySize == math.MaxInt32 {
// for _, seq := range c.curSequences {
// if seqRange, ok := c.cellRanges[seq]; ok {
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
// }
// }
// return
// }
// type lowestPosition struct {
// pos int32
// curBatch bool
// }
// // create a map of unique sequences to the lowest position in that sequence
// lowestPos := make(map[int]lowestPosition)
// for i := range c.curPositions {
// seq := c.curSequences[i]
// lowest, ok := lowestPos[seq]
// if !ok {
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
// } else if c.curPositions[i] < lowest.pos {
// lowest.pos = c.curPositions[i]
// }
// lowestPos[seq] = lowest
// }
// // for any sequences are not part of this batch, clean up any tokens
// // that are no longer needed after the processing of the previous
// // batch
// for seq, seqRange := range c.cellRanges {
// if _, ok := lowestPos[seq]; !ok {
// var last int32
// for i := seqRange.min; i <= seqRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// last = max(last, c.cells[i].pos)
// }
// }
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
// }
// }
// // delete any entries that are beyond the window of the oldest position in the sequence
// for seq, lowest := range lowestPos {
// oldRange, ok := c.cellRanges[seq]
// if !ok {
// continue
// }
// newRange := newRange()
// for i := oldRange.min; i <= oldRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
// } else {
// newRange.min = min(newRange.min, i)
// newRange.max = max(newRange.max, i)
// }
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
// c.curCellRange.min = min(c.curCellRange.min, i)
// c.curCellRange.max = max(c.curCellRange.max, i)
// }
// }
// }
// c.cellRanges[seq] = newRange
// }
// }
// func roundDown(length, pad int) int {
// return (length / pad) * pad
// }
// func roundUp(length, pad int) int {
// return ((length + pad - 1) / pad) * pad
// }
// // Builds a mask of history x batch indicating whether for each token in the batch the
// // token in the history should apply. This is based on both the sequence and causality (the
// // position of the history is not ahead of the token in the batch).
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
// // Align and pad the two dimensions as required by the backend
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
// length := c.curCellRange.max - c.curCellRange.min + 1
// mask := make([]float32, batchSize*length)
// for i := range c.curBatchSize {
// enabled := !slices.Contains(c.opts.Except, i)
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
// }
// }
// }
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
// // has already been masked out because the sequence doesn't match.
// for i := c.curBatchSize * length; i < len(mask); i++ {
// mask[i] = float32(math.Inf(-1))
// }
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
// // if c.config.MaskDType != ml.DTypeFloat32 {
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
// // }
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
// return maskTensor
// }
// func (c *Causal) SetLayer(layer int) {
// c.curLayer = layer
// }
// type CausalOptions struct {
// // Enabled controls whether the causal mask is generated for a particular index in a batch
// Except []int
// }
// // SetCausal disables causal mask generation for a particular range of indicies in
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
// if !slices.Equal(c.opts.Except, opts.Except) {
// c.opts = opts
// if ctx != nil {
// c.curMask = c.buildMask(ctx)
// }
// }
// }
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// key := c.keys[c.curLayer]
// value := c.values[c.curLayer]
// kHeadDim := c.kHeadDims[c.curLayer]
// vHeadDim := c.vHeadDims[c.curLayer]
// numKVHeads := c.numKVHeads[c.curLayer]
// // rowSize := numKVHeads * c.curBatchSize
// // cachedSize := c.curMask.Dim(1)
// cachedSize := c.curLoc.Dim(0)
// // kCellSize := kHeadDim * numKVHeads
// // vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get full cache", "key", key)
// slog.Info("XXX Causal.Get full cache", "value", value)
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
// // panic("XXX")
// // fmt.Fprintln(os.Stderr, key.ToString())
// // panic("full cache value")
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
// // panic("XXX")
// // if c.config.PermutedV {
// // panic("permuted")
// // // TODO not converted
// // vHeadDim := value.Dim(1)
// // elemSize := value.Stride(2)
// // value = value.AsStrided(ctx,
// // []int{numKVHeads, vHeadDim, cachedSize},
// // []int{value.Stride(0), value.Stride(1)},
// // elemSize*c.curCellRange.min,
// // )
// // } else {
// // vHeadDim := c.vHeadDims[c.curLayer]
// // rowSize := value.Stride(2)
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
// // panic("XXX")
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
// // panic("XXX")
// // }
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
// // // the 1 becomes trailing and messes up later operations
// // // This isn't the right solution, but works around it...
// // if c.curMask.Dim(1) == 1 {
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
// // }
// // fmt.Fprintln(os.Stderr, key.ToString())
// // fmt.Fprintln(os.Stderr, value.ToString())
// // panic("XXX")
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
// return key, value, c.curMask
// }
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
// kHeadDim := key.Dim(3)
// vHeadDim := value.Dim(3)
// numKVHeads := key.Dim(1)
// batchSize := key.Dim(2)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
// // panic("XXX")
// if c.curBatchSize != batchSize {
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
// }
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
// if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
// c.kHeadDims[c.curLayer] = kHeadDim
// c.vHeadDims[c.curLayer] = vHeadDim
// c.numKVHeads[c.curLayer] = numKVHeads
// }
// if _, ok := c.values[c.curLayer]; !ok {
// // if c.config.PermutedV {
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
// // } else {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
// // }
// }
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
// // panic("XXX")
// // curLoc := 0 // TODO c.curLoc is now a tensor
// // kSize := numKVHeads * kHeadDim
// // vSize := numKVHeads * vHeadDim
// // start := []int{int(curLoc), 0}
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
// // strides := []int{1, 1}
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
// // panic("input value")
// // fmt.Fprintln(os.Stderr, t.ToString())
// // panic("XXX")
// // if c.config.PermutedV {
// // panic("permuted")
// // // TODO not adjusted
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
// // value = value.Transpose(ctx, 2, 0, 1, 3)
// // valueCache := c.values[c.curLayer]
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
// // } else {
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
// // }
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
// // panic("XXX")
// }
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
// seqRange := newRange()
// for i := range c.cells {
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
// if slices.Contains(c.cells[i].sequences, dstSeq) {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
// }
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
// if i < seqRange.min {
// seqRange.min = i
// }
// if i > seqRange.max {
// seqRange.max = i
// }
// }
// }
// c.cellRanges[dstSeq] = seqRange
// }
// func (c *Causal) CanResume(seq int, pos int32) bool {
// if c.swaMemorySize == math.MaxInt32 {
// return true
// }
// seqRange, ok := c.cellRanges[seq]
// if !ok {
// return false
// }
// // for sliding window, check that the window of the new sequence is contained in
// // the window of what we are storing
// var first int32 = math.MaxInt32
// var last int32 = -1
// for i := seqRange.min; i <= seqRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// first = min(first, c.cells[i].pos)
// last = max(last, c.cells[i].pos)
// }
// }
// if last == -1 {
// return false
// }
// posWindowStart := max(0, pos-c.swaWindowSize)
// return posWindowStart >= first && pos <= last+1
// }
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
// if c.shiftFn == nil {
// return ErrNotSupported
// }
// seqRange := c.cellRanges[seq]
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
// size := min(seqRange.max-start+1, c.maxBatch)
// offsets := make([]int32, size)
// var batchFirst, batchLast int
// batchFirst = -1
// for i := range offsets {
// cell := c.cells[start+i]
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
// offsets[i] = offset
// if batchFirst < 0 {
// batchFirst = i
// }
// batchLast = i
// }
// }
// if batchFirst < 0 {
// continue
// }
// offsets = offsets[batchFirst : batchLast+1]
// slog.Info("XXX Causal.shift creating new temporary context")
// ctx := c.backend.NewContext()
// kShift := ctx.Input().FromInts(offsets, len(offsets))
// for i, key := range c.keys {
// if key == nil {
// continue
// }
// kHeadDim := key.Dim(2)
// numKVHeads := key.Dim(1)
// rowSize := key.Stride(0)
// key = key.AsStrided(ctx,
// []int{len(offsets), numKVHeads, kHeadDim},
// []int{key.Stride(0), key.Stride(1)},
// rowSize*(start+batchFirst),
// )
// roped, err := c.shiftFn(ctx, i, key, kShift)
// if err != nil {
// ctx.Close()
// return err
// }
// ctx.Forward(roped.Copy(ctx, key))
// }
// ctx.Compute()
// ctx.Close()
// }
// return nil
// }
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
// // should return an error, which will trigger the runner to evaluate the full history and
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
// // results in use after free, so we don't do it for now.
// var offset int32
// if endIndex != math.MaxInt32 {
// offset = beginIndex - endIndex
// }
// seqRange := newRange()
// for i := range c.cells {
// if slices.Contains(c.cells[i].sequences, seq) {
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
// } else {
// if c.cells[i].pos >= endIndex {
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
// return errors.New("shifting cells shared by multiple sequences not supported")
// }
// c.cells[i].pos += offset
// }
// if i < seqRange.min {
// seqRange.min = i
// }
// if i > seqRange.max {
// seqRange.max = i
// }
// }
// }
// }
// if seqRange == newRange() {
// delete(c.cellRanges, seq)
// return nil
// }
// c.cellRanges[seq] = seqRange
// if endIndex != math.MaxInt32 {
// err := c.shift(seq, endIndex+offset, offset)
// if err != nil {
// return err
// }
// }
// return nil
// }
package kvcache
// import (
// "fmt"
// "math"
// "slices"
// "testing"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type testCase struct {
// name string
// in []float32
// inShape []int
// seqs []int
// pos []int32
// expected []float32
// expectedShape []int
// expectedMask []float32
// }
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
// t.Helper()
// for _, permuted := range []bool{false, true} {
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
// fn(t, &testBackend{permutedV: permuted})
// })
// }
// }
// func TestStore(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// inShape: []int{2, 3, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// expectedShape: []int{2, 3, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{115, 215, 125, 225, 135, 235},
// inShape: []int{2, 3, 1},
// seqs: []int{0},
// pos: []int32{4},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
// expectedShape: []int{2, 3, 5},
// expectedMask: []float32{0, 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWA(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 6, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWASeparateBatches(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "First seq 0",
// in: []float32{1, 2},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{0, 1},
// expected: []float32{1, 2},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 0",
// in: []float32{3, 4},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{2, 3},
// expected: []float32{2, 3, 4},
// expectedShape: []int{1, 1, 3},
// expectedMask: []float32{
// 0, 0, x,
// x, 0, 0,
// },
// },
// {
// name: "First seq 1",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{0, 1},
// expected: []float32{5, 6},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 1",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{2, 3},
// expected: []float32{6, 3, 4, 7, 8},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// x, x, x, 0, 0,
// },
// },
// {
// name: "Third seq 0",
// in: []float32{9, 10},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{9, 10, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWAMemCache(1, 3, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 2, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestChunkedAttention(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewChunkedAttentionCache(2, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// testCache(
// t, backend, cache,
// []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6, 7},
// inShape: []int{1, 1, 3},
// seqs: []int{0, 0, 0},
// pos: []int32{4, 5, 6},
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
// expectedShape: []int{1, 1, 7},
// expectedMask: []float32{
// x, x, x, x, 0, x, x,
// x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, 0,
// },
// },
// {
// name: "ThirdBatch",
// in: []float32{8, 9},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{7, 8},
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
// expectedShape: []int{1, 1, 9},
// expectedMask: []float32{
// x, x, x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, x, x, 0,
// },
// },
// },
// )
// })
// }
// func TestSequences(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{2, 2},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestRemove(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// return key.Add(ctx, shift), nil
// })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err := cache.Remove(0, 1, math.MaxInt32)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveEnd",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{1, 2},
// expected: []float32{1, 5, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, 0, x, x, x,
// x, x, 0, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err = cache.Remove(0, 0, 1)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveMiddle",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{1, 2},
// expected: []float32{7, 4, 3, 4, 6, 8},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{
// 0, 0, x, x, x, x,
// 0, 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestCopy(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// cache.CopyPrefix(0, 1, 2)
// tests = []testCase{
// {
// name: "Copy",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{3, 4},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
// for _, test := range tests {
// t.Run(test.name, func(t *testing.T) {
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
// if err != nil {
// panic(err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats(test.in, test.inShape...)
// cache.Put(context, tensor, tensor)
// out, _, mask := cache.Get(context)
// context.Forward(out, mask).Compute(out, mask)
// if !slices.Equal(out.Floats(), test.expected) {
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
// }
// if !slices.Equal(out.Shape(), test.expectedShape) {
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
// }
// if !slices.Equal(mask.Floats(), test.expectedMask) {
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
// }
// })
// }
// }
// func TestCanResume(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// cache := NewSWACache(windowSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4},
// Sequences: []int{0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
// cache.Put(context, tensor, tensor)
// // with window size 4, nothing has slid out of the window yet
// if !cache.CanResume(0, 0) {
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
// }
// if !cache.CanResume(0, 1) {
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
// }
// if !cache.CanResume(0, 2) {
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
// }
// if !cache.CanResume(0, 3) {
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
// }
// if !cache.CanResume(0, 4) {
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
// }
// // shift window by adding position 5
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{5},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
// }
// })
// }
// func TestCanResumeSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// memSize := int32(5)
// cache := NewSWAMemCache(windowSize, memSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
// cache.Put(context, tensor, tensor)
// // shift window by adding position 7
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{7},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 6) {
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
// }
// if !cache.CanResume(0, 7) {
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
// }
// })
// }
// type testBackend struct {
// ml.Backend
// permutedV bool
// }
// func (b *testBackend) NewContext() ml.Context {
// return &testContext{}
// }
// func (b *testBackend) NewContextSize(int) ml.Context {
// return &testContext{}
// }
// func (b *testBackend) CacheConfig() ml.CacheConfig {
// return ml.CacheConfig{PermutedV: b.permutedV}
// }
// type testContext struct {
// ml.Context
// }
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
// total := 0
// if len(shape) > 0 {
// total = 1
// for _, s := range shape {
// total *= s
// }
// }
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
// }
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
// return c.Empty(dtype, shape...)
// }
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
// copy(t.data, s)
// return t
// }
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
// f := make([]float32, len(s))
// for i := range f {
// f[i] = float32(s[i])
// }
// out := c.FromFloats(f, shape...)
// out.(*testTensor).dtype = ml.DTypeI32
// return out
// }
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
// s := make([]float32, 0, int((stop-start)/step))
// for i := start; i < stop; i += step {
// s = append(s, i)
// }
// out := c.FromFloats(s, len(s))
// out.(*testTensor).dtype = dtype
// return out
// }
// func (c *testContext) Input() ml.Context { return c }
// func (c *testContext) Layer(int) ml.Context { return c }
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
// func (c *testContext) Compute(...ml.Tensor) {}
// func (c *testContext) Reserve() {}
// func (c *testContext) MaxGraphNodes() int {
// return 10
// }
// func (c *testContext) Close() {}
// type testTensor struct {
// ml.Tensor
// dtype ml.DType
// elementSize int
// data []float32
// shape []int
// }
// func (t *testTensor) Dim(n int) int {
// return t.shape[n]
// }
// func (t *testTensor) Stride(n int) int {
// stride := t.elementSize
// for i := range n {
// stride *= t.shape[i]
// }
// return stride
// }
// func (t *testTensor) Shape() []int {
// return t.shape
// }
// func (t *testTensor) DType() ml.DType {
// return t.dtype
// }
// func (t *testTensor) Floats() []float32 {
// out := make([]float32, len(t.data))
// copy(out, t.data)
// return out
// }
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = -t.data[i]
// }
// return out
// }
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
// }
// return out
// }
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: t.data,
// shape: shape,
// }
// }
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
// offset /= t.elementSize
// var s []int
// switch len(shape) {
// case 1:
// s = []int{shape[0]}
// case 3:
// s = []int{shape[0], shape[2]}
// case 5:
// s = []int{shape[0], shape[2], shape[4]}
// default:
// panic("unsupported number of dimensions")
// }
// context := &testContext{}
// view := context.Empty(t.dtype, s...).(*testTensor)
// view.data = t.data[offset : offset+len(view.data)]
// return view
// }
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
// if len(t.shape) > 4 || len(order) > 4 {
// panic("permute only supports up to 4 dimensions")
// }
// if len(order) != len(t.shape) && len(order) != 4 {
// panic("invalid number of dimensions for permute")
// }
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
// orderFull := append(make([]int, 0, 4), order...)
// for len(orderFull) < 4 {
// orderFull = append(orderFull, len(orderFull))
// }
// seen := [4]bool{}
// shape4 := [4]int{1, 1, 1, 1}
// for i := 0; i < len(t.shape) && i < 4; i++ {
// shape4[i] = t.shape[i]
// }
// newShape4 := [4]int{1, 1, 1, 1}
// for axis := range 4 {
// dst := orderFull[axis]
// if dst < 0 || dst >= 4 {
// panic("invalid axis for permute")
// }
// if seen[dst] {
// panic("duplicate axis for permute")
// }
// seen[dst] = true
// newShape4[dst] = shape4[axis]
// }
// total := len(t.data)
// newData := make([]float32, total)
// if total > 0 {
// oldDims := shape4
// newDims := newShape4
// oldStride := [4]int{1, 1, 1, 1}
// newStride := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
// newStride[i] = newStride[i-1] * newDims[i-1]
// }
// var coords [4]int
// var newCoords [4]int
// for idx := range total {
// remainder := idx
// for axis := range 4 {
// dim := oldDims[axis]
// if dim == 0 {
// coords[axis] = 0
// continue
// }
// coords[axis] = remainder % dim
// remainder /= dim
// }
// for axis := range 4 {
// newCoords[orderFull[axis]] = coords[axis]
// }
// newIndex := 0
// for axis := range 4 {
// if newDims[axis] == 0 {
// continue
// }
// newIndex += newCoords[axis] * newStride[axis]
// }
// newData[newIndex] = t.data[idx]
// }
// }
// numDims := 4
// for numDims > 1 && newShape4[numDims-1] <= 1 {
// numDims--
// }
// newShape := make([]int, numDims)
// copy(newShape, newShape4[:numDims])
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: newData,
// shape: newShape,
// }
// }
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
// dst := t
// srcTensor := src.(*testTensor)
// idxTensor := idxs.(*testTensor)
// shapeTo4D := func(shape []int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 0; i < len(shape) && i < 4; i++ {
// out[i] = shape[i]
// }
// return out
// }
// computeStrides := func(shape [4]int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// out[i] = out[i-1] * shape[i-1]
// }
// return out
// }
// dstShape4D := shapeTo4D(dst.shape)
// srcShape4D := shapeTo4D(srcTensor.shape)
// idxShape4D := shapeTo4D(idxTensor.shape)
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
// panic("SetRows requires matching tensor shapes")
// }
// if srcShape4D[1] != idxShape4D[0] {
// panic("SetRows rows/index mismatch")
// }
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
// panic("SetRows cannot broadcast indices")
// }
// if idxShape4D[3] != 1 {
// panic("SetRows expects 1D or 2D index tensors")
// }
// dstStride := computeStrides(dstShape4D)
// srcStride := computeStrides(srcShape4D)
// idxStride := computeStrides(idxShape4D)
// numColumns := srcShape4D[0]
// numRows := srcShape4D[1]
// for dim3Index := range dstShape4D[3] {
// for dim2Index := range dstShape4D[2] {
// idxDim2 := 0
// idxDim3 := 0
// if idxShape4D[1] > 0 {
// idxDim2 = dim2Index % idxShape4D[1]
// }
// if idxShape4D[2] > 0 {
// idxDim3 = dim3Index % idxShape4D[2]
// }
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
// for row := range numRows {
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
// if idx < 0 || idx >= dstShape4D[1] {
// panic("SetRows index out of range")
// }
// srcOffset := srcBase + row*srcStride[1]
// dstOffset := dstBase + idx*dstStride[1]
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
// }
// }
// }
// return dst
// }
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// copy(t2.(*testTensor).data, t.data)
// return nil
// }
package kvcache
// import (
// "fmt"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Encoder cache stores K and V tensors that are position independent
// //
// // The tensors can be of any shape and will be returned as they were stored
// // The mask is currently always nil
// //
// // Not currently safe for multiple sequences
// type EncoderCache struct {
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // the active layer for Get and Put
// curLayer int
// // if something is stored during this pass, this
// // will be the position (but there is no guarantee
// // anything will be stored)
// curPos int32
// // curReserve indicates that this forward pass is only for
// // memory reservation and we should not update our metadata
// // based on it.
// curReserve bool
// // ** cache metadata **
// // was something stored in the cache?
// encoderCached bool
// // position of the cached data
// encoderPos int32
// // ** cache data storage **
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// }
// func NewEncoderCache() *EncoderCache {
// return &EncoderCache{
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// }
// }
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if maxSequences > 1 {
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
// }
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
// }
// c.backend = backend
// }
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *EncoderCache) Close() {
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// // We work with the most recent image
// if len(batch.Multimodal) > 0 {
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
// }
// c.curReserve = reserve
// return nil
// }
// func (c *EncoderCache) SetLayer(layer int) {
// c.curLayer = layer
// }
// func (c *EncoderCache) EncoderCached() bool {
// return c.encoderCached
// }
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.keys[c.curLayer], c.values[c.curLayer], nil
// }
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
// if !c.curReserve {
// c.encoderPos = c.curPos
// c.encoderCached = true
// }
// if c.config.PermutedV {
// value = value.Transpose(ctx, 1, 2, 0, 3)
// }
// if _, ok := c.ctxs[c.curLayer]; !ok {
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
// }
// if _, ok := c.values[c.curLayer]; !ok {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
// }
// ctx.Forward(
// key.Copy(ctx, c.keys[c.curLayer]),
// value.Copy(ctx, c.values[c.curLayer]),
// )
// }
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// panic("encoder cache does not support multiple sequences")
// }
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
// return true
// }
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
// c.encoderCached = false
// }
// return nil
// }
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type MLXCausal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewMLXCausalCache() *MLXCausal {
return &MLXCausal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
func (c *MLXCausal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *MLXCausal) Close() {
// slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX MLXCausal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *MLXCausal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}
package kvcache
// import (
// "math"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Wrapper cache is a container for multiple types of caches,
// // such as for the encoding and decoding portions of a model.
// type WrapperCache struct {
// // caches we are wrapping
// caches []Cache
// // cache to be used for this layer
// curType int
// }
// func NewWrapperCache(caches ...Cache) *WrapperCache {
// return &WrapperCache{
// caches: caches,
// }
// }
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// for _, cache := range c.caches {
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
// }
// }
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
// for _, cache := range c.caches {
// cache.SetConfig(config)
// }
// }
// func (c *WrapperCache) Close() {
// for _, cache := range c.caches {
// cache.Close()
// }
// }
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// for i, cache := range c.caches {
// err := cache.StartForward(ctx, batch, reserve)
// if err != nil {
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
// for j := i - 1; j >= 0; j-- {
// for k := range batch.Positions {
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
// }
// }
// return err
// }
// }
// c.curType = 0
// return nil
// }
// func (c *WrapperCache) SetLayer(layer int) {
// for _, cache := range c.caches {
// cache.SetLayer(layer)
// }
// }
// func (c *WrapperCache) SetLayerType(layerType int) {
// c.curType = layerType
// }
// func (c *WrapperCache) UnderlyingCache() Cache {
// return c.caches[c.curType]
// }
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.caches[c.curType].Get(ctx)
// }
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
// c.caches[c.curType].Put(ctx, key, value)
// }
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// for _, cache := range c.caches {
// cache.CopyPrefix(srcSeq, dstSeq, len)
// }
// }
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
// for _, cache := range c.caches {
// if !cache.CanResume(seq, pos) {
// return false
// }
// }
// return true
// }
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
// for _, cache := range c.caches {
// err := cache.Remove(seq, beginIndex, endIndex)
// if err != nil {
// return err
// }
// }
// return nil
// }
package ml
import (
"fmt"
"log/slog"
"os"
"github.com/ollama/ollama/fs"
)
type Backend interface {
// Close frees all memory associated with this backend
// Close()
// Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model
// BackendMemory() BackendMemory
Config() fs.Config
Get(name string) Tensor
NewContext() Context
// NewContextSize(size int) Context
// Enumerate the devices available for inference via this backend
// BackendDevices() []DeviceInfo
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// AllocMemory causes the backend to allocate memory for the model. If
// false, this is only being used for discovering the required amount of
// memory and cannot load the model for running.
AllocMemory bool
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
// GPULayers is the set of layers to offload to GPUs
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
backends[name] = f
}
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
be := os.Getenv("OLLAMA_BACKEND")
if be == "" {
be = "mlx"
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
}
slog.Info("Loading new engine", "backend", be)
if backend, ok := backends[be]; ok {
return backend(modelPath, params)
}
return nil, fmt.Errorf("unsupported backend")
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
FromFloats(s []float32, shape ...int) Tensor
FromInts(s []int32, shape ...int) Tensor
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange(start, stop, step float32, dtype DType) Tensor
Forward(...Tensor) Context
// SetBatchSize provides a hint on the batch size to optimize processing
// Uses heuristics if not set
// SetBatchSize(int)
Compute(...Tensor)
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
// Reserve()
// MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating tensors that are
// inputs to the model (which includes things like output locations)
Input() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
// Load a tensor from "filename" safetensors file, and compare with the input tensor
// Returns error if the shape is inconsistent, or similarity measures are below 99%
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
}
type RoPEOptions struct {
Base *float32
Freqs Tensor
}
func WithRoPEBase(base float32) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Base = &base
}
}
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Freqs = freqs
}
}
type Tensor interface {
ToString() string
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
Dim(n int) int
Stride(n int) int
Shape() []int
DType() DType
// Cast(ctx Context, dtype DType) Tensor
// Bytes() []byte
Floats() []float32
Ints() []int32
// FromBytes([]byte)
// FromFloats([]float32)
// FromInts([]int32)
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
// Mul(ctx Context, t2 Tensor) Tensor
// Div(ctx Context, t2 Tensor) Tensor
Max(ctx Context, axes []int, keepDims bool) Tensor
Min(ctx Context, axes []int, keepDims bool) Tensor
Matmul(ctx Context, a2 Tensor) Tensor
// Mulmat(ctx Context, t2 Tensor) Tensor
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
// MulmatID(ctx Context, t2, ids Tensor) Tensor
// AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
// SumRows(ctx Context) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
// Sin(ctx Context) Tensor
// Cos(ctx Context) Tensor
// Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
// QuickGELU(ctx Context, up ...Tensor) Tensor
// SILU(ctx Context, up ...Tensor) Tensor
// RELU(ctx Context, up ...Tensor) Tensor
// Sigmoid(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
Transpose(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, allowColMajor bool) Tensor
// Pad(ctx Context, shape ...int) Tensor
// Stack(ctx Context, dim int, s ...Tensor) Tensor
// Repeat repeats the tensor n times along dimension dim
// Repeat(ctx Context, dim, n int) Tensor
// Concat(ctx Context, t2 Tensor, dim int) Tensor
// Rows(ctx Context, t2 Tensor) Tensor
// TODO these probably aren't actually needed - false starts on trying to wire up cache
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
Copy(ctx Context, t2 Tensor) Tensor
// Duplicate(ctx Context) Tensor
// Slice(ctx Context, dim, low, high, step int) Tensor
// Chunk(ctx Context, dim int, size int) []Tensor
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
// TopK(ctx Context, k int) Tensor
// Argsort(ctx Context) Tensor
// Mean(ctx Context) Tensor
// Variance(ctx Context) Tensor
// Stddev(ctx Context) Tensor
// Sqr(ctx Context) Tensor
// Sqrt(ctx Context) Tensor
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// type ScaledDotProductAttention interface {
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
// }
// type number interface {
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
// ~float32 | ~float64 |
// ~complex64 | ~complex128
// }
// func mul[T number](s ...T) T {
// p := T(1)
// for _, v := range s {
// p *= v
// }
// return p
// }
// type DumpOptions func(*dumpOptions)
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
// func DumpWithPrecision(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Precision = n
// }
// }
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
// // beginning and end of each dimension will be printed.
// func DumpWithThreshold(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Threshold = n
// }
// }
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
// func DumpWithEdgeItems(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.EdgeItems = n
// }
// }
// type dumpOptions struct {
// Precision, Threshold, EdgeItems int
// }
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
// for _, optsFunc := range optsFuncs {
// optsFunc(&opts)
// }
// if mul(t.Shape()...) <= opts.Threshold {
// opts.EdgeItems = math.MaxInt
// }
// switch t.DType() {
// case DTypeFloat32:
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeFloat16: // TODO other types...
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
// f32 = t.Copy(ctx, f32)
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeInt32:
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
// return strconv.FormatInt(int64(i), 10)
// })
// default:
// return "<unsupported>"
// }
// }
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
// if t.Bytes() == nil {
// ctx.Compute(t)
// }
// s := make(S, mul(t.Shape()...))
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
// panic(err)
// }
// shape := t.Shape()
// slices.Reverse(shape)
// var sb strings.Builder
// var f func([]int, int)
// f = func(dims []int, stride int) {
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
// sb.WriteString("[")
// defer func() { sb.WriteString("]") }()
// for i := 0; i < dims[0]; i++ {
// if i >= items && i < dims[0]-items {
// sb.WriteString("..., ")
// // skip to next printable element
// skip := dims[0] - 2*items
// if len(dims) > 1 {
// stride += mul(append(dims[1:], skip)...)
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
// }
// i += skip - 1
// } else if len(dims) > 1 {
// f(dims[1:], stride)
// stride += mul(dims[1:]...)
// if i < dims[0]-1 {
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
// }
// } else {
// text := fn(s[stride+i])
// if len(text) > 0 && text[0] != '-' {
// sb.WriteString(" ")
// }
// sb.WriteString(text)
// if i < dims[0]-1 {
// sb.WriteString(", ")
// }
// }
// }
// }
// f(shape, 0)
// return sb.String()
// }
type DType int
const (
DTypeBool DType = iota
DTypeUint8
DTypeUint16
DTypeUint32
DTypeUint64
DTypeInt8
DTypeInt16
DTypeInt32
DTypeInt64
DTypeFloat16
DTypeFloat32
DTypeFloat64
DTypeBfloat16
DTypeComplex64
)
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)
package backend
// _ "github.com/ollama/ollama/x/ml/backend/mlx"
include(FetchContent)
set(MLX_C_BUILD_EXAMPLES OFF)
set(MLX_BUILD_GGUF OFF)
set(MLX_BUILD_SAFETENSORS ON)
function(set_target_output_directory _target)
if(TARGET ${_target})
set_target_properties(${_target} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
)
endif()
endfunction()
# Check for Metal support (macOS only)
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(NOT MLX_METAL_VERSION)
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
else()
# On Linux, disable Metal backend
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
endif()
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
elseif(MLX_CUDA_ARCHITECTURES)
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG v0.4.1)
FetchContent_MakeAvailable(mlx-c)
set_target_output_directory(mlx)
set_target_output_directory(mlxc)
//go:build mlx
package mlx
/*
#cgo CPPFLAGS: -I${SRCDIR}/../../../../build/_deps/mlx-c-src
#cgo LDFLAGS: -L${SRCDIR}/../../../../build/lib/ollama/ -lmlxc -lmlx
#cgo LDFLAGS: -framework Accelerate
#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../../build/lib/ollama/
#include <stdlib.h>
#include "mlx/c/mlx.h"
static inline size_t stride(const mlx_array a, int i) {return mlx_array_strides(a)[i];}
extern void goStackTrace();
static void error_handler(const char *msg, void* data) {
fprintf(stderr, "MLX error: %s\n", msg);
goStackTrace();
exit(-1); // TODO adjust so this can become a return code on the current thread instead of exit
}
static void set_error_handler() {mlx_set_error_handler(&error_handler, NULL, NULL);}
static void* mlx_array_data_float16_asvoid(const mlx_array a) {return (void*)mlx_array_data_float16(a);}
typedef const char cchar_t;
*/
import "C"
import (
"encoding/json"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"reflect"
"runtime"
"runtime/debug"
"sync"
"unsafe"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/x448/float16"
)
func init() {
ml.RegisterBackend("mlx", New)
C.set_error_handler()
}
//export goStackTrace
func goStackTrace() {
debug.PrintStack()
}
type SafetensorsIndexMetadata struct {
TotalSize uint64 `json:"total_size"`
}
type SafetensorsIndex struct {
Metadata SafetensorsIndexMetadata `json:"metadata"`
WeightMap map[string]string `json:"weight_map"`
}
type Backend struct {
meta fs.Config
tensors map[string]*Array
}
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
// TODO assumes modelPath is actually a directory for now...
kv, tokenizer, err := convert.LoadModelMetadata(os.DirFS(modelPath))
if err != nil {
return nil, fmt.Errorf("unable to load model: %w", err)
}
b := &Backend{
meta: kv.KV(tokenizer),
}
err = b.LoadSafeTensors(modelPath)
if err != nil {
return nil, fmt.Errorf("safetensors load failed: %w", err)
}
return b, nil
}
func (b *Backend) LoadSafeTensors(dir string) error {
if _, err := os.Stat(dir); err != nil {
return fmt.Errorf("failed to stat dir: %w", err)
}
// other variations to try?
stFilename := filepath.Join(dir, "model.safetensors.index.json")
if _, err := os.Stat(stFilename); err != nil {
return fmt.Errorf("failed to stat %s: %w", stFilename, err)
}
fp, err := os.Open(stFilename)
if err != nil {
return fmt.Errorf("failed to open safetensor index: %s: %w", stFilename, err)
}
decoder := json.NewDecoder(fp)
var index SafetensorsIndex
if err := decoder.Decode(&index); err != nil {
return fmt.Errorf("decode error: %s: %w", stFilename, err)
}
slog.Info("XXX parsed metadata", "size", index.Metadata.TotalSize, "weights", len(index.WeightMap))
filenames := map[string]struct{}{}
for _, filename := range index.WeightMap {
filenames[filename] = struct{}{}
}
stream := C.mlx_default_cpu_stream_new()
b.tensors = map[string]*Array{}
for filename := range filenames {
filepath := filepath.Join(dir, filename)
if _, err := os.Stat(filepath); err != nil {
return fmt.Errorf("failed to stat %s: %w", filepath, err)
}
slog.Info("Loading tensors from", "filename", filename)
cFilename := C.CString(filepath)
defer C.free(unsafe.Pointer(cFilename))
data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it?
metadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_array_free(data)
defer C.mlx_map_string_to_string_free(metadata)
if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 {
// TODO with the current error handling, this will never happen
return fmt.Errorf("load failed")
}
it := C.mlx_map_string_to_array_iterator_new(data)
// defer C.mlx_array_free(shaped)
// TODO confusing how memory management works with this...
for {
var key *C.cchar_t
var value C.mlx_array
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
k := C.GoString((*C.char)(key))
b.tensors[k] = &Array{
name: k,
a: value,
}
// slog.Info("XXX read", "tensor", b.tensors[k], "type", b.tensors[k].TypeString())
}
}
return nil
}
func (b *Backend) Get(name string) ml.Tensor {
var t ml.Tensor
var ok bool
if t, ok = b.tensors[name]; !ok {
// slog.Warn("unable to locate", "tensor", name)
return nil
}
// slog.Info("Fetching", "tensor", name, "type", b.tensors[name].TypeString())
return t
}
func (b *Backend) NewContext() ml.Context {
// slog.Info("MLX.NewContext")
return &Context{
stream: C.mlx_default_gpu_stream_new(),
}
}
func (b *Backend) Config() fs.Config {
return b.meta
}
type Context struct {
stream C.mlx_stream
mu sync.Mutex
arrays []C.mlx_array // TODO should we do some bookkeeping to ensure none of these Arrays are still lingering?
}
func (c *Context) Close() {
// C.mlx_synchronize(c.stream) // ???
C.mlx_stream_free(c.stream)
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.arrays {
slog.Info("XXX freeing", "array", a)
C.mlx_array_free(a)
}
}
func (c *Context) Compute(tensors ...ml.Tensor) {
// TODO - for the zero tensor case this feels like it might not be correct...
needSync := true
sync := func() {
if needSync {
C.mlx_synchronize(c.stream)
needSync = false
}
}
vec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vec)
for _, t := range tensors {
C.mlx_vector_array_append_value(vec, t.(*Array).a)
t.(*Array).sync = sync
}
C.mlx_async_eval(vec)
}
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
vec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vec)
needSync := true
sync := func() {
if needSync {
C.mlx_synchronize(c.stream)
needSync = false
}
}
for _, t := range tensors {
t.(*Array).sync = sync
C.mlx_vector_array_append_value(vec, t.(*Array).a)
}
C.mlx_async_eval(vec)
return c
}
func (c *Context) Input() ml.Context {
return c
}
// func (c *Context) Output() ml.Context {
// return c
// }
func (c *Context) Layer(_ int) ml.Context {
return c
}
func (c *Context) RandomNormal(shape []int, dtype ml.DType, loc, scale float32, key ml.Tensor) ml.Tensor {
var r C.mlx_array
var k C.mlx_array
if key != nil {
k = key.(*Array).a
}
sh := make([]C.int, len(shape))
for i := range shape {
sh[i] = C.int(shape[i])
}
C.mlx_random_normal(
&r,
&sh[0],
C.size_t(len(shape)),
C.mlx_dtype(dtype),
C.float(loc),
C.float(scale),
k,
c.stream,
)
return newArray(c, r)
}
func (c *Context) CompareWith(filepath string, tensors map[string]ml.Tensor, abortOnError bool) (err error) {
minCosine := float32(0.96) // TODO too low...
fileTensors := map[string]*Array{}
defer func() {
if err != nil {
for k, v := range tensors {
fmt.Fprintln(os.Stderr, "input tensor "+k+"\n"+v.ToString())
if fv, ok := fileTensors[k]; ok {
fmt.Fprintln(os.Stderr, " file tensor "+k+"\n"+fv.ToString())
} else {
fmt.Fprintln(os.Stderr, " file tensor "+k+" missing!\n")
}
}
}
if abortOnError {
if err != nil {
panic(fmt.Sprintf("%s", err))
}
}
}()
if _, err = os.Stat(filepath); err != nil {
filepath += ".safetensors"
if _, err = os.Stat(filepath); err != nil {
err = fmt.Errorf("failed to stat %s: %w", filepath, err)
return
}
err = nil
}
// slog.Info("Loading tensors from", "filename", filepath)
cFilename := C.CString(filepath)
defer C.free(unsafe.Pointer(cFilename))
data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it?
metadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_array_free(data)
defer C.mlx_map_string_to_string_free(metadata)
stream := C.mlx_default_cpu_stream_new()
if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 {
// TODO with the current error handling, this will never happen
err = fmt.Errorf("load failed")
return
}
it := C.mlx_map_string_to_array_iterator_new(data)
allTensors := []ml.Tensor{}
for _, t := range tensors {
allTensors = append(allTensors, t)
}
for {
var key *C.cchar_t
var value C.mlx_array
defer C.mlx_array_free(value)
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
k := C.GoString((*C.char)(key))
var r C.mlx_array
defer C.mlx_array_free(r)
C.mlx_astype(
&r,
value,
C.MLX_FLOAT32,
stream,
)
fileTensors[k] = &Array{
name: k,
a: r,
}
// slog.Info("XXX read", "tensor", t, "type", t.TypeString())
allTensors = append(allTensors, fileTensors[k])
}
c.Forward(allTensors...)
for k, t := range tensors {
a, ok := fileTensors[k]
if !ok {
err = fmt.Errorf("tensor named %s not found in file", k)
return
}
if !reflect.DeepEqual(a.Shape(), t.Shape()) {
err = fmt.Errorf("mismatched shapes: file: %v vs. input %v", a.Shape(), t.Shape())
return
}
// slog.Info("XXX shapes match", "shape", t.Shape())
// TODO handle int types...
tDType := t.DType()
if tDType != ml.DTypeFloat16 && tDType != ml.DTypeFloat32 {
var r C.mlx_array
defer C.mlx_array_free(r)
C.mlx_astype(
&r,
t.(*Array).a,
C.MLX_FLOAT32,
stream,
)
t = &Array{
a: r,
}
c.Forward(t)
}
af := a.Floats()
tf := t.Floats()
cos := cosineSimilarity(af, tf)
diff := a.Sub(c, t)
min := diff.Min(c, nil, true)
max := diff.Max(c, nil, true)
c.Forward(min, max)
minf := min.Floats()
maxf := max.Floats()
if cos < minCosine {
err = fmt.Errorf("%s shapes match, but not similar enough: %v min_difference=%v max_difference=%v", k, cos, minf, maxf)
return
}
slog.Info("XXX tensors are similar", k, cos, "shape", t.Shape(), "min_difference", minf, "max_difference", maxf)
}
err = nil
return
}
func dotProduct[V float32 | float64](v1, v2 []V) V {
var result V = 0
if len(v1) != len(v2) {
return result
}
for i := 0; i < len(v1); i++ {
result += v1[i] * v2[i]
}
return result
}
func magnitude[V float32 | float64](v []V) V {
var result V = 0
for _, val := range v {
result += val * val
}
return V(math.Sqrt(float64(result)))
}
func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
mag1 := magnitude(v1)
mag2 := magnitude(v2)
if mag1 == 0 || mag2 == 0 {
return 0
}
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
}
func euclideanDistance[V float32 | float64](v1, v2 []V) V {
if len(v1) != len(v2) {
return V(math.Inf(1))
}
var sum V = 0
for i := 0; i < len(v1); i++ {
diff := v1[i] - v2[i]
sum += diff * diff
}
return V(math.Sqrt(float64(sum)))
}
func manhattanDistance[V float32 | float64](v1, v2 []V) V {
if len(v1) != len(v2) {
return V(math.Inf(1))
}
var sum V = 0
for i := 0; i < len(v1); i++ {
sum += V(math.Abs(float64(v1[i] - v2[i])))
}
return sum
}
type Array struct {
name string
a C.mlx_array
c *Context
sync func()
}
func newArray(ctx *Context, a C.mlx_array) *Array {
// TODO measure impact and if this slows things down, make it conditional on some debugging flag at load time
var name string
_, f, l, ok := runtime.Caller(2)
if ok {
name = fmt.Sprintf("%s:%d", f, l)
}
t := &Array{
name: name,
a: a,
c: ctx,
}
// DEBUG memory allocation problems...
// slog.Info("XXX Allocated", "array", t, "a", a)
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.arrays = append(ctx.arrays, a)
return t
}
// FromFloats implements ml.Context.
func (c *Context) FromFloats(s []float32, shape ...int) ml.Tensor {
u16s := make([]float16.Float16, len(s))
for i := range u16s {
u16s[i] = float16.Fromfloat32(s[i])
}
cshape := make([]C.int, len(shape))
for i, dim := range shape {
cshape[i] = C.int(dim)
}
return newArray(c,
C.mlx_array_new_data(
unsafe.Pointer(&u16s[0]),
&cshape[0],
C.int(len(cshape)),
C.MLX_FLOAT16,
),
)
}
func (a *Array) Floats() []float32 {
if a.sync != nil {
a.sync()
}
l := (int)(C.mlx_array_size(a.a))
switch C.mlx_array_dtype(a.a) {
case C.MLX_BFLOAT16:
panic("bfloat16 not yet implemented")
case C.MLX_FLOAT16:
data := C.mlx_array_data_float16_asvoid(a.a)
if data == nil {
panic("nil data, wasn't eval'd")
}
u16s := unsafe.Slice((*uint16)(data), l)
f32s := make([]float32, len(u16s))
for i := range u16s {
f32s[i] = float16.Frombits(u16s[i]).Float32()
}
return f32s
case C.MLX_FLOAT32:
data := C.mlx_array_data_float32(a.a)
if data == nil {
panic("nil data, wasn't eval'd")
}
f32s := unsafe.Slice((*float32)(data), l)
return f32s
default:
panic(fmt.Sprintf("unsupported dtype for Floats: %d", C.mlx_array_dtype(a.a)))
}
}
// FromInts implements ml.Context.
func (c *Context) FromInts(s []int32, shape ...int) ml.Tensor {
cshape := make([]C.int, len(shape))
for i, dim := range shape {
cshape[i] = C.int(dim)
}
return newArray(c,
C.mlx_array_new_data(
unsafe.Pointer(&s[0]),
&cshape[0],
C.int(len(cshape)),
C.MLX_INT32,
),
)
}
func (a *Array) Ints() []int32 {
if a.sync != nil {
a.sync()
}
l := (int)(C.mlx_array_size(a.a))
switch C.mlx_array_dtype(a.a) {
case C.MLX_INT32:
data := C.mlx_array_data_int32(a.a)
if data == nil {
panic("nil data, wasn't eval'd")
}
i32s := unsafe.Slice((*int32)(data), l)
return i32s
// TODO other types via conversion?
default:
panic(fmt.Sprintf("unsupported dtype for Ints: %d", C.mlx_array_dtype(a.a)))
}
}
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
sh := make([]C.int, len(shape))
for i, s := range shape {
sh[i] = (C.int)(s)
}
var r C.mlx_array
C.mlx_zeros(
&r,
&sh[0],
(C.size_t)(len(sh)),
C.mlx_dtype(dtype),
c.stream,
)
return newArray(c, r)
}
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
// TODO more efficient impl?
return c.Zeros(dtype, shape...)
}
func (a *Array) DType() ml.DType {
return (ml.DType)(C.mlx_array_dtype(a.a))
}
func (a *Array) Dim(n int) int {
return int(C.mlx_array_dim(a.a, C.int(n)))
}
func (a *Array) Stride(n int) int {
return (int)(C.stride(a.a, (C.int)(n)))
}
func (c *Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
var r C.mlx_array
C.mlx_arange(
&r,
C.double(start),
C.double(stop),
C.double(step),
(C.mlx_dtype)(dtype),
c.stream,
)
return newArray(c, r)
}
// Scale implements ml.Tensor.
func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor {
scale := C.mlx_array_new_float(C.float(s))
var r C.mlx_array
C.mlx_multiply(
&r,
a.a,
scale,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) Softmax(ctx ml.Context) ml.Tensor {
var r C.mlx_array
C.mlx_softmax(
&r,
a.a,
false, // TODO - precise?
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) SliceUpdate(ctx ml.Context, update ml.Tensor, start, stop, strides []int) ml.Tensor {
cStart := make([]C.int, len(start))
for i := range start {
cStart[i] = C.int(start[i])
}
cStop := make([]C.int, len(stop))
for i := range stop {
cStop[i] = C.int(stop[i])
}
cStrides := make([]C.int, len(strides))
for i := range strides {
cStrides[i] = C.int(strides[i])
}
var r C.mlx_array
C.mlx_slice_update(
&r,
a.a,
update.(*Array).a,
(*C.int)(unsafe.Pointer(&cStart[0])),
C.size_t(len(cStart)),
(*C.int)(unsafe.Pointer(&cStop[0])),
C.size_t(len(cStop)),
(*C.int)(unsafe.Pointer(&cStrides[0])),
C.size_t(len(cStrides)),
ctx.(*Context).stream,
)
// Release the old array and replace with the new one to ensure the same underlying buffer is used
a.c.mu.Lock()
defer a.c.mu.Unlock()
for i := range a.c.arrays {
if a.c.arrays[i] == a.a {
C.mlx_array_free(a.a)
a.a = r
a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
return a
}
}
panic("unable to locate array in context")
}
func (a *Array) SliceUpdateDynamic(ctx ml.Context, update, start ml.Tensor, axes []int) ml.Tensor {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
var r C.mlx_array
C.mlx_slice_update_dynamic(
&r,
a.a,
update.(*Array).a,
start.(*Array).a,
(*C.int)(unsafe.Pointer(&cAxes[0])),
C.size_t(len(cAxes)),
ctx.(*Context).stream,
)
// Release the old array and replace with the new one to ensure the same underlying buffer is used
a.c.mu.Lock()
defer a.c.mu.Unlock()
for i := range a.c.arrays {
if a.c.arrays[i] == a.a {
C.mlx_array_free(a.a)
a.a = r
a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
return a
}
}
panic("unable to locate array in context")
}
func (a *Array) PutAlongAxis(ctx ml.Context, indicies, values ml.Tensor, axis int) ml.Tensor {
var r C.mlx_array
C.mlx_put_along_axis(
&r,
a.a,
indicies.(*Array).a,
values.(*Array).a,
C.int(axis),
ctx.(*Context).stream,
)
// Release the old array and replace with the new one to ensure the same underlying buffer is used
a.c.mu.Lock()
defer a.c.mu.Unlock()
for i := range a.c.arrays {
if a.c.arrays[i] == a.a {
C.mlx_array_free(a.a)
a.a = r
a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
return a
}
}
panic("unable to locate array in context")
}
func (a *Array) Scatter(ctx ml.Context, indicies []ml.Tensor, updates ml.Tensor, axes []int) ml.Tensor {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
var cAxes0 *C.int
if len(cAxes) > 0 {
cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
}
indiciesVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(indiciesVec)
for _, ind := range indicies {
C.mlx_vector_array_append_value(indiciesVec, ind.(*Array).a)
}
var r C.mlx_array
C.mlx_scatter(
&r,
a.a,
indiciesVec,
updates.(*Array).a,
cAxes0,
C.size_t(len(cAxes)),
ctx.(*Context).stream,
)
// Release the old array and replace with the new one to ensure the same underlying buffer is used
a.c.mu.Lock()
defer a.c.mu.Unlock()
for i := range a.c.arrays {
if a.c.arrays[i] == a.a {
C.mlx_array_free(a.a)
a.a = r
a.c.arrays[i] = r
return a
}
}
panic("unable to locate array in context")
}
func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
C.mlx_copy(
&a2.(*Array).a,
a.a,
ctx.(*Context).stream,
)
// TODO - view?
return newArray(ctx.(*Context), a2.(*Array).a)
}
func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
var r C.mlx_array
C.mlx_add(
&r,
a.a,
a2.(*Array).a,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) Sub(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
var r C.mlx_array
C.mlx_subtract(
&r,
a.a,
a2.(*Array).a,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) Max(ctx ml.Context, axes []int, keepDims bool) ml.Tensor {
var r C.mlx_array
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
var cAxes0 *C.int
if len(cAxes) > 0 {
cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
C.mlx_max_axes(
&r,
a.a,
cAxes0,
C.size_t(len(cAxes)),
C._Bool(keepDims),
ctx.(*Context).stream,
)
} else {
C.mlx_max(
&r,
a.a,
C._Bool(keepDims),
ctx.(*Context).stream,
)
}
return newArray(ctx.(*Context), r)
}
func (a *Array) Min(ctx ml.Context, axes []int, keepDims bool) ml.Tensor {
var r C.mlx_array
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
var cAxes0 *C.int
if len(cAxes) > 0 {
cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
C.mlx_min_axes(
&r,
a.a,
cAxes0,
C.size_t(len(cAxes)),
C._Bool(keepDims),
ctx.(*Context).stream,
)
} else {
C.mlx_min(
&r,
a.a,
C._Bool(keepDims),
ctx.(*Context).stream,
)
}
return newArray(ctx.(*Context), r)
}
func (a *Array) Matmul(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
var r C.mlx_array
C.mlx_matmul(
&r,
a.a,
a2.(*Array).a,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
// slog.Info("MLX.RMSNorm", "a", a, "w", w)
var r C.mlx_array
C.mlx_fast_rms_norm(
&r,
a.a,
w.(*Array).a,
C.float(eps),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
var r C.mlx_array
C.mlx_fast_layer_norm(
&r,
a.a,
w.(*Array).a,
b.(*Array).a,
C.float(eps),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
// TODO implement
panic("NOT YET IMPLEMENTED")
}
func (t Array) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("NOT YET IMPLEMENTED")
}
// RoPE implements Rotary Positional Encoding
//
// dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.
// traditional (bool) – If set to True choose the traditional implementation which rotates consecutive dimensions.
// scale (float) – The scale used to scale the positions.
// offset (int) – The position offset to start at. TODO MLX-C does not yet expose Offset as an Array
// WithBase (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of base and freqs must be None.
// WithFreqs (array, optional) – Optional frequencies to use with RoPE. If set, the base parameter must be None. Default: None.
func (a *Array) RoPE(ctx ml.Context, dims int, traditional bool, scale float32, offset int, options ...func(*ml.RoPEOptions)) ml.Tensor {
opts := ml.RoPEOptions{}
// Apply any provided options
for _, option := range options {
option(&opts)
}
var r C.mlx_array
var base C.mlx_optional_float
var freqs C.mlx_array
if opts.Base != nil {
base.value = C.float(*opts.Base)
base.has_value = true
}
if opts.Freqs != nil {
freqs = opts.Freqs.(*Array).a
}
C.mlx_fast_rope(
&r,
a.a,
C.int(dims),
C._Bool(traditional),
base,
C.float(scale),
C.int(offset),
freqs,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
// A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.
//
// Supports:
// - Multi-Head Attention
// - Grouped Query Attention
// - Multi-Query Attention
//
// Note:
// - The softmax operation is performed in float32 regardless of the input precision.
// - For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q.
//
// In the following the dimensions are given by:
// - B: The batch size.
// - N_q: The number of query heads.
// - N_kv: The number of key and value heads.
// - T_q: The number of queries per example.
// - T_kv: The number of keys and values per example.
// - D: The per-head dimension.
//
// Parameters:
// - [subject array] queries (array) – Queries with shape [B, N_q, T_q, D].
// - keys (array) – with shape [B, N_kv, T_kv, D].
// - values (array) – with shape [B, N_kv, T_kv, D].
// - scale (float) – Scale for queries (typically 1.0 / sqrt(q.shape(-1)).
// - mask (str or array, optional) – The mask to apply to the query-key scores.
// The mask can be an array or a string indicating the mask type. The only supported string type is "causal".
// If the mask is an array it can be a boolean or additive mask. The mask can have at most 4 dimensions and
// must be broadcast-compatible with the shape [B, N, T_q, T_kv]. If an additive mask is given its type must
// promote to the promoted type of q, k, and v.
// - sinks (array, optional) – An optional array of attention sinks. Default: None.
func (queries *Array) ScaledDotProductAttention(ctx ml.Context, keys, values ml.Tensor, scale float64, maskMode string, mask ml.Tensor, sinks ml.Tensor) ml.Tensor {
var r C.mlx_array
var s C.mlx_array
if sinks != nil {
s = sinks.(*Array).a
}
maskModeC := C.CString(maskMode)
defer C.free(unsafe.Pointer(maskModeC))
var maskArr C.mlx_array
if mask != nil {
maskArr = mask.(*Array).a
}
C.mlx_fast_scaled_dot_product_attention(
&r,
queries.a,
keys.(*Array).a,
values.(*Array).a,
C.float(scale),
maskModeC,
maskArr,
s,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) TakeAxes(ctx ml.Context, indicies ml.Tensor, axes int) ml.Tensor {
var r C.mlx_array
C.mlx_take_axis(&r, a.a, indicies.(*Array).a, C.int(axes), ctx.(*Context).stream)
return newArray(ctx.(*Context), r)
}
// TODO not sure if we'll want this variation taking raw ints instead of a tensor...
// func (a *Array) TakeAxes(ctx ml.Context, axes int, indicies ...int) ml.Tensor {
// var i C.mlx_array
// var r C.mlx_array
// if indicies != nil {
// shape := []C.int{C.int(len(indicies))}
// cindicies := make([]int32, len(indicies))
// for i, v := range indicies {
// cindicies[i] = int32(v)
// }
// i = C.mlx_array_new_data(
// unsafe.Pointer(&cindicies[0]),
// &shape[0],
// C.int(len(shape)),
// C.MLX_INT32,
// )
// }
// C.mlx_take_axis(&r, a.a, i, C.int(axes), ctx.(*Context).stream)
// return newArray(ctx.(*Context), r)
// }
func (a *Array) GELU(ctx ml.Context, up ...ml.Tensor) ml.Tensor {
// TODO precise vs fast, and compile
// x * mx.sigmoid(1.702 * x)
u16s := []float16.Float16{float16.Fromfloat32(1.702)}
cshape := []C.int{1}
f := C.mlx_array_new_data(unsafe.Pointer(&u16s[0]), &cshape[0], 1, C.MLX_FLOAT16)
defer C.mlx_array_free(f)
var r1, r2, r3 C.mlx_array
C.mlx_multiply(&r1, a.a, f, ctx.(*Context).stream)
defer C.mlx_array_free(r1)
C.mlx_sigmoid(&r2, r1, ctx.(*Context).stream)
defer C.mlx_array_free(r2)
C.mlx_multiply(&r3, a.a, r2, ctx.(*Context).stream)
if len(up) > 0 {
var r4 C.mlx_array
defer C.mlx_array_free(r3)
C.mlx_multiply(&r4, r3, up[0].(*Array).a, ctx.(*Context).stream)
return newArray(ctx.(*Context), r4)
}
return newArray(ctx.(*Context), r3)
}
// Create a view into the array with the given shape and strides.
//
// The resulting array will always be as if the provided array was row
// contiguous regardless of the provided arrays storage order and current
// strides.
//
// Note that this function should be used with caution as it changes the shape
// and strides of the array directly. This can lead to the resulting array
// pointing to invalid memory locations which can result into crashes.
//
// Parameters:
// - shape (list(int), optional) – The shape of the resulting array. If None it defaults to a.shape().
// - strides (list(int), optional) – The strides of the resulting array. If None it defaults to the
// reverse exclusive cumulative product of a.shape().
// - offset (int) – Skip that many elements from the beginning of the input array.
func (a *Array) AsStrided(ctx ml.Context, shape, strides []int, offset int) ml.Tensor {
var r C.mlx_array
sh := make([]C.int, len(shape))
st := make([]C.int64_t, len(strides))
var sh0 *C.int
var st0 *C.int64_t
for i, s := range shape {
sh[i] = C.int(s)
}
for i, s := range strides {
st[i] = C.int64_t(s)
}
if len(sh) > 0 {
sh0 = (*C.int)(unsafe.Pointer(&sh[0]))
}
if len(st) > 0 {
st0 = (*C.int64_t)(unsafe.Pointer(&st[0]))
}
C.mlx_as_strided(
&r,
a.a,
sh0,
C.size_t(len(sh)),
st0,
C.size_t(len(st)),
C.size_t(offset),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
cshape := make([]C.int, len(shape))
for i, dim := range shape {
cshape[i] = C.int(dim)
}
var r C.mlx_array
C.mlx_reshape(&r, a.a, &cshape[0], C.size_t(len(cshape)), ctx.(*Context).stream)
return newArray(ctx.(*Context), r)
}
func (a *Array) Transpose(ctx ml.Context, shape ...int) ml.Tensor {
ndim := min(C.mlx_array_ndim(a.a), C.size_t(len(shape)))
var r C.mlx_array
sh := make([]C.int, ndim)
for i := range ndim {
sh[i] = (C.int)(shape[i])
if int(sh[i]) >= int(ndim) {
slog.Error("Permute error", "tensor", a, "shape", shape)
panic("invalid pemute call")
}
}
if len(sh) > 0 {
C.mlx_transpose_axes(
&r,
a.a,
&sh[0],
ndim,
ctx.(*Context).stream,
)
} else {
C.mlx_transpose(
&r,
a.a,
ctx.(*Context).stream,
)
}
return newArray(ctx.(*Context), r)
}
func (a *Array) Contiguous(ctx ml.Context, allowColMajor bool) ml.Tensor {
var r C.mlx_array
C.mlx_contiguous(
&r,
a.a,
(C._Bool)(allowColMajor),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
// Conv2D implements ml.Tensor.
// GGML API
// input: [N, IC, IH, IW]
// weight: [OC,IC, KH, KW]
// result: [N, OC, OH, OW]
//
// MLX:
// input: (N, KH, KW, C_in)
// weight: (C_out, IH, IW, C_in)
// result: XXX
func (input *Array) Conv2D(ctx ml.Context, weight ml.Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) ml.Tensor {
var r C.mlx_array
C.mlx_conv2d(
&r,
input.a,
weight.(*Array).a,
C.int(stride0),
C.int(stride1),
C.int(padding0),
C.int(padding1),
C.int(dilation0),
C.int(dilation1),
C.int(groups),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (input *Array) Conv3D(ctx ml.Context, weight ml.Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) ml.Tensor {
var r C.mlx_array
C.mlx_conv3d(
&r,
input.a,
weight.(*Array).a,
C.int(stride0),
C.int(stride1),
C.int(stride2),
C.int(padding0),
C.int(padding1),
C.int(padding2),
C.int(dilation0),
C.int(dilation1),
C.int(dilation2),
C.int(groups),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) ToString() string {
str := C.mlx_string_new()
C.mlx_array_tostring(&str, a.a)
s := C.mlx_string_data(str)
defer C.mlx_string_free(str)
return C.GoString(s)
}
func (a *Array) LogValue() slog.Value {
dims := int(C.mlx_array_ndim(a.a))
strides := make([]int, dims)
for i := range strides {
strides[i] = int(C.stride(a.a, (C.int)(i)))
}
return slog.GroupValue(
slog.String("name", a.name),
slog.String("type", a.TypeString()),
slog.Any("shape", a.Shape()),
slog.Any("strides", strides),
// slog.String("values", C.GoString(s)),
)
}
func (a *Array) Shape() []int {
shape := make([]int, C.mlx_array_ndim(a.a))
for i := range shape {
shape[i] = int(C.mlx_array_dim(a.a, C.int(i)))
}
return shape
}
func (a *Array) TypeString() string {
switch C.mlx_array_dtype(a.a) {
case C.MLX_BOOL:
return "bool"
case C.MLX_UINT8:
return "uint8"
case C.MLX_UINT16:
return "uint16"
case C.MLX_UINT32:
return "uint32"
case C.MLX_UINT64:
return "uint64"
case C.MLX_INT8:
return "int8"
case C.MLX_INT16:
return "int16"
case C.MLX_INT32:
return "int32"
case C.MLX_INT64:
return "int64"
case C.MLX_FLOAT16:
return "float16"
case C.MLX_FLOAT32:
return "float32"
case C.MLX_BFLOAT16:
return "bfloat16"
case C.MLX_COMPLEX64:
return "complex64"
default:
return "unknown"
}
}
//go:build mlx
package mlx
import (
"log/slog"
"os"
"reflect"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
_ "github.com/ollama/ollama/x/model/models/gemma3"
)
func init() {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
}
func TestLoadModel(t *testing.T) {
dir := "/Users/daniel/Models/gemma-3-4b-it/"
b := &Backend{}
err := b.LoadSafeTensors(dir)
if err != nil {
t.Fatalf("load failed: %s", err)
}
}
func TestFromInts(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []int32{1, 2, 3, 4, 5, 6}
a := c.FromInts(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
}
func TestFromFloats(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []float32{1, 2, 3, 4, 5, 6}
a := c.FromFloats(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
res := a.Floats()
if !reflect.DeepEqual(res, data) {
t.Fatalf("incorrect results: %v", res)
}
}
func TestAdd(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
t3 := t1.Add(c, t2)
c.Compute(t3, exp)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp.Floats()) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestReshapeTranspose(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
c.Compute(t1)
t1f := t1.Floats()
exp := []float32{
0, 4, 8,
1, 5, 9,
2, 6, 10,
3, 7, 11,
12, 16, 20,
13, 17, 21,
14, 18, 22,
15, 19, 23,
}
if !reflect.DeepEqual(t1f, exp) {
t.Fatalf("incorrect results: %v", t1f)
}
}
func prod(vals ...int) int {
r := 1
for _, v := range vals {
r *= v
}
return r
}
func TestMatmul(t *testing.T) {
// TODO create scenarios...
b := &Backend{}
c := b.NewContext()
defer c.Close()
s1 := []int{1, 3, 2, 4}
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
s2 := []int{4, 2}
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
t3 := t1.Matmul(c, t2)
exp := []float32{
28, 34,
76, 98,
124, 162,
172, 226,
220, 290,
268, 354,
}
c.Compute(t3)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestRows(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
outputs := c.Zeros(ml.DTypeInt32, 1)
t2 := t1.TakeAxes(c, outputs, 1)
c.Forward(t1, t2).Compute(t1, t2)
t.Log(t1.ToString())
t.Log(t2.ToString())
f := t2.Floats()
t.Logf("Result: %v", f)
}
func TestCaching(t *testing.T) {
// Validate the caching algorithm
b := &Backend{}
c := b.NewContext()
defer c.Close()
batchSize := 3
headDim := 4
numKVHeads := 2
// Make cache twice the size of one test batch
cells := batchSize * 2
cellSize := numKVHeads * headDim
shape := []int{1, numKVHeads, batchSize, headDim}
stop := float32(1)
for _, x := range shape {
stop *= float32(x)
}
// Create the cache
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
// Input tensor
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
// Reshape to copy into the cache
/*
From MLX python/src/indexing.cpp mlx_scatter_args_array
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
auto up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
*/
numRows := 3
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
// Simulate cells 1,3,5 are available
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
cache.Scatter(c, indicies, up, axis)
c.Forward(cache)
// Cache should contain the data now
t.Log("Cache after put\n" + cache.ToString())
// Retrieve cache content and verify it matches
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
t1f := t1.Floats()
outf := out.Floats()
if !reflect.DeepEqual(t1f, outf) {
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
}
}
func TestGemma3(t *testing.T) {
// Why is the sky blue
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
limit := 50
// TODO generalize this
dir := "/Users/daniel/Models/gemma-3-4b-it/"
m, err := model.New(dir, ml.BackendParams{})
if err != nil {
t.Fatalf("unable to load model: %s", err)
}
b := m.Backend()
ctx := b.NewContext()
defer ctx.Close()
batch := input.Batch{
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
Positions: make([]int32, len(inputs)),
Sequences: make([]int, len(inputs)),
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
Offset: 0,
}
for i := range len(inputs) {
batch.Positions[i] = int32(i)
}
offset := len(inputs)
cache := m.Config().Cache
if cache != nil {
numSlots := 1
batchSize := 512
numCtx := 4096
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
err := cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
}
opts := api.DefaultOptions()
var grammar *sample.GrammarSampler
sampler := sample.NewSampler(
opts.Temperature,
opts.TopK,
opts.TopP,
opts.MinP,
opts.Seed,
grammar,
)
t.Log("Starting Forward pass loop")
pendingResponses := []string{}
for {
out, err := m.Forward(ctx, batch)
if err != nil {
t.Fatalf("failed forward pass: %s", err)
}
ctx.Forward(out)
outputs := out.Floats()
t.Logf("finished forward pass! length:%d", len(outputs))
// sample a token
logits := outputs
token, err := sampler.Sample(logits)
if err != nil {
t.Fatalf("unable to sample token: %s", err)
}
t.Logf("Sampled token: %v", token)
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
t.Log("hit EOS")
break
}
piece, err := m.(model.TextProcessor).Decode([]int32{token})
if err != nil {
t.Fatalf("unable to decode token: %s", err)
}
pendingResponses = append(pendingResponses, piece)
sequence := strings.Join(pendingResponses, "")
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
t.Logf("hit stop token: %v", stop)
break
}
t.Logf("RESULTS: %s", sequence)
batch = input.Batch{
Inputs: ctx.FromInts([]int32{token}, 1, 1),
Positions: make([]int32, 1),
Sequences: make([]int, 1),
Outputs: ctx.FromInts([]int32{0}, 1),
Offset: offset,
}
offset++
batch.Positions[0] = 0
err = cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
if offset > limit {
break
}
}
}
//go:build mlx
package mlx
/*
#include <stdio.h>
#include <string.h>
#include "mlx/c/array.h"
#include "mlx/c/ops.h"
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
void unpack_32_4(uint8_t* data, int8_t* dst) {
memset(dst, 0, 16);
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
if (j % 2 != 0) {
x <<= 4;
}
dst[j / 2] += x;
}
// Last 16 weights are in the higher bits
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] >> 4);
if (j % 2 != 0) {
x <<= 4;
}
dst[8 + j / 2] += x;
}
}
// Extracts (weight, scales, biases) from Q4_0 tensors.
// Data layout is: |16 bit scale|32 x 4bit weights|.
void extract_q4_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = -8 * scales[i];
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q4_1 tensors.
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
void extract_q4_1_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = *((float16_t*)(data) + 1);
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q8_0 tensors.
// Data layout is: |16 bit scale|32 x 8bit weights|.
void extract_q8_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t weights_per_block = 32;
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
uint8_t* block_data = data + i * bytes_per_block;
scales[i] = *((float16_t*)block_data);
biases[i] = -128 * scales[i];
for (int64_t j = 0; j < weights_per_block; ++j) {
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
// Original data is in int8_t, so we add a bias of -128 and invert the
// first bit.
x ^= 1 << 7;
weights[i * weights_per_block + j] = x;
}
}
}
// Drived from ggml-quants.c
#define QK_K 256
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
uint16_t d; // super-block scale
} block_q6_K;
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
const int64_t nb = k / QK_K;
block_q6_K *x = (block_q6_K *)vx;
float16_t* y = (float16_t *)vy;
for (int i = 0; i < nb; i++) {
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
const uint8_t * restrict ql = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l + 0] = d * sc[is + 0] * q1;
y[l + 32] = d * sc[is + 2] * q2;
y[l + 64] = d * sc[is + 4] * q3;
y[l + 96] = d * sc[is + 6] * q4;
}
y += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
#define K_SCALE_SIZE 12
#define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
// 4-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
typedef struct {
union {
struct {
uint16_t d; // super-block scale for quantized scales
uint16_t dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR_S;
uint16_t dm;
} GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
} else {
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
block_q4_K *x = (block_q4_K *)vx;
float16_t* y = (float16_t *)vy;
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const uint8_t * q = x[i].qs;
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
float16_t min = 0.0;
memcpy(&min, &x[i].dmin, sizeof(d));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
const float16_t d1 = d * sc; const float16_t m1 = min * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
const float16_t d2 = d * sc; const float16_t m2 = min * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
}
}
*/
import "C"
import (
"fmt"
"unsafe"
"github.com/x448/float16"
)
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
shape := append([]C.int{}, final_shape...)
var weights_per_byte C.int
if dtype == 2 || dtype == 3 {
weights_per_byte = 2
} else if dtype == 8 {
weights_per_byte = 1
} else {
return r, fmt.Errorf("unsupported tensor type %d", dtype)
}
weights_per_block := C.int(32)
if shape[len(shape)-1]%weights_per_block != 0 {
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
}
weights_shape := append([]C.int{}, shape...)
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
for i := range weights_shape {
w_nbytes *= weights_shape[i]
}
w_data := make([]byte, w_nbytes)
cbytes := C.CBytes(w_data)
defer C.free(cbytes)
weights := C.mlx_array_new_data(
cbytes,
&weights_shape[0],
C.int(len(weights_shape)),
C.MLX_UINT32,
)
// For scales and bias
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
for i := range shape {
sb_nbytes *= shape[i]
}
s_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(s_data)
defer C.free(cbytes)
scales := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
b_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(b_data)
defer C.free(cbytes)
biases := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
var bits C.int
switch dtype {
case 2:
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 3:
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 8:
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 8
}
groupSize := C.mlx_optional_int{value: 32, has_value: true}
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
C.mlx_dequantize(
&r,
weights,
scales,
biases,
groupSize,
bitsOpt,
nil, // TODO mode
dtypeOpt,
stream,
)
C.mlx_array_free(weights)
C.mlx_array_free(scales)
C.mlx_array_free(biases)
return r, nil
}
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
size := 1
for _, d := range shape {
size *= int(d)
}
fdata := make([]float16.Float16, size)
switch dtype {
case 14:
C.dequant_row_q6_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
case 12:
C.dequant_row_q4_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
default:
return r, fmt.Errorf("unsupported K quant")
}
r = C.mlx_array_new_data(
unsafe.Pointer(&fdata[0]),
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
return r, nil
}
package ml
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"hash/maphash"
"io"
"log/slog"
"math"
"net/http"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/logutil"
)
// GPULayers is a set of layers to be allocated on a single GPU
type GPULayers struct {
DeviceID
// Layers is a set of layer indicies to load
Layers []int
}
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
func (g GPULayers) FirstLayer() int {
if len(g.Layers) == 0 {
return math.MaxInt
}
first := g.Layers[0]
for i := 1; i < len(g.Layers); i++ {
if g.Layers[i] < first {
first = g.Layers[i]
}
}
return first
}
func (g GPULayers) String() string {
if len(g.Layers) == 0 {
return ""
}
slices.Sort(g.Layers)
contiguous := true
base := g.Layers[0]
for i := range g.Layers {
if g.Layers[i] != base+i {
contiguous = false
break
}
}
if contiguous {
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
} else {
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
}
}
// GPULayersList is a set of layer allocations across multiple GPUs
type GPULayersList []GPULayers
func (l GPULayersList) Len() int { return len(l) }
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
// Sort by the ordering of the layers offloaded
func (l GPULayersList) Less(i, j int) bool {
li := l[i].FirstLayer()
lj := l[j].FirstLayer()
return li < lj
}
func (l GPULayersList) String() string {
if l.Sum() > 0 {
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
} else {
return fmt.Sprintf("%v", []GPULayers(l))
}
}
// Sum is the total number of layers assigned across all GPUs
func (l GPULayersList) Sum() int {
var sum int
for _, g := range l {
sum += len(g.Layers)
}
return sum
}
var h maphash.Hash
// Hash is an identifier of this layer assignment
func (l GPULayersList) Hash() uint64 {
h.Reset()
for _, g := range l {
if len(g.Layers) > 0 {
h.WriteString(g.ID + g.Library)
for _, l := range g.Layers {
binary.Write(&h, binary.NativeEndian, int64(l))
}
}
}
return h.Sum64()
}
// ErrNoMem is returned when panicing due to insufficient memory. It includes
// the attempted memory allocation.
type ErrNoMem struct {
BackendMemory
}
func (e ErrNoMem) Error() string {
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
}
// Minimal unique device identification
type DeviceID struct {
// ID is an identifier for the device for matching with system
// management libraries. The ID is only unique for other devices
// using the same Library.
// This ID represents a "post filtered" view of the enumerated devices
// if the ID is numeric
ID string `json:"id"`
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
Library string `json:"backend,omitempty"`
}
// DeviceMemory provides a breakdown of the memory needed
// per device, such as a CPU or GPU.
type DeviceMemory struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string
// Weights is the per-layer memory needed for the model weights.
Weights []uint64
// Cache is the per-layer memory needed for the KV cache.
Cache []uint64
// Graph is the size of the compute graph. It is not per-layer.
Graph uint64
}
func sumMemory(mem []uint64) uint64 {
var sum uint64
for _, m := range mem {
sum += m
}
return sum
}
// Size returns the total size of the memory required by this device
func (m DeviceMemory) Size() uint64 {
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
}
func memoryPresent(mem []uint64) bool {
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
}
func (m DeviceMemory) LogValue() slog.Value {
var attrs []slog.Attr
if memoryPresent(m.Weights) {
attrs = append(attrs, slog.Any("Weights", m.Weights))
}
if memoryPresent(m.Cache) {
attrs = append(attrs, slog.Any("Cache", m.Cache))
}
if m.Graph != 0 {
attrs = append(attrs, slog.Any("Graph", m.Graph))
}
if len(attrs) > 0 && m.ID != "" {
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
}
return slog.GroupValue(attrs...)
}
// BackendMemory provides the amount of memory required to load the model
// per device based on the BackendParams. In some cases, not all required
// allocations will be known at this point. However, the size of the most recent
// allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress.
type BackendMemory struct {
// InputWeights are always located on the CPU and cannot be moved
InputWeights uint64
// CPU model components are located in system memory. This does not
// include unified memory allocated through the GPU.
CPU DeviceMemory
// GPU model components are located on one or more GPUs.
GPUs []DeviceMemory
}
func (m BackendMemory) LogValue() slog.Value {
var attrs []slog.Attr
if m.InputWeights != 0 {
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
}
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
for _, g := range m.GPUs {
attrs = append(attrs, slog.Any(g.Name, g))
}
return slog.GroupValue(attrs...)
}
// Log prints a high level summary of the memory
func (m BackendMemory) Log(level slog.Level) {
var total uint64
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := sumMemory(m.CPU.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := gpu.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.CPU.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
if total > 0 {
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
}
}
type DeviceInfo struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string `json:"name"`
// Description is the longer user-friendly identification of the device
Description string `json:"description"`
// FilterID is populated with the unfiltered device ID if a numeric ID is used
// so the device can be included.
FilterID string `json:"filter_id,omitempty"`
// Integrated is set true for integrated GPUs, false for Discrete GPUs
Integrated bool `json:"integration,omitempty"`
// PCIID is the bus, device and domain ID of the device for deduplication
// when discovered by multiple backends
PCIID string `json:"pci_id,omitempty"`
// TotalMemory is the total amount of memory the device can use for loading models
TotalMemory uint64 `json:"total_memory"`
// FreeMemory is the amount of memory currently available on the device for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// ComputeMajor is the major version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMajor int
// ComputeMinor is the minor version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMinor int
// Driver Information
DriverMajor int `json:"driver_major,omitempty"`
DriverMinor int `json:"driver_minor,omitempty"`
// Where backends were loaded from
LibraryPath []string
}
type SystemInfo struct {
// ThreadCount is the optimal number of threads to use for inference
ThreadCount int `json:"threads,omitempty"`
// TotalMemory is the total amount of system memory
TotalMemory uint64 `json:"total_memory,omitempty"`
// FreeMemory is the amount of memory currently available on the system for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// FreeSwap is the amount of system swap space reported as available
FreeSwap uint64 `json:"free_swap,omitempty"`
}
func (d DeviceInfo) Compute() string {
// AMD gfx is encoded into the major minor in hex form
if strings.EqualFold(d.Library, "ROCm") {
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
}
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
}
func (d DeviceInfo) Driver() string {
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
}
// MinimumMemory reports the amount of memory that should be set aside
// on the device for overhead (e.g. VRAM consumed by context structures independent
// of model allocations)
func (d DeviceInfo) MinimumMemory() uint64 {
if d.Library == "Metal" {
return 512 * format.MebiByte
}
return 457 * format.MebiByte
}
// Sort by Free Space.
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
type ByFreeMemory []DeviceInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool {
if a[i].Integrated && !a[j].Integrated {
return true
} else if !a[i].Integrated && a[j].Integrated {
return false
}
return a[i].FreeMemory < a[j].FreeMemory
}
// ByPerformance groups devices by similar speed
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
scores := []bool{}
for _, info := range l {
found := false
requested := info.Integrated
for i, score := range scores {
if score == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
scores = append(scores, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func LibraryPaths(l []DeviceInfo) []string {
gpuLibs := []string{LibOllamaPath}
for _, gpu := range l {
for _, dir := range gpu.LibraryPath {
needed := true
for _, existing := range gpuLibs {
if dir == existing {
needed = false
break
}
}
if needed {
gpuLibs = append(gpuLibs, dir)
}
}
}
return gpuLibs
}
type DeviceComparison int
const (
UniqueDevice DeviceComparison = iota
SameBackendDevice // The device is the same, and the library/backend is the same
DuplicateDevice // The same physical device but different library/backend (overlapping device)
)
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
if a.PCIID != b.PCIID {
return UniqueDevice
}
// If PCIID is empty, we have to use ID + library for uniqueness
if a.PCIID == "" && a.DeviceID != b.DeviceID {
return UniqueDevice
}
if a.Library == b.Library {
return SameBackendDevice
}
return DuplicateDevice
}
// For a SameBackendDevice, return true if b is better than a
// e.g. newer GPU library version
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
aLib := a.LibraryPath[len(a.LibraryPath)-1]
bLib := b.LibraryPath[len(b.LibraryPath)-1]
if aLib == bLib {
return false
}
aLibSplit := strings.SplitN(aLib, "_", 2)
bLibSplit := strings.SplitN(bLib, "_", 2)
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
return false
}
if aLibSplit[0] != bLibSplit[0] {
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
return false
}
if aLibSplit[1] == bLibSplit[1] {
return false
}
cmp := []string{aLibSplit[1], bLibSplit[1]}
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
return cmp[0] == bLibSplit[1]
}
// For each GPU, check if it does NOT support flash attention
func FlashAttentionSupported(l []DeviceInfo) bool {
for _, gpu := range l {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
if !supportsFA {
return false
}
}
return true
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
if len(l) == 0 {
return nil
}
env := map[string]string{}
for _, d := range l {
d.updateVisibleDevicesEnv(env, mustFilter)
}
return env
}
// NeedsInitValidation returns true if the device in question has the potential
// to crash at inference time and requires deeper validation before we include
// it in the supported devices list.
func (d DeviceInfo) NeedsInitValidation() bool {
// ROCm: rocblas will crash on unsupported devices.
// CUDA: verify CC is supported by the version of the library
return d.Library == "ROCm" || d.Library == "CUDA"
}
// Set the init validation environment variable
func (d DeviceInfo) AddInitValidation(env map[string]string) {
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
}
// PreferredLibrary returns true if this library is preferred over the other input
// library
// Used to filter out Vulkan in favor of CUDA or ROCm
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
// TODO in the future if we find Vulkan is better than ROCm on some devices
// that implementation can live here.
if d.Library == "CUDA" || d.Library == "ROCm" {
return true
}
return false
}
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
var envVar string
switch d.Library {
case "ROCm":
// ROCm must be filtered as it can crash the runner on unsupported devices
envVar = "ROCR_VISIBLE_DEVICES"
if runtime.GOOS != "linux" {
envVar = "HIP_VISIBLE_DEVICES"
}
case "CUDA":
if !mustFilter {
// By default we try to avoid filtering CUDA devices because ROCm also
// looks at the CUDA env var, and gets confused in mixed vendor environments.
return
}
envVar = "CUDA_VISIBLE_DEVICES"
default:
// Vulkan is not filtered via env var, but via scheduling decisions
return
}
v, existing := env[envVar]
if existing {
v = v + ","
}
if d.FilterID != "" {
v = v + d.FilterID
} else {
v = v + d.ID
}
env[envVar] = v
}
type BaseRunner interface {
// GetPort returns the localhost port number the runner is running on
GetPort() int
// HasExited indicates if the runner is no longer running. This can be used during
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
HasExited() bool
}
type RunnerDiscovery interface {
BaseRunner
// GetDeviceInfos will perform a query of the underlying device libraries
// for device identification and free VRAM information
// During bootstrap scenarios, this routine may take seconds to complete
GetDeviceInfos(ctx context.Context) []DeviceInfo
}
type FilteredRunnerDiscovery interface {
RunnerDiscovery
// GetActiveDeviceIDs returns the filtered set of devices actively in
// use by this runner for running models. If the runner is a bootstrap runner, no devices
// will be active yet so no device IDs are returned.
// This routine will not query the underlying device and will return immediately
GetActiveDeviceIDs() []DeviceID
}
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
var moreDevices []DeviceInfo
port := runner.GetPort()
tick := time.Tick(10 * time.Millisecond)
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("failed to finish discovery before timeout")
case <-tick:
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
// slog.Warn("failed to send request", "error", err)
if runner.HasExited() {
return nil, fmt.Errorf("runner crashed")
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// old runner, fall back to bootstrapping model
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
slog.Warn("failed to read response", "error", err)
continue
}
if resp.StatusCode != 200 {
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
return nil, fmt.Errorf("runner error: %s", string(body))
}
if err := json.Unmarshal(body, &moreDevices); err != nil {
slog.Warn("unmarshal encode response", "error", err)
continue
}
return moreDevices, nil
}
}
}
package nn
import (
"fmt"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
)
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
ctx.Forward(key, value)
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
// panic("after cache get") //
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
// var mask ml.Tensor
if cache != nil {
key, value, _ = cache.Get(ctx)
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
// panic("after cache get") //
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if cache != nil {
// TODO what to do with vmla?
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
} else {
panic("else case not supported")
// TODO transpose shapes are wrong
// key = key.Transpose(ctx, 0, 2, 1, 3)
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
// kq := query.Matmul(ctx, key)
// kq = kq.Scale(ctx, scale)
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
// kq = kq.Softmax(ctx)
// kqv := kq.Matmul(ctx, value)
// if vmla != nil {
// kqv = kqv.Matmul(ctx, vmla)
// }
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
}
}
package nn
import "github.com/ollama/ollama/x/ml"
type Conv2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
if m.Bias != nil {
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
}
return t
}
type Conv3D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment