"vscode:/vscode.git/clone" did not exist on "e6fe3ada570ec29154635f253907b67b08b52a66"
Commit 0cb78a2f authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update

parent 217903ab
...@@ -61,11 +61,13 @@ const ( ...@@ -61,11 +61,13 @@ const (
MIIM_SUBMENU = 0x00000004 MIIM_SUBMENU = 0x00000004
MIM_APPLYTOSUBMENUS = 0x80000000 MIM_APPLYTOSUBMENUS = 0x80000000
NIF_ICON = 0x00000002 NIF_ICON = 0x00000002
NIF_TIP = 0x00000004
NIF_INFO = 0x00000010 NIF_INFO = 0x00000010
NIF_MESSAGE = 0x00000001 NIF_MESSAGE = 0x00000001
SW_HIDE = 0 SW_HIDE = 0
TPM_BOTTOMALIGN = 0x0020 TPM_BOTTOMALIGN = 0x0020
TPM_LEFTALIGN = 0x0000 TPM_LEFTALIGN = 0x0000
TPM_RIGHTBUTTON = 0x0002
WM_CLOSE = 0x0010 WM_CLOSE = 0x0010
WM_USER = 0x0400 WM_USER = 0x0400
WS_CAPTION = 0x00C00000 WS_CAPTION = 0x00C00000
......
This is here to make sure the build/ directory exists for the go:embed command
This is here to make sure the build/ directory exists for the go:embed command
package build
import "embed"
// Darwin payloads separated by architecture to avoid duplicate payloads when cross compiling
//go:embed darwin/amd64/*
var EmbedFS embed.FS
package build
import "embed"
// Darwin payloads separated by architecture to avoid duplicate payloads when cross compiling
//go:embed darwin/arm64/*
var EmbedFS embed.FS
package build
import "embed"
//go:embed linux/*
var EmbedFS embed.FS
//go:build !linux && !darwin
package build
import "embed"
// unused on windows
var EmbedFS embed.FS
This is here to make sure the build/ directory exists for the go:embed command
This is here to make sure the build/ directory exists for the go:embed command
This diff is collapsed.
package cmd
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
)
func TestShowInfo(t *testing.T) {
t.Run("bare details", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("bare model info", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(7_000_000_000),
"test.context_length": float64(0),
"test.embedding_length": float64(0),
},
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
context length 0
embedding length 0
quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("parameters", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Parameters: `
stop never
stop gonna
stop give
stop you
stop up
temperature 99`,
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
Parameters
stop never
stop gonna
stop give
stop you
stop up
temperature 99
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("project info", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
ProjectorInfo: map[string]any{
"general.architecture": "clip",
"general.parameter_count": float64(133_700_000),
"clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0),
},
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
Projector
architecture clip
parameters 133.70M
embedding length 0
dimensions 0
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("system", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
System: `You are a pirate!
Ahoy, matey!
Weigh anchor!
`,
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
System
You are a pirate!
Ahoy, matey!
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("license", func(t *testing.T) {
var b bytes.Buffer
license, err := os.ReadFile(filepath.Join("..", "LICENSE"))
if err != nil {
t.Fatal(err)
}
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
License: string(license),
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
License
MIT License
Copyright (c) Ollama
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
}
func TestDeleteHandler(t *testing.T) {
stopped := false
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/delete" && r.Method == http.MethodDelete {
var req api.DeleteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name == "test-model" {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusNotFound)
}
return
}
if r.URL.Path == "/api/generate" && r.Method == http.MethodPost {
var req api.GenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Model == "test-model" {
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{
Done: true,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
stopped = true
return
} else {
w.WriteHeader(http.StatusNotFound)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{
Done: false,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(context.TODO())
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
t.Fatalf("DeleteHandler failed: %v", err)
}
if !stopped {
t.Fatal("Model was not stopped before deletion")
}
err := DeleteHandler(cmd, []string{"test-model-not-found"})
if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
}
}
func TestGetModelfileName(t *testing.T) {
tests := []struct {
name string
modelfileName string
fileExists bool
expectedName string
expectedErr error
}{
{
name: "no modelfile specified, no modelfile exists",
modelfileName: "",
fileExists: false,
expectedName: "",
expectedErr: os.ErrNotExist,
},
{
name: "no modelfile specified, modelfile exists",
modelfileName: "",
fileExists: true,
expectedName: "Modelfile",
expectedErr: nil,
},
{
name: "modelfile specified, no modelfile exists",
modelfileName: "crazyfile",
fileExists: false,
expectedName: "crazyfile",
expectedErr: os.ErrNotExist,
},
{
name: "modelfile specified, modelfile exists",
modelfileName: "anotherfile",
fileExists: true,
expectedName: "anotherfile",
expectedErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{
Use: "fakecmd",
}
cmd.Flags().String("file", "", "path to modelfile")
var expectedFilename string
if tt.fileExists {
tempDir, err := os.MkdirTemp("", "modelfiledir")
defer os.RemoveAll(tempDir)
if err != nil {
t.Fatalf("temp modelfile dir creation failed: %v", err)
}
var fn string
if tt.modelfileName != "" {
fn = tt.modelfileName
} else {
fn = "Modelfile"
}
tempFile, err := os.CreateTemp(tempDir, fn)
if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err)
}
expectedFilename = tempFile.Name()
err = cmd.Flags().Set("file", expectedFilename)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
} else {
if tt.modelfileName != "" {
expectedFilename = tt.modelfileName
err := cmd.Flags().Set("file", tt.modelfileName)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
}
}
actualFilename, actualErr := getModelfileName(cmd)
if actualFilename != expectedFilename {
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
}
if tt.expectedErr != os.ErrNotExist {
if actualErr != tt.expectedErr {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
} else {
if !os.IsNotExist(actualErr) {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
}
})
}
}
func TestPushHandler(t *testing.T) {
tests := []struct {
name string
modelName string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
}{
{
name: "successful push",
modelName: "test-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
var req api.PushRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}
// Simulate progress updates
responses := []api.ProgressResponse{
{Status: "preparing manifest"},
{Digest: "sha256:abc123456789", Total: 100, Completed: 50},
{Digest: "sha256:abc123456789", Total: 100, Completed: 100},
}
for _, resp := range responses {
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.(http.Flusher).Flush()
}
},
},
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
},
{
name: "unauthorized push",
modelName: "unauthorized-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{
"error": "access denied",
})
if err != nil {
t.Fatal(err)
}
},
},
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
handler(w, r)
return
}
http.Error(w, "not found", http.StatusNotFound)
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
// Capture stdout for the "Model pushed" message
oldStdout := os.Stdout
outR, outW, _ := os.Pipe()
os.Stdout = outW
err := PushHandler(cmd, []string{tt.modelName})
// Restore stderr
w.Close()
os.Stderr = oldStderr
// drain the pipe
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
// Restore stdout and get output
outW.Close()
os.Stdout = oldStdout
stdout, _ := io.ReadAll(outR)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
}
})
}
}
...@@ -18,7 +18,6 @@ import ( ...@@ -18,7 +18,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
) )
...@@ -31,26 +30,6 @@ const ( ...@@ -31,26 +30,6 @@ const (
MultilineSystem MultilineSystem
) )
func loadModel(cmd *cobra.Command, opts *runOptions) error {
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
chatReq := &api.ChatRequest{
Model: opts.Model,
KeepAlive: opts.KeepAlive,
}
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
}
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usage := func() { usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
...@@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Model = args[1] opts.Model = args[1]
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
return err return err
} }
continue continue
...@@ -340,8 +319,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -340,8 +319,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = append(opts.Messages, newMessage) opts.Messages = append(opts.Messages, newMessage)
} }
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset()
sb.Reset() sb.Reset()
continue continue
default: default:
...@@ -371,7 +348,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -371,7 +348,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] { switch args[1] {
case "info": case "info":
showInfo(resp) _ = showInfo(resp, os.Stderr)
case "license": case "license":
if resp.License == "" { if resp.License == "" {
fmt.Println("No license was specified for this model.") fmt.Println("No license was specified for this model.")
...@@ -463,13 +440,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -463,13 +440,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
// clear all previous images for better responses
if len(images) > 0 {
for i := range opts.Messages {
opts.Messages[i].Images = nil
}
}
newMessage.Content = msg newMessage.Content = msg
newMessage.Images = images newMessage.Images = images
} }
...@@ -522,35 +492,29 @@ func buildModelfile(opts runOptions) string { ...@@ -522,35 +492,29 @@ func buildModelfile(opts runOptions) string {
} }
func normalizeFilePath(fp string) string { func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements return strings.NewReplacer(
replacements := map[string]string{ "\\ ", " ", // Escaped space
"\\ ": " ", // Escaped space "\\(", "(", // Escaped left parenthesis
"\\(": "(", // Escaped left parenthesis "\\)", ")", // Escaped right parenthesis
"\\)": ")", // Escaped right parenthesis "\\[", "[", // Escaped left square bracket
"\\[": "[", // Escaped left square bracket "\\]", "]", // Escaped right square bracket
"\\]": "]", // Escaped right square bracket "\\{", "{", // Escaped left curly brace
"\\{": "{", // Escaped left curly brace "\\}", "}", // Escaped right curly brace
"\\}": "}", // Escaped right curly brace "\\$", "$", // Escaped dollar sign
"\\$": "$", // Escaped dollar sign "\\&", "&", // Escaped ampersand
"\\&": "&", // Escaped ampersand "\\;", ";", // Escaped semicolon
"\\;": ";", // Escaped semicolon "\\'", "'", // Escaped single quote
"\\'": "'", // Escaped single quote "\\\\", "\\", // Escaped backslash
"\\\\": "\\", // Escaped backslash "\\*", "*", // Escaped asterisk
"\\*": "*", // Escaped asterisk "\\?", "?", // Escaped question mark
"\\?": "?", // Escaped question mark ).Replace(fp)
}
for escaped, actual := range replacements {
fp = strings.ReplaceAll(fp, escaped, actual)
}
return fp
} }
func extractFileNames(input string) []string { func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20) // Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension // and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches // This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b` regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
re := regexp.MustCompile(regexPattern) re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1) return re.FindAllString(input, -1)
...@@ -563,10 +527,9 @@ func extractFileData(input string) (string, []api.ImageData, error) { ...@@ -563,10 +527,9 @@ func extractFileData(input string) (string, []api.ImageData, error) {
for _, fp := range filePaths { for _, fp := range filePaths {
nfp := normalizeFilePath(fp) nfp := normalizeFilePath(fp)
data, err := getImageData(nfp) data, err := getImageData(nfp)
if err != nil { if errors.Is(err, os.ErrNotExist) {
if os.IsNotExist(err) {
continue continue
} } else if err != nil {
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err) fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
return "", imgs, err return "", imgs, err
} }
...@@ -574,7 +537,7 @@ func extractFileData(input string) (string, []api.ImageData, error) { ...@@ -574,7 +537,7 @@ func extractFileData(input string) (string, []api.ImageData, error) {
input = strings.ReplaceAll(input, fp, "") input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data) imgs = append(imgs, data)
} }
return input, imgs, nil return strings.TrimSpace(input), imgs, nil
} }
func getImageData(filePath string) ([]byte, error) { func getImageData(filePath string) ([]byte, error) {
......
...@@ -12,44 +12,45 @@ import ( ...@@ -12,44 +12,45 @@ import (
func TestExtractFilenames(t *testing.T) { func TestExtractFilenames(t *testing.T) {
// Unix style paths // Unix style paths
input := ` some preamble input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg` /unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
res := extractFileNames(input) res := extractFileNames(input)
assert.Len(t, res, 5) assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png") assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.svg") assert.Contains(t, res[4], "five.JPG")
assert.NotContains(t, res[4], '"') assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbtween") assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg")
// Windows style paths // Windows style paths
input = ` some preamble input = ` some preamble
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2 c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4 /absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6 ./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8 d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
` `
res = extractFileNames(input) res = extractFileNames(input)
assert.Len(t, res, 10) assert.Len(t, res, 10)
assert.NotContains(t, res, "inbtween") assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:") assert.Contains(t, res[0], "c:")
assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[1], "c:") assert.Contains(t, res[1], "c:")
assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png") assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.svg") assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.png") assert.Contains(t, res[5], "six.png")
assert.Contains(t, res[6], "seven.svg") assert.Contains(t, res[6], "seven.JPEG")
assert.Contains(t, res[6], "d:") assert.Contains(t, res[6], "d:")
assert.Contains(t, res[7], "eight.png") assert.Contains(t, res[7], "eight.png")
assert.Contains(t, res[7], "c:") assert.Contains(t, res[7], "c:")
assert.Contains(t, res[8], "nine.png") assert.Contains(t, res[8], "nine.png")
assert.Contains(t, res[8], "d:") assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.svg") assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:") assert.Contains(t, res[9], "E:")
} }
......
...@@ -7,16 +7,27 @@ import ( ...@@ -7,16 +7,27 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"strings"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
type Parameters struct { type ModelParameters struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
} }
func (Parameters) KV(t *Tokenizer) llm.KV { type AdapterParameters struct {
Alpha uint32 `json:"lora_alpha"`
LoraLayers uint32 `json:"lora_layers"`
LoraParameters struct {
Rank uint32 `json:"rank"`
Alpha float32 `json:"alpha"`
Scale float32 `json:"scale"`
} `json:"lora_parameters"`
}
func (ModelParameters) KV(t *Tokenizer) llm.KV {
kv := llm.KV{ kv := llm.KV{
"general.file_type": uint32(1), "general.file_type": uint32(1),
"general.quantization_version": uint32(2), "general.quantization_version": uint32(2),
...@@ -27,6 +38,10 @@ func (Parameters) KV(t *Tokenizer) llm.KV { ...@@ -27,6 +38,10 @@ func (Parameters) KV(t *Tokenizer) llm.KV {
"tokenizer.ggml.token_type": t.Vocabulary.Types, "tokenizer.ggml.token_type": t.Vocabulary.Types,
} }
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
if t.Template != "" { if t.Template != "" {
kv["tokenizer.chat_template"] = t.Template kv["tokenizer.chat_template"] = t.Template
} }
...@@ -39,40 +54,119 @@ func (Parameters) KV(t *Tokenizer) llm.KV { ...@@ -39,40 +54,119 @@ func (Parameters) KV(t *Tokenizer) llm.KV {
return kv return kv
} }
func (Parameters) specialTokenTypes() []string { func (p AdapterParameters) KV() llm.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
} else {
alpha = p.LoraParameters.Alpha
}
kv := llm.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
"general.type": "adapter",
"general.version": "v0.2",
}
return kv
}
func (ModelParameters) specialTokenTypes() []string {
return []string{ return []string{
"bos", "eos", "unk", "sep", "pad", "cls", "mask", "bos", "eos", "unk", "sep", "pad", "cls", "mask",
} }
} }
func (Parameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error { func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts) return llm.WriteGGUF(ws, kv, ts)
} }
type Converter interface { func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts)
}
type ModelConverter interface {
// KV maps parameters to LLM key-values // KV maps parameters to LLM key-values
KV(*Tokenizer) llm.KV KV(*Tokenizer) llm.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []llm.Tensor Tensors([]Tensor) []llm.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
// tensorName returns the LLM tensor name for a specific input name
tensorName(string) string
// specialTokenTypes returns any special token types the model uses // specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
}
type moreParser interface {
parseMore(fs.FS) error
}
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(llm.KV) llm.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []llm.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
} }
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
}
var p AdapterParameters
if err := json.Unmarshal(bts, &p); err != nil {
return err
}
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
var conv AdapterConverter
switch arch {
case "llama":
conv = &llamaAdapter{}
case "gemma2":
conv = &gemma2Adapter{}
default:
return errors.New("unsupported architecture")
}
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
return err
}
if err := json.Unmarshal(bts, conv); err != nil {
return err
}
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path. // and files it finds in the input path.
// Supported input model formats include safetensors. // Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func Convert(fsys fs.FS, ws io.WriteSeeker) error { func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json") bts, err := fs.ReadFile(fsys, "config.json")
if err != nil { if err != nil {
return err return err
} }
var p Parameters var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil { if err := json.Unmarshal(bts, &p); err != nil {
return err return err
} }
...@@ -81,14 +175,20 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error { ...@@ -81,14 +175,20 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
return errors.New("unknown architecture") return errors.New("unknown architecture")
} }
var conv Converter var conv ModelConverter
switch p.Architectures[0] { switch p.Architectures[0] {
case "LlamaForCausalLM", "MistralForCausalLM": case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llama{} conv = &llamaModel{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
conv = &mixtral{} conv = &mixtralModel{}
case "GemmaForCausalLM": case "GemmaForCausalLM":
conv = &gemma{} conv = &gemmaModel{}
case "Gemma2ForCausalLM":
conv = &gemma2Model{}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "BertModel":
conv = &bertModel{}
default: default:
return errors.New("unsupported architecture") return errors.New("unsupported architecture")
} }
...@@ -97,23 +197,33 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error { ...@@ -97,23 +197,33 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
return err return err
} }
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes()) t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil { if err != nil {
return err return err
} }
if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) { vocabSize := int(p.VocabSize)
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens)) switch {
case vocabSize > len(t.Vocabulary.Tokens):
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) { for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i)) t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1) t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined) t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
} }
} else { case vocabSize < len(t.Vocabulary.Tokens):
return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
} }
ts, err := parseTensors(fsys) ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil { if err != nil {
return err return err
} }
......
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/llm"
)
type bertModel struct {
ModelParameters
NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NLayer uint32 `json:"n_layer"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NCtx uint32 `json:"n_ctx"`
HiddenSize uint32 `json:"hidden_size"`
NEmbd uint32 `json:"n_embd"`
IntermediateSize uint32 `json:"intermediate_size"`
NInner uint32 `json:"n_inner"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NHead uint32 `json:"n_head"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
PoolingType uint32
}
var (
_ ModelConverter = (*bertModel)(nil)
_ moreParser = (*bertModel)(nil)
)
func (p *bertModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "modules.json")
if err != nil {
return err
}
var modules []struct {
Type string `json:"type"`
Path string `json:"path"`
}
if err := json.Unmarshal(bts, &modules); err != nil {
return err
}
var pooling string
for _, m := range modules {
if m.Type == "sentence_transformers.models.Pooling" {
pooling = m.Path
break
}
}
if pooling != "" {
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
if err != nil {
return err
}
var pc struct {
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
}
if err := json.Unmarshal(bts, &pc); err != nil {
return err
}
if pc.PoolingModeMeanTokens {
p.PoolingType = 1
} else if pc.PoolingModeCLSToken {
p.PoolingType = 2
}
}
return nil
}
func (p *bertModel) KV(t *Tokenizer) llm.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
if contextLength := cmp.Or(p.MaxPositionEmbeddings, p.NCtx); contextLength > 0 {
kv["bert.context_length"] = contextLength
}
if embeddingLength := cmp.Or(p.HiddenSize, p.NEmbd); embeddingLength > 0 {
kv["bert.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
}
if feedForwardLength := cmp.Or(p.IntermediateSize, p.NInner); feedForwardLength > 0 {
kv["bert.feed_forward_length"] = cmp.Or(p.IntermediateSize, p.NInner)
}
if headCount := cmp.Or(p.NumAttentionHeads, p.NHead); headCount > 0 {
kv["bert.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
}
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon, p.NormEpsilon); layerNormEpsilon > 0 {
kv["bert.attention.layer_norm_epsilon"] = layerNormEpsilon
}
kv["tokenizer.ggml.model"] = "bert"
kv["tokenizer.ggml.token_type_count"] = uint32(2)
// convert to phantom space tokens
for i, e := range t.Tokens {
if strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]") {
// noop
} else if strings.HasPrefix(e, "##") {
t.Tokens[i] = e[2:]
} else {
t.Tokens[i] = "\u2581" + e
}
}
kv["tokenizer.ggml.tokens"] = t.Tokens
return kv
}
func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
"pooler.dense.weight",
"pooler.dense.bias",
}, t.Name()) {
continue
}
out = append(out, llm.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (bertModel) Replacements() []string {
return []string{
"encoder.layer", "blk",
"encoder.layers", "blk",
"embeddings.word_embeddings", "token_embd",
"embeddings.token_type_embeddings", "token_types",
"embeddings.LayerNorm", "token_embd_norm",
"embeddings.position_embeddings", "position_embd",
"attention.self.query", "attn_q",
"attention.self.key", "attn_k",
"attention.self.value", "attn_v",
"attention.output.dense", "attn_output",
"attention.output.LayerNorm", "attn_output_norm",
"intermediate.dense", "ffn_up",
"output.dense", "ffn_down",
"output.LayerNorm", "layer_output_norm",
}
}
...@@ -9,8 +9,8 @@ import ( ...@@ -9,8 +9,8 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
type gemma struct { type gemmaModel struct {
Parameters ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"` HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"` HiddenLayers uint32 `json:"num_hidden_layers"`
...@@ -21,12 +21,11 @@ type gemma struct { ...@@ -21,12 +21,11 @@ type gemma struct {
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
} }
var _ Converter = (*gemma)(nil) var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemma) KV(t *Tokenizer) llm.KV { func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
kv := p.Parameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma" kv["general.architecture"] = "gemma"
kv["general.name"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings kv["gemma.context_length"] = p.MaxPositionEmbeddings
kv["gemma.embedding_length"] = p.HiddenSize kv["gemma.embedding_length"] = p.HiddenSize
kv["gemma.block_count"] = p.HiddenLayers kv["gemma.block_count"] = p.HiddenLayers
...@@ -43,16 +42,15 @@ func (p *gemma) KV(t *Tokenizer) llm.KV { ...@@ -43,16 +42,15 @@ func (p *gemma) KV(t *Tokenizer) llm.KV {
return kv return kv
} }
func (p *gemma) Tensors(ts []Tensor) []llm.Tensor { func (p *gemmaModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor var out []llm.Tensor
for _, t := range ts { for _, t := range ts {
name := p.tensorName(t.Name()) if strings.HasSuffix(t.Name(), "_norm.weight") {
if strings.HasSuffix(name, "_norm.weight") {
t.SetRepacker(p.addOne) t.SetRepacker(p.addOne)
} }
out = append(out, llm.Tensor{ out = append(out, llm.Tensor{
Name: name, Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
...@@ -62,8 +60,8 @@ func (p *gemma) Tensors(ts []Tensor) []llm.Tensor { ...@@ -62,8 +60,8 @@ func (p *gemma) Tensors(ts []Tensor) []llm.Tensor {
return out return out
} }
func (p *gemma) tensorName(n string) string { func (p *gemmaModel) Replacements() []string {
return strings.NewReplacer( return []string{
"model.embed_tokens", "token_embd", "model.embed_tokens", "token_embd",
"model.norm", "output_norm", "model.norm", "output_norm",
"model.layers", "blk", "model.layers", "blk",
...@@ -76,11 +74,10 @@ func (p *gemma) tensorName(n string) string { ...@@ -76,11 +74,10 @@ func (p *gemma) tensorName(n string) string {
"mlp.down_proj", "ffn_down", "mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up", "mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm", "post_attention_layernorm", "ffn_norm",
"block_sparse_moe.gate", "ffn_inp", }
).Replace(n)
} }
func (*gemma) addOne(_ string, data []float32, shape []uint64) ([]float32, error) { func (*gemmaModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data)) n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0])) ones := tensor.Ones(tensor.Float32, int(shape[0]))
......
package convert
import (
"github.com/ollama/ollama/llm"
)
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
AttentionLogitSoftcap float32 `json:"attn_logit_softcapping"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) llm.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
kv["gemma2.embedding_length"] = p.HiddenSize
kv["gemma2.block_count"] = p.HiddenLayers
kv["gemma2.feed_forward_length"] = p.IntermediateSize
kv["gemma2.attention.head_count"] = p.NumAttentionHeads
kv["gemma2.attention.head_count_kv"] = p.NumKeyValueHeads
kv["gemma2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma2.attention.key_length"] = p.HeadDim
kv["gemma2.attention.value_length"] = p.HeadDim
kv["gemma2.attention.sliding_window"] = p.SlidingWindow
kv["gemma2.attn_logit_softcapping"] = p.AttentionLogitSoftcap
kv["gemma2.final_logit_softcapping"] = p.FinalLogitSoftcap
kv["tokenizer.ggml.eot_token_id"] = uint32(107)
kv["tokenizer.ggml.middle_token_id"] = uint32(68)
kv["tokenizer.ggml.prefix_token_id"] = uint32(67)
kv["tokenizer.ggml.suffix_token_id"] = uint32(69)
return kv
}
func (p *gemma2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm",
}
}
package convert
import (
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type gemma2Adapter struct {
AdapterParameters
}
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV llm.KV) llm.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv
}
func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
shape[0], shape[1] = shape[1], shape[0]
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *gemma2Adapter) Replacements() []string {
return []string{
"base_model.model.", "",
"model.layers", "blk",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"lora_A.weight", "weight.lora_a",
"lora_B.weight", "weight.lora_b",
"lora_a", "weight.lora_a",
"lora_b", "weight.lora_b",
}
}
func (p *gemma2Adapter) repack(name string, data []float32, shape []uint64) ([]float32, error) {
dims := []int{int(shape[1]), int(shape[0])}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.T(1, 0); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
...@@ -3,6 +3,7 @@ package convert ...@@ -3,6 +3,7 @@ package convert
import ( import (
"cmp" "cmp"
"fmt" "fmt"
"math"
"strings" "strings"
"github.com/pdevine/tensor" "github.com/pdevine/tensor"
...@@ -11,8 +12,8 @@ import ( ...@@ -11,8 +12,8 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
type llama struct { type llamaModel struct {
Parameters ModelParameters
NLayers uint32 `json:"n_layers"` NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"` NumHiddenLayers uint32 `json:"num_hidden_layers"`
NLayer uint32 `json:"n_layer"` NLayer uint32 `json:"n_layer"`
...@@ -28,7 +29,13 @@ type llama struct { ...@@ -28,7 +29,13 @@ type llama struct {
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
RopeScaling struct { RopeScaling struct {
Type string `json:"type"` Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"` Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor
} `json:"rope_scaling"` } `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"` RMSNormEPS float32 `json:"rms_norm_eps"`
LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEPS float32 `json:"layer_norm_eps"`
...@@ -37,12 +44,11 @@ type llama struct { ...@@ -37,12 +44,11 @@ type llama struct {
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
} }
var _ Converter = (*llama)(nil) var _ ModelConverter = (*llamaModel)(nil)
func (p *llama) KV(t *Tokenizer) llm.KV { func (p *llamaModel) KV(t *Tokenizer) llm.KV {
kv := p.Parameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama" kv["general.architecture"] = "llama"
kv["general.name"] = "llama"
kv["llama.vocab_size"] = p.VocabSize kv["llama.vocab_size"] = p.VocabSize
kv["llama.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) kv["llama.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
...@@ -71,6 +77,27 @@ func (p *llama) KV(t *Tokenizer) llm.KV { ...@@ -71,6 +77,27 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
if p.RopeScaling.Type == "linear" { if p.RopeScaling.Type == "linear" {
kv["llama.rope.scaling.type"] = p.RopeScaling.Type kv["llama.rope.scaling.type"] = p.RopeScaling.Type
kv["llama.rope.scaling.factor"] = p.RopeScaling.Factor kv["llama.rope.scaling.factor"] = p.RopeScaling.Factor
} else if p.RopeScaling.RopeType == "llama3" {
dim := p.HiddenSize / p.NumAttentionHeads
for i := uint32(0); i < dim; i += 2 {
factor := cmp.Or(p.RopeScaling.Factor, 8.0)
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh
lambda := 2 * math.Pi * math.Pow(float64(p.RopeTheta), float64(i)/float64(dim))
if lambda < float64(lambdaHigh) {
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0)
} else if lambda > float64(lambdaLow) {
p.RopeScaling.factors = append(p.RopeScaling.factors, factor)
} else {
smooth := (float32(original)/float32(lambda) - factorLow) / (factorHigh - factorLow)
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0/((1-smooth)/factor+smooth))
}
}
} }
if p.NumKeyValueHeads > 0 { if p.NumKeyValueHeads > 0 {
...@@ -90,24 +117,29 @@ func (p *llama) KV(t *Tokenizer) llm.KV { ...@@ -90,24 +117,29 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
kv["llama.attention.value_length"] = p.HeadDim kv["llama.attention.value_length"] = p.HeadDim
} }
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
return kv return kv
} }
func (p *llama) Tensors(ts []Tensor) []llm.Tensor { func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor var out []llm.Tensor
if p.RopeScaling.factors != nil {
out = append(out, llm.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
WriterTo: p.RopeScaling.factors,
})
}
for _, t := range ts { for _, t := range ts {
name := p.tensorName(t.Name()) if strings.HasSuffix(t.Name(), "attn_q.weight") ||
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
strings.HasSuffix(name, "attn_k.weight") {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
} }
out = append(out, llm.Tensor{ out = append(out, llm.Tensor{
Name: name, Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
...@@ -117,8 +149,8 @@ func (p *llama) Tensors(ts []Tensor) []llm.Tensor { ...@@ -117,8 +149,8 @@ func (p *llama) Tensors(ts []Tensor) []llm.Tensor {
return out return out
} }
func (p *llama) tensorName(n string) string { func (p *llamaModel) Replacements() []string {
return strings.NewReplacer( return []string{
"lm_head", "output", "lm_head", "output",
"model.embed_tokens", "token_embd", "model.embed_tokens", "token_embd",
"model.norm", "output_norm", "model.norm", "output_norm",
...@@ -132,21 +164,19 @@ func (p *llama) tensorName(n string) string { ...@@ -132,21 +164,19 @@ func (p *llama) tensorName(n string) string {
"mlp.down_proj", "ffn_down", "mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up", "mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm", "post_attention_layernorm", "ffn_norm",
// mixtral }
"block_sparse_moe.gate", "ffn_gate_inp",
).Replace(n)
} }
func (p *llama) repack(name string, data []float32, shape []uint64) ([]float32, error) { func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int var dims []int
for _, dim := range shape { for _, dim := range shape {
dims = append(dims, int(dim)) dims = append(dims, int(dim))
} }
var heads uint32 var heads uint32
if strings.HasSuffix(name, "q_proj.weight") { if strings.HasSuffix(name, "attn_q.weight") {
heads = p.NumAttentionHeads heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "k_proj.weight") { } else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else { } else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name) return nil, fmt.Errorf("unknown tensor for repack: %s", name)
......
package convert
import (
"cmp"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type llamaAdapter struct {
AdapterParameters
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
}
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv
}
func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
shape[0], shape[1] = shape[1], shape[0]
t.SetRepacker(p.repackAndTranspose)
} else {
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
}
return out
}
func (p *llamaAdapter) Replacements() []string {
return []string{
"base_model.model.", "",
"model.layers", "blk",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"lora_A.weight", "weight.lora_a",
"lora_B.weight", "weight.lora_b",
"lora_a", "weight.lora_a",
"lora_b", "weight.lora_b",
}
}
func (p *llamaAdapter) repack(name string, data []float32, shape []uint64) ([]float32, error) {
dims := []int{int(shape[1]), int(shape[0])}
var heads uint32
if strings.HasSuffix(name, "attn_q.weight.lora_a") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight.lora_a") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return data, nil
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
func (p *llamaAdapter) repackAndTranspose(name string, data []float32, shape []uint64) ([]float32, error) {
dims := []int{int(shape[1]), int(shape[0])}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var heads uint32
if strings.HasSuffix(name, "attn_q.weight.lora_a") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight.lora_a") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
}
if heads > 0 {
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
}
if err := n.T(1, 0); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment