Commit 0cb78a2f authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update

parent 217903ab
...@@ -61,11 +61,13 @@ const ( ...@@ -61,11 +61,13 @@ const (
MIIM_SUBMENU = 0x00000004 MIIM_SUBMENU = 0x00000004
MIM_APPLYTOSUBMENUS = 0x80000000 MIM_APPLYTOSUBMENUS = 0x80000000
NIF_ICON = 0x00000002 NIF_ICON = 0x00000002
NIF_TIP = 0x00000004
NIF_INFO = 0x00000010 NIF_INFO = 0x00000010
NIF_MESSAGE = 0x00000001 NIF_MESSAGE = 0x00000001
SW_HIDE = 0 SW_HIDE = 0
TPM_BOTTOMALIGN = 0x0020 TPM_BOTTOMALIGN = 0x0020
TPM_LEFTALIGN = 0x0000 TPM_LEFTALIGN = 0x0000
TPM_RIGHTBUTTON = 0x0002
WM_CLOSE = 0x0010 WM_CLOSE = 0x0010
WM_USER = 0x0400 WM_USER = 0x0400
WS_CAPTION = 0x00C00000 WS_CAPTION = 0x00C00000
......
This is here to make sure the build/ directory exists for the go:embed command
This is here to make sure the build/ directory exists for the go:embed command
package build
import "embed"
// Darwin payloads separated by architecture to avoid duplicate payloads when cross compiling
//go:embed darwin/amd64/*
var EmbedFS embed.FS
package build
import "embed"
// Darwin payloads separated by architecture to avoid duplicate payloads when cross compiling
//go:embed darwin/arm64/*
var EmbedFS embed.FS
package build
import "embed"
//go:embed linux/*
var EmbedFS embed.FS
//go:build !linux && !darwin
package build
import "embed"
// unused on windows
var EmbedFS embed.FS
This is here to make sure the build/ directory exists for the go:embed command
This is here to make sure the build/ directory exists for the go:embed command
...@@ -2,6 +2,7 @@ package cmd ...@@ -2,6 +2,7 @@ package cmd
import ( import (
"archive/zip" "archive/zip"
"bufio"
"bytes" "bytes"
"context" "context"
"crypto/ed25519" "crypto/ed25519"
...@@ -18,10 +19,10 @@ import ( ...@@ -18,10 +19,10 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"slices" "strconv"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"time" "time"
...@@ -33,39 +34,67 @@ import ( ...@@ -33,39 +34,67 @@ import (
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func CreateHandler(cmd *cobra.Command, args []string) error { var (
filename, _ := cmd.Flags().GetString("file") errModelNotFound = errors.New("no Modelfile or safetensors files found")
filename, err := filepath.Abs(filename) errModelfileNotFound = errors.New("specified Modelfile wasn't found")
)
func getModelfileName(cmd *cobra.Command) (string, error) {
fn, _ := cmd.Flags().GetString("file")
filename := fn
if filename == "" {
filename = "Modelfile"
}
absName, err := filepath.Abs(filename)
if err != nil { if err != nil {
return err return "", err
} }
client, err := api.ClientFromEnvironment() _, err = os.Stat(absName)
if err != nil { if err != nil {
return err return fn, err
} }
return absName, nil
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
f, err := os.Open(filename) var reader io.Reader
if err != nil {
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
}
} else if err != nil {
return err return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
reader = f
defer f.Close()
} }
defer f.Close()
modelfile, err := parser.ParseFile(f) modelfile, err := parser.ParseFile(reader)
if err != nil { if err != nil {
return err return err
} }
...@@ -78,6 +107,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error { ...@@ -78,6 +107,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
status := "transferring model data" status := "transferring model data"
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
defer p.Stop()
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
for i := range modelfile.Commands { for i := range modelfile.Commands {
switch modelfile.Commands[i].Name { switch modelfile.Commands[i].Name {
...@@ -112,7 +147,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { ...@@ -112,7 +147,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile path = tempfile
} }
digest, err := createBlob(cmd, client, path) digest, err := createBlob(cmd, client, path, spinner)
if err != nil { if err != nil {
return err return err
} }
...@@ -202,6 +237,12 @@ func tempZipFiles(path string) (string, error) { ...@@ -202,6 +237,12 @@ func tempZipFiles(path string) (string, error) {
// safetensors files might be unresolved git lfs references; skip if they are // safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...) files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapters.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapter_model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are // pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
...@@ -211,7 +252,7 @@ func tempZipFiles(path string) (string, error) { ...@@ -211,7 +252,7 @@ func tempZipFiles(path string) (string, error) {
// covers consolidated.x.pth, consolidated.pth // covers consolidated.x.pth, consolidated.pth
files = append(files, pt...) files = append(files, pt...)
} else { } else {
return "", errors.New("no safetensors or torch files found") return "", errModelNotFound
} }
// add configuration files, json files are detected as text/plain // add configuration files, json files are detected as text/plain
...@@ -221,6 +262,14 @@ func tempZipFiles(path string) (string, error) { ...@@ -221,6 +262,14 @@ func tempZipFiles(path string) (string, error) {
} }
files = append(files, js...) files = append(files, js...)
// bert models require a nested config.json
// TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
if err != nil {
return "", err
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is // tokenizer.model might be a unresolved git lfs reference; error if it is
...@@ -250,6 +299,11 @@ func tempZipFiles(path string) (string, error) { ...@@ -250,6 +299,11 @@ func tempZipFiles(path string) (string, error) {
return "", err return "", err
} }
zfi.Name, err = filepath.Rel(path, file)
if err != nil {
return "", err
}
zf, err := zipfile.CreateHeader(zfi) zf, err := zipfile.CreateHeader(zfi)
if err != nil { if err != nil {
return "", err return "", err
...@@ -263,13 +317,20 @@ func tempZipFiles(path string) (string, error) { ...@@ -263,13 +317,20 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil return tempfile.Name(), nil
} }
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
return "", err return "", err
} }
defer bin.Close() defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
hash := sha256.New() hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil { if _, err := io.Copy(hash, bin); err != nil {
return "", err return "", err
...@@ -279,13 +340,76 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er ...@@ -279,13 +340,76 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err return "", err
} }
var pw progressWriter
status := "transferring model data 0%"
spinner.SetMessage(status)
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
} }
type progressWriter struct {
n atomic.Int64
}
func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n.Add(int64(len(p)))
return len(p), nil
}
func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
req := &api.GenerateRequest{
Model: opts.Model,
KeepAlive: opts.KeepAlive,
}
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
}
func StopHandler(cmd *cobra.Command, args []string) error {
opts := &runOptions{
Model: args[0],
KeepAlive: &api.Duration{Duration: 0},
}
if err := loadOrUnloadModel(cmd, opts); err != nil {
if strings.Contains(err.Error(), "not found") {
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
}
}
return nil
}
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
...@@ -329,6 +453,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { ...@@ -329,6 +453,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if len(prompts) > 0 { if len(prompts) > 0 {
interactive = false interactive = false
} }
// Be quiet if we're redirecting to a pipe or file
if !term.IsTerminal(int(os.Stdout.Fd())) {
interactive = false
}
nowrap, err := cmd.Flags().GetBool("nowordwrap") nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil { if err != nil {
...@@ -360,11 +488,11 @@ func RunHandler(cmd *cobra.Command, args []string) error { ...@@ -360,11 +488,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
opts.MultiModal = slices.Contains(info.Details.Families, "clip") opts.MultiModal = len(info.ProjectorInfo) != 0
opts.ParentModel = info.Details.ParentModel opts.ParentModel = info.Details.ParentModel
if interactive { if interactive {
if err := loadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
return err return err
} }
...@@ -385,47 +513,6 @@ func RunHandler(cmd *cobra.Command, args []string) error { ...@@ -385,47 +513,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts) return generate(cmd, opts)
} }
func errFromUnknownKey(unknownKeyErr error) error {
// find SSH public key in the error message
sshKeyPattern := `ssh-\w+ [^\s"]+`
re := regexp.MustCompile(sshKeyPattern)
matches := re.FindStringSubmatch(unknownKeyErr.Error())
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
if err != nil {
return unknownKeyErr
}
localPubKey = strings.TrimSpace(string(svcPubKey))
}
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
if serverPubKey != localPubKey {
return unknownKeyErr
}
var msg strings.Builder
msg.WriteString(unknownKeyErr.Error())
msg.WriteString("\n\nYour ollama key is:\n")
msg.WriteString(localPubKey)
msg.WriteString("\nAdd your key at:\n")
msg.WriteString("https://ollama.com/settings/keys")
return errors.New(msg.String())
}
return unknownKeyErr
}
func PushHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
...@@ -472,6 +559,8 @@ func PushHandler(cmd *cobra.Command, args []string) error { ...@@ -472,6 +559,8 @@ func PushHandler(cmd *cobra.Command, args []string) error {
} }
request := api.PushRequest{Name: args[0], Insecure: insecure} request := api.PushRequest{Name: args[0], Insecure: insecure}
n := model.ParseName(args[0])
if err := client.Push(cmd.Context(), &request, fn); err != nil { if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
...@@ -479,18 +568,19 @@ func PushHandler(cmd *cobra.Command, args []string) error { ...@@ -479,18 +568,19 @@ func PushHandler(cmd *cobra.Command, args []string) error {
if strings.Contains(err.Error(), "access denied") { if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
} }
host := model.ParseName(args[0]).Host
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// re-throw an error with a more user-friendly message
return errFromUnknownKey(err)
}
return err return err
} }
p.Stop()
spinner.Stop() spinner.Stop()
destination := n.String()
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
}
fmt.Printf("\nYou can find your model at:\n\n")
fmt.Printf("\t%s\n", destination)
return nil return nil
} }
...@@ -520,7 +610,7 @@ func ListHandler(cmd *cobra.Command, args []string) error { ...@@ -520,7 +610,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
table.SetHeaderLine(false) table.SetHeaderLine(false)
table.SetBorder(false) table.SetBorder(false)
table.SetNoWhiteSpace(true) table.SetNoWhiteSpace(true)
table.SetTablePadding("\t") table.SetTablePadding(" ")
table.AppendBulk(data) table.AppendBulk(data)
table.Render() table.Render()
...@@ -555,7 +645,15 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error { ...@@ -555,7 +645,15 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100) cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent)) procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
} }
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
var until string
delta := time.Since(m.ExpiresAt)
if delta > 0 {
until = "Stopping..."
} else {
until = format.HumanTime(m.ExpiresAt, "Never")
}
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
} }
} }
...@@ -566,7 +664,7 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error { ...@@ -566,7 +664,7 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
table.SetHeaderLine(false) table.SetHeaderLine(false)
table.SetBorder(false) table.SetBorder(false)
table.SetNoWhiteSpace(true) table.SetNoWhiteSpace(true)
table.SetTablePadding("\t") table.SetTablePadding(" ")
table.AppendBulk(data) table.AppendBulk(data)
table.Render() table.Render()
...@@ -579,6 +677,17 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { ...@@ -579,6 +677,17 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
// Unload the model if it's running before deletion
opts := &runOptions{
Model: args[0],
KeepAlive: &api.Duration{Duration: 0},
}
if err := loadOrUnloadModel(cmd, opts); err != nil {
if !strings.Contains(err.Error(), "not found") {
return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
}
}
for _, name := range args { for _, name := range args {
req := api.DeleteRequest{Name: name} req := api.DeleteRequest{Name: name}
if err := client.Delete(cmd.Context(), &req); err != nil { if err := client.Delete(cmd.Context(), &req); err != nil {
...@@ -654,130 +763,97 @@ func ShowHandler(cmd *cobra.Command, args []string) error { ...@@ -654,130 +763,97 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
case "parameters": case "parameters":
fmt.Println(resp.Parameters) fmt.Println(resp.Parameters)
case "system": case "system":
fmt.Println(resp.System) fmt.Print(resp.System)
case "template": case "template":
fmt.Println(resp.Template) fmt.Print(resp.Template)
} }
return nil return nil
} }
showInfo(resp) return showInfo(resp, os.Stdout)
return nil
} }
func showInfo(resp *api.ShowResponse) { func showInfo(resp *api.ShowResponse, w io.Writer) error {
arch := resp.ModelInfo["general.architecture"].(string) tableRender := func(header string, rows func() [][]string) {
fmt.Fprintln(w, " ", header)
modelData := [][]string{ table := tablewriter.NewWriter(w)
{"arch", arch}, table.SetAlignment(tablewriter.ALIGN_LEFT)
{"parameters", resp.Details.ParameterSize}, table.SetBorder(false)
{"quantization", resp.Details.QuantizationLevel}, table.SetNoWhiteSpace(true)
{"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))}, table.SetTablePadding(" ")
{"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))},
} switch header {
case "Template", "System", "License":
mainTableData := [][]string{ table.SetColWidth(100)
{"Model"},
{renderSubTable(modelData, false)},
}
if resp.ProjectorInfo != nil {
projectorData := [][]string{
{"arch", "clip"},
{"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
} }
if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok { table.AppendBulk(rows())
projectorData = append(projectorData, []string{"projector type", projectorType.(string)}) table.Render()
fmt.Fprintln(w)
}
tableRender("Model", func() (rows [][]string) {
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
} }
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
return
})
projectorData = append(projectorData, if resp.ProjectorInfo != nil {
[]string{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))}, tableRender("Projector", func() (rows [][]string) {
[]string{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))}, arch := resp.ProjectorInfo["general.architecture"].(string)
) rows = append(rows, []string{"", "architecture", arch})
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))})
mainTableData = append(mainTableData, rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ProjectorInfo[fmt.Sprintf("%s.vision.embedding_length", arch)].(float64), 'f', -1, 64)})
[]string{"Projector"}, rows = append(rows, []string{"", "dimensions", strconv.FormatFloat(resp.ProjectorInfo[fmt.Sprintf("%s.vision.projection_dim", arch)].(float64), 'f', -1, 64)})
[]string{renderSubTable(projectorData, false)}, return
) })
} }
if resp.Parameters != "" { if resp.Parameters != "" {
mainTableData = append(mainTableData, []string{"Parameters"}, []string{formatParams(resp.Parameters)}) tableRender("Parameters", func() (rows [][]string) {
} scanner := bufio.NewScanner(strings.NewReader(resp.Parameters))
for scanner.Scan() {
if resp.System != "" { if text := scanner.Text(); text != "" {
mainTableData = append(mainTableData, []string{"System"}, []string{renderSubTable(twoLines(resp.System), true)}) rows = append(rows, append([]string{""}, strings.Fields(text)...))
} }
}
if resp.License != "" { return
mainTableData = append(mainTableData, []string{"License"}, []string{renderSubTable(twoLines(resp.License), true)}) })
}
table := tablewriter.NewWriter(os.Stdout)
table.SetAutoWrapText(false)
table.SetBorder(false)
table.SetAlignment(tablewriter.ALIGN_LEFT)
for _, v := range mainTableData {
table.Append(v)
}
table.Render()
}
func renderSubTable(data [][]string, file bool) string {
var buf bytes.Buffer
table := tablewriter.NewWriter(&buf)
table.SetAutoWrapText(!file)
table.SetBorder(false)
table.SetNoWhiteSpace(true)
table.SetTablePadding("\t")
table.SetAlignment(tablewriter.ALIGN_LEFT)
for _, v := range data {
table.Append(v)
}
table.Render()
renderedTable := buf.String()
lines := strings.Split(renderedTable, "\n")
for i, line := range lines {
lines[i] = "\t" + line
} }
return strings.Join(lines, "\n") head := func(s string, n int) (rows [][]string) {
} scanner := bufio.NewScanner(strings.NewReader(s))
for scanner.Scan() && (len(rows) < n || n < 0) {
func twoLines(s string) [][]string { if text := scanner.Text(); text != "" {
lines := strings.Split(s, "\n") rows = append(rows, []string{"", strings.TrimSpace(text)})
res := [][]string{}
count := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" {
count++
res = append(res, []string{line})
if count == 2 {
return res
} }
} }
return
} }
return res
}
func formatParams(s string) string { if resp.System != "" {
lines := strings.Split(s, "\n") tableRender("System", func() [][]string {
table := [][]string{} return head(resp.System, 2)
})
}
for _, line := range lines { if resp.License != "" {
table = append(table, strings.Fields(line)) tableRender("License", func() [][]string {
return head(resp.License, 2)
})
} }
return renderSubTable(table, false)
return nil
} }
func CopyHandler(cmd *cobra.Command, args []string) error { func CopyHandler(cmd *cobra.Command, args []string) error {
...@@ -1086,7 +1162,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { ...@@ -1086,7 +1162,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
return nil return nil
} }
func RunServer(cmd *cobra.Command, _ []string) error { func RunServer(_ *cobra.Command, _ []string) error {
if err := initializeKeypair(); err != nil { if err := initializeKeypair(); err != nil {
return err return err
} }
...@@ -1205,7 +1281,7 @@ func NewCLI() *cobra.Command { ...@@ -1205,7 +1281,7 @@ func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile) log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false cobra.EnableCommandSorting = false
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" && term.IsTerminal(int(os.Stdout.Fd())) {
console.ConsoleFromFile(os.Stdin) //nolint:errcheck console.ConsoleFromFile(os.Stdin) //nolint:errcheck
} }
...@@ -1237,7 +1313,7 @@ func NewCLI() *cobra.Command { ...@@ -1237,7 +1313,7 @@ func NewCLI() *cobra.Command {
RunE: CreateHandler, RunE: CreateHandler,
} }
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile") createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
showCmd := &cobra.Command{ showCmd := &cobra.Command{
...@@ -1267,6 +1343,15 @@ func NewCLI() *cobra.Command { ...@@ -1267,6 +1343,15 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)") runCmd.Flags().String("format", "", "Response format (e.g. json)")
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: StopHandler,
}
serveCmd := &cobra.Command{ serveCmd := &cobra.Command{
Use: "serve", Use: "serve",
Aliases: []string{"start"}, Aliases: []string{"start"},
...@@ -1334,6 +1419,7 @@ func NewCLI() *cobra.Command { ...@@ -1334,6 +1419,7 @@ func NewCLI() *cobra.Command {
createCmd, createCmd,
showCmd, showCmd,
runCmd, runCmd,
stopCmd,
pullCmd, pullCmd,
pushCmd, pushCmd,
listCmd, listCmd,
...@@ -1360,6 +1446,8 @@ func NewCLI() *cobra.Command { ...@@ -1360,6 +1446,8 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"], envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_GPU_OVERHEAD"],
envVars["OLLAMA_LOAD_TIMEOUT"],
}) })
default: default:
appendEnvDocs(cmd, envs) appendEnvDocs(cmd, envs)
...@@ -1371,6 +1459,7 @@ func NewCLI() *cobra.Command { ...@@ -1371,6 +1459,7 @@ func NewCLI() *cobra.Command {
createCmd, createCmd,
showCmd, showCmd,
runCmd, runCmd,
stopCmd,
pullCmd, pullCmd,
pushCmd, pushCmd,
listCmd, listCmd,
......
package cmd
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
)
func TestShowInfo(t *testing.T) {
t.Run("bare details", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("bare model info", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(7_000_000_000),
"test.context_length": float64(0),
"test.embedding_length": float64(0),
},
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
context length 0
embedding length 0
quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("parameters", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Parameters: `
stop never
stop gonna
stop give
stop you
stop up
temperature 99`,
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
Parameters
stop never
stop gonna
stop give
stop you
stop up
temperature 99
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("project info", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
ProjectorInfo: map[string]any{
"general.architecture": "clip",
"general.parameter_count": float64(133_700_000),
"clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0),
},
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
Projector
architecture clip
parameters 133.70M
embedding length 0
dimensions 0
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("system", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
System: `You are a pirate!
Ahoy, matey!
Weigh anchor!
`,
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
System
You are a pirate!
Ahoy, matey!
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("license", func(t *testing.T) {
var b bytes.Buffer
license, err := os.ReadFile(filepath.Join("..", "LICENSE"))
if err != nil {
t.Fatal(err)
}
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
License: string(license),
}, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
License
MIT License
Copyright (c) Ollama
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
}
func TestDeleteHandler(t *testing.T) {
stopped := false
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/delete" && r.Method == http.MethodDelete {
var req api.DeleteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name == "test-model" {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusNotFound)
}
return
}
if r.URL.Path == "/api/generate" && r.Method == http.MethodPost {
var req api.GenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Model == "test-model" {
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{
Done: true,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
stopped = true
return
} else {
w.WriteHeader(http.StatusNotFound)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{
Done: false,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(context.TODO())
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
t.Fatalf("DeleteHandler failed: %v", err)
}
if !stopped {
t.Fatal("Model was not stopped before deletion")
}
err := DeleteHandler(cmd, []string{"test-model-not-found"})
if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
}
}
func TestGetModelfileName(t *testing.T) {
tests := []struct {
name string
modelfileName string
fileExists bool
expectedName string
expectedErr error
}{
{
name: "no modelfile specified, no modelfile exists",
modelfileName: "",
fileExists: false,
expectedName: "",
expectedErr: os.ErrNotExist,
},
{
name: "no modelfile specified, modelfile exists",
modelfileName: "",
fileExists: true,
expectedName: "Modelfile",
expectedErr: nil,
},
{
name: "modelfile specified, no modelfile exists",
modelfileName: "crazyfile",
fileExists: false,
expectedName: "crazyfile",
expectedErr: os.ErrNotExist,
},
{
name: "modelfile specified, modelfile exists",
modelfileName: "anotherfile",
fileExists: true,
expectedName: "anotherfile",
expectedErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{
Use: "fakecmd",
}
cmd.Flags().String("file", "", "path to modelfile")
var expectedFilename string
if tt.fileExists {
tempDir, err := os.MkdirTemp("", "modelfiledir")
defer os.RemoveAll(tempDir)
if err != nil {
t.Fatalf("temp modelfile dir creation failed: %v", err)
}
var fn string
if tt.modelfileName != "" {
fn = tt.modelfileName
} else {
fn = "Modelfile"
}
tempFile, err := os.CreateTemp(tempDir, fn)
if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err)
}
expectedFilename = tempFile.Name()
err = cmd.Flags().Set("file", expectedFilename)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
} else {
if tt.modelfileName != "" {
expectedFilename = tt.modelfileName
err := cmd.Flags().Set("file", tt.modelfileName)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
}
}
actualFilename, actualErr := getModelfileName(cmd)
if actualFilename != expectedFilename {
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
}
if tt.expectedErr != os.ErrNotExist {
if actualErr != tt.expectedErr {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
} else {
if !os.IsNotExist(actualErr) {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
}
})
}
}
func TestPushHandler(t *testing.T) {
tests := []struct {
name string
modelName string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
}{
{
name: "successful push",
modelName: "test-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
var req api.PushRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}
// Simulate progress updates
responses := []api.ProgressResponse{
{Status: "preparing manifest"},
{Digest: "sha256:abc123456789", Total: 100, Completed: 50},
{Digest: "sha256:abc123456789", Total: 100, Completed: 100},
}
for _, resp := range responses {
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.(http.Flusher).Flush()
}
},
},
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
},
{
name: "unauthorized push",
modelName: "unauthorized-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{
"error": "access denied",
})
if err != nil {
t.Fatal(err)
}
},
},
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
handler(w, r)
return
}
http.Error(w, "not found", http.StatusNotFound)
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
// Capture stdout for the "Model pushed" message
oldStdout := os.Stdout
outR, outW, _ := os.Pipe()
os.Stdout = outW
err := PushHandler(cmd, []string{tt.modelName})
// Restore stderr
w.Close()
os.Stderr = oldStderr
// drain the pipe
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
// Restore stdout and get output
outW.Close()
os.Stdout = oldStdout
stdout, _ := io.ReadAll(outR)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
}
})
}
}
...@@ -18,7 +18,6 @@ import ( ...@@ -18,7 +18,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
) )
...@@ -31,26 +30,6 @@ const ( ...@@ -31,26 +30,6 @@ const (
MultilineSystem MultilineSystem
) )
func loadModel(cmd *cobra.Command, opts *runOptions) error {
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
chatReq := &api.ChatRequest{
Model: opts.Model,
KeepAlive: opts.KeepAlive,
}
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
}
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usage := func() { usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
...@@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Model = args[1] opts.Model = args[1]
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
return err return err
} }
continue continue
...@@ -340,8 +319,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -340,8 +319,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = append(opts.Messages, newMessage) opts.Messages = append(opts.Messages, newMessage)
} }
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset()
sb.Reset() sb.Reset()
continue continue
default: default:
...@@ -371,7 +348,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -371,7 +348,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] { switch args[1] {
case "info": case "info":
showInfo(resp) _ = showInfo(resp, os.Stderr)
case "license": case "license":
if resp.License == "" { if resp.License == "" {
fmt.Println("No license was specified for this model.") fmt.Println("No license was specified for this model.")
...@@ -463,13 +440,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -463,13 +440,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
// clear all previous images for better responses
if len(images) > 0 {
for i := range opts.Messages {
opts.Messages[i].Images = nil
}
}
newMessage.Content = msg newMessage.Content = msg
newMessage.Images = images newMessage.Images = images
} }
...@@ -522,35 +492,29 @@ func buildModelfile(opts runOptions) string { ...@@ -522,35 +492,29 @@ func buildModelfile(opts runOptions) string {
} }
func normalizeFilePath(fp string) string { func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements return strings.NewReplacer(
replacements := map[string]string{ "\\ ", " ", // Escaped space
"\\ ": " ", // Escaped space "\\(", "(", // Escaped left parenthesis
"\\(": "(", // Escaped left parenthesis "\\)", ")", // Escaped right parenthesis
"\\)": ")", // Escaped right parenthesis "\\[", "[", // Escaped left square bracket
"\\[": "[", // Escaped left square bracket "\\]", "]", // Escaped right square bracket
"\\]": "]", // Escaped right square bracket "\\{", "{", // Escaped left curly brace
"\\{": "{", // Escaped left curly brace "\\}", "}", // Escaped right curly brace
"\\}": "}", // Escaped right curly brace "\\$", "$", // Escaped dollar sign
"\\$": "$", // Escaped dollar sign "\\&", "&", // Escaped ampersand
"\\&": "&", // Escaped ampersand "\\;", ";", // Escaped semicolon
"\\;": ";", // Escaped semicolon "\\'", "'", // Escaped single quote
"\\'": "'", // Escaped single quote "\\\\", "\\", // Escaped backslash
"\\\\": "\\", // Escaped backslash "\\*", "*", // Escaped asterisk
"\\*": "*", // Escaped asterisk "\\?", "?", // Escaped question mark
"\\?": "?", // Escaped question mark ).Replace(fp)
}
for escaped, actual := range replacements {
fp = strings.ReplaceAll(fp, escaped, actual)
}
return fp
} }
func extractFileNames(input string) []string { func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20) // Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension // and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches // This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b` regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
re := regexp.MustCompile(regexPattern) re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1) return re.FindAllString(input, -1)
...@@ -563,10 +527,9 @@ func extractFileData(input string) (string, []api.ImageData, error) { ...@@ -563,10 +527,9 @@ func extractFileData(input string) (string, []api.ImageData, error) {
for _, fp := range filePaths { for _, fp := range filePaths {
nfp := normalizeFilePath(fp) nfp := normalizeFilePath(fp)
data, err := getImageData(nfp) data, err := getImageData(nfp)
if err != nil { if errors.Is(err, os.ErrNotExist) {
if os.IsNotExist(err) { continue
continue } else if err != nil {
}
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err) fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
return "", imgs, err return "", imgs, err
} }
...@@ -574,7 +537,7 @@ func extractFileData(input string) (string, []api.ImageData, error) { ...@@ -574,7 +537,7 @@ func extractFileData(input string) (string, []api.ImageData, error) {
input = strings.ReplaceAll(input, fp, "") input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data) imgs = append(imgs, data)
} }
return input, imgs, nil return strings.TrimSpace(input), imgs, nil
} }
func getImageData(filePath string) ([]byte, error) { func getImageData(filePath string) ([]byte, error) {
......
...@@ -12,44 +12,45 @@ import ( ...@@ -12,44 +12,45 @@ import (
func TestExtractFilenames(t *testing.T) { func TestExtractFilenames(t *testing.T) {
// Unix style paths // Unix style paths
input := ` some preamble input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg` /unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
res := extractFileNames(input) res := extractFileNames(input)
assert.Len(t, res, 5) assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png") assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.svg") assert.Contains(t, res[4], "five.JPG")
assert.NotContains(t, res[4], '"') assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbtween") assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg")
// Windows style paths // Windows style paths
input = ` some preamble input = ` some preamble
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2 c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4 /absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6 ./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8 d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
` `
res = extractFileNames(input) res = extractFileNames(input)
assert.Len(t, res, 10) assert.Len(t, res, 10)
assert.NotContains(t, res, "inbtween") assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:") assert.Contains(t, res[0], "c:")
assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[1], "c:") assert.Contains(t, res[1], "c:")
assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png") assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.svg") assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.png") assert.Contains(t, res[5], "six.png")
assert.Contains(t, res[6], "seven.svg") assert.Contains(t, res[6], "seven.JPEG")
assert.Contains(t, res[6], "d:") assert.Contains(t, res[6], "d:")
assert.Contains(t, res[7], "eight.png") assert.Contains(t, res[7], "eight.png")
assert.Contains(t, res[7], "c:") assert.Contains(t, res[7], "c:")
assert.Contains(t, res[8], "nine.png") assert.Contains(t, res[8], "nine.png")
assert.Contains(t, res[8], "d:") assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.svg") assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:") assert.Contains(t, res[9], "E:")
} }
......
...@@ -7,16 +7,27 @@ import ( ...@@ -7,16 +7,27 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"strings"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
type Parameters struct { type ModelParameters struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
} }
func (Parameters) KV(t *Tokenizer) llm.KV { type AdapterParameters struct {
Alpha uint32 `json:"lora_alpha"`
LoraLayers uint32 `json:"lora_layers"`
LoraParameters struct {
Rank uint32 `json:"rank"`
Alpha float32 `json:"alpha"`
Scale float32 `json:"scale"`
} `json:"lora_parameters"`
}
func (ModelParameters) KV(t *Tokenizer) llm.KV {
kv := llm.KV{ kv := llm.KV{
"general.file_type": uint32(1), "general.file_type": uint32(1),
"general.quantization_version": uint32(2), "general.quantization_version": uint32(2),
...@@ -27,6 +38,10 @@ func (Parameters) KV(t *Tokenizer) llm.KV { ...@@ -27,6 +38,10 @@ func (Parameters) KV(t *Tokenizer) llm.KV {
"tokenizer.ggml.token_type": t.Vocabulary.Types, "tokenizer.ggml.token_type": t.Vocabulary.Types,
} }
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
if t.Template != "" { if t.Template != "" {
kv["tokenizer.chat_template"] = t.Template kv["tokenizer.chat_template"] = t.Template
} }
...@@ -39,40 +54,119 @@ func (Parameters) KV(t *Tokenizer) llm.KV { ...@@ -39,40 +54,119 @@ func (Parameters) KV(t *Tokenizer) llm.KV {
return kv return kv
} }
func (Parameters) specialTokenTypes() []string { func (p AdapterParameters) KV() llm.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
} else {
alpha = p.LoraParameters.Alpha
}
kv := llm.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
"general.type": "adapter",
"general.version": "v0.2",
}
return kv
}
func (ModelParameters) specialTokenTypes() []string {
return []string{ return []string{
"bos", "eos", "unk", "sep", "pad", "cls", "mask", "bos", "eos", "unk", "sep", "pad", "cls", "mask",
} }
} }
func (Parameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error { func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts) return llm.WriteGGUF(ws, kv, ts)
} }
type Converter interface { func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts)
}
type ModelConverter interface {
// KV maps parameters to LLM key-values // KV maps parameters to LLM key-values
KV(*Tokenizer) llm.KV KV(*Tokenizer) llm.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []llm.Tensor Tensors([]Tensor) []llm.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
// tensorName returns the LLM tensor name for a specific input name
tensorName(string) string
// specialTokenTypes returns any special token types the model uses // specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
} }
type moreParser interface {
parseMore(fs.FS) error
}
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(llm.KV) llm.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []llm.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
}
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
}
var p AdapterParameters
if err := json.Unmarshal(bts, &p); err != nil {
return err
}
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
var conv AdapterConverter
switch arch {
case "llama":
conv = &llamaAdapter{}
case "gemma2":
conv = &gemma2Adapter{}
default:
return errors.New("unsupported architecture")
}
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
return err
}
if err := json.Unmarshal(bts, conv); err != nil {
return err
}
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path. // and files it finds in the input path.
// Supported input model formats include safetensors. // Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func Convert(fsys fs.FS, ws io.WriteSeeker) error { func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json") bts, err := fs.ReadFile(fsys, "config.json")
if err != nil { if err != nil {
return err return err
} }
var p Parameters var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil { if err := json.Unmarshal(bts, &p); err != nil {
return err return err
} }
...@@ -81,14 +175,20 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error { ...@@ -81,14 +175,20 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
return errors.New("unknown architecture") return errors.New("unknown architecture")
} }
var conv Converter var conv ModelConverter
switch p.Architectures[0] { switch p.Architectures[0] {
case "LlamaForCausalLM", "MistralForCausalLM": case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llama{} conv = &llamaModel{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
conv = &mixtral{} conv = &mixtralModel{}
case "GemmaForCausalLM": case "GemmaForCausalLM":
conv = &gemma{} conv = &gemmaModel{}
case "Gemma2ForCausalLM":
conv = &gemma2Model{}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "BertModel":
conv = &bertModel{}
default: default:
return errors.New("unsupported architecture") return errors.New("unsupported architecture")
} }
...@@ -97,23 +197,33 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error { ...@@ -97,23 +197,33 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
return err return err
} }
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes()) t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil { if err != nil {
return err return err
} }
if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) { vocabSize := int(p.VocabSize)
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens)) switch {
case vocabSize > len(t.Vocabulary.Tokens):
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) { for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i)) t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1) t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined) t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
} }
} else { case vocabSize < len(t.Vocabulary.Tokens):
return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
} }
ts, err := parseTensors(fsys) ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil { if err != nil {
return err return err
} }
......
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/llm"
)
type bertModel struct {
ModelParameters
NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NLayer uint32 `json:"n_layer"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NCtx uint32 `json:"n_ctx"`
HiddenSize uint32 `json:"hidden_size"`
NEmbd uint32 `json:"n_embd"`
IntermediateSize uint32 `json:"intermediate_size"`
NInner uint32 `json:"n_inner"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NHead uint32 `json:"n_head"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
PoolingType uint32
}
var (
_ ModelConverter = (*bertModel)(nil)
_ moreParser = (*bertModel)(nil)
)
func (p *bertModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "modules.json")
if err != nil {
return err
}
var modules []struct {
Type string `json:"type"`
Path string `json:"path"`
}
if err := json.Unmarshal(bts, &modules); err != nil {
return err
}
var pooling string
for _, m := range modules {
if m.Type == "sentence_transformers.models.Pooling" {
pooling = m.Path
break
}
}
if pooling != "" {
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
if err != nil {
return err
}
var pc struct {
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
}
if err := json.Unmarshal(bts, &pc); err != nil {
return err
}
if pc.PoolingModeMeanTokens {
p.PoolingType = 1
} else if pc.PoolingModeCLSToken {
p.PoolingType = 2
}
}
return nil
}
func (p *bertModel) KV(t *Tokenizer) llm.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
if contextLength := cmp.Or(p.MaxPositionEmbeddings, p.NCtx); contextLength > 0 {
kv["bert.context_length"] = contextLength
}
if embeddingLength := cmp.Or(p.HiddenSize, p.NEmbd); embeddingLength > 0 {
kv["bert.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
}
if feedForwardLength := cmp.Or(p.IntermediateSize, p.NInner); feedForwardLength > 0 {
kv["bert.feed_forward_length"] = cmp.Or(p.IntermediateSize, p.NInner)
}
if headCount := cmp.Or(p.NumAttentionHeads, p.NHead); headCount > 0 {
kv["bert.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
}
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon, p.NormEpsilon); layerNormEpsilon > 0 {
kv["bert.attention.layer_norm_epsilon"] = layerNormEpsilon
}
kv["tokenizer.ggml.model"] = "bert"
kv["tokenizer.ggml.token_type_count"] = uint32(2)
// convert to phantom space tokens
for i, e := range t.Tokens {
if strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]") {
// noop
} else if strings.HasPrefix(e, "##") {
t.Tokens[i] = e[2:]
} else {
t.Tokens[i] = "\u2581" + e
}
}
kv["tokenizer.ggml.tokens"] = t.Tokens
return kv
}
func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
"pooler.dense.weight",
"pooler.dense.bias",
}, t.Name()) {
continue
}
out = append(out, llm.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (bertModel) Replacements() []string {
return []string{
"encoder.layer", "blk",
"encoder.layers", "blk",
"embeddings.word_embeddings", "token_embd",
"embeddings.token_type_embeddings", "token_types",
"embeddings.LayerNorm", "token_embd_norm",
"embeddings.position_embeddings", "position_embd",
"attention.self.query", "attn_q",
"attention.self.key", "attn_k",
"attention.self.value", "attn_v",
"attention.output.dense", "attn_output",
"attention.output.LayerNorm", "attn_output_norm",
"intermediate.dense", "ffn_up",
"output.dense", "ffn_down",
"output.LayerNorm", "layer_output_norm",
}
}
...@@ -9,8 +9,8 @@ import ( ...@@ -9,8 +9,8 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
type gemma struct { type gemmaModel struct {
Parameters ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"` HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"` HiddenLayers uint32 `json:"num_hidden_layers"`
...@@ -21,12 +21,11 @@ type gemma struct { ...@@ -21,12 +21,11 @@ type gemma struct {
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
} }
var _ Converter = (*gemma)(nil) var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemma) KV(t *Tokenizer) llm.KV { func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
kv := p.Parameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma" kv["general.architecture"] = "gemma"
kv["general.name"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings kv["gemma.context_length"] = p.MaxPositionEmbeddings
kv["gemma.embedding_length"] = p.HiddenSize kv["gemma.embedding_length"] = p.HiddenSize
kv["gemma.block_count"] = p.HiddenLayers kv["gemma.block_count"] = p.HiddenLayers
...@@ -43,16 +42,15 @@ func (p *gemma) KV(t *Tokenizer) llm.KV { ...@@ -43,16 +42,15 @@ func (p *gemma) KV(t *Tokenizer) llm.KV {
return kv return kv
} }
func (p *gemma) Tensors(ts []Tensor) []llm.Tensor { func (p *gemmaModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor var out []llm.Tensor
for _, t := range ts { for _, t := range ts {
name := p.tensorName(t.Name()) if strings.HasSuffix(t.Name(), "_norm.weight") {
if strings.HasSuffix(name, "_norm.weight") {
t.SetRepacker(p.addOne) t.SetRepacker(p.addOne)
} }
out = append(out, llm.Tensor{ out = append(out, llm.Tensor{
Name: name, Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
...@@ -62,8 +60,8 @@ func (p *gemma) Tensors(ts []Tensor) []llm.Tensor { ...@@ -62,8 +60,8 @@ func (p *gemma) Tensors(ts []Tensor) []llm.Tensor {
return out return out
} }
func (p *gemma) tensorName(n string) string { func (p *gemmaModel) Replacements() []string {
return strings.NewReplacer( return []string{
"model.embed_tokens", "token_embd", "model.embed_tokens", "token_embd",
"model.norm", "output_norm", "model.norm", "output_norm",
"model.layers", "blk", "model.layers", "blk",
...@@ -76,11 +74,10 @@ func (p *gemma) tensorName(n string) string { ...@@ -76,11 +74,10 @@ func (p *gemma) tensorName(n string) string {
"mlp.down_proj", "ffn_down", "mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up", "mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm", "post_attention_layernorm", "ffn_norm",
"block_sparse_moe.gate", "ffn_inp", }
).Replace(n)
} }
func (*gemma) addOne(_ string, data []float32, shape []uint64) ([]float32, error) { func (*gemmaModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data)) n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0])) ones := tensor.Ones(tensor.Float32, int(shape[0]))
......
package convert
import (
"github.com/ollama/ollama/llm"
)
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
AttentionLogitSoftcap float32 `json:"attn_logit_softcapping"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) llm.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
kv["gemma2.embedding_length"] = p.HiddenSize
kv["gemma2.block_count"] = p.HiddenLayers
kv["gemma2.feed_forward_length"] = p.IntermediateSize
kv["gemma2.attention.head_count"] = p.NumAttentionHeads
kv["gemma2.attention.head_count_kv"] = p.NumKeyValueHeads
kv["gemma2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma2.attention.key_length"] = p.HeadDim
kv["gemma2.attention.value_length"] = p.HeadDim
kv["gemma2.attention.sliding_window"] = p.SlidingWindow
kv["gemma2.attn_logit_softcapping"] = p.AttentionLogitSoftcap
kv["gemma2.final_logit_softcapping"] = p.FinalLogitSoftcap
kv["tokenizer.ggml.eot_token_id"] = uint32(107)
kv["tokenizer.ggml.middle_token_id"] = uint32(68)
kv["tokenizer.ggml.prefix_token_id"] = uint32(67)
kv["tokenizer.ggml.suffix_token_id"] = uint32(69)
return kv
}
func (p *gemma2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm",
}
}
package convert
import (
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type gemma2Adapter struct {
AdapterParameters
}
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV llm.KV) llm.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv
}
func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
shape[0], shape[1] = shape[1], shape[0]
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *gemma2Adapter) Replacements() []string {
return []string{
"base_model.model.", "",
"model.layers", "blk",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"lora_A.weight", "weight.lora_a",
"lora_B.weight", "weight.lora_b",
"lora_a", "weight.lora_a",
"lora_b", "weight.lora_b",
}
}
func (p *gemma2Adapter) repack(name string, data []float32, shape []uint64) ([]float32, error) {
dims := []int{int(shape[1]), int(shape[0])}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.T(1, 0); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
...@@ -3,6 +3,7 @@ package convert ...@@ -3,6 +3,7 @@ package convert
import ( import (
"cmp" "cmp"
"fmt" "fmt"
"math"
"strings" "strings"
"github.com/pdevine/tensor" "github.com/pdevine/tensor"
...@@ -11,8 +12,8 @@ import ( ...@@ -11,8 +12,8 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
type llama struct { type llamaModel struct {
Parameters ModelParameters
NLayers uint32 `json:"n_layers"` NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"` NumHiddenLayers uint32 `json:"num_hidden_layers"`
NLayer uint32 `json:"n_layer"` NLayer uint32 `json:"n_layer"`
...@@ -27,8 +28,14 @@ type llama struct { ...@@ -27,8 +28,14 @@ type llama struct {
NumKeyValueHeads uint32 `json:"num_key_value_heads"` NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
RopeScaling struct { RopeScaling struct {
Type string `json:"type"` Type string `json:"type"`
Factor float32 `json:"factor"` RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor
} `json:"rope_scaling"` } `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"` RMSNormEPS float32 `json:"rms_norm_eps"`
LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEPS float32 `json:"layer_norm_eps"`
...@@ -37,12 +44,11 @@ type llama struct { ...@@ -37,12 +44,11 @@ type llama struct {
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
} }
var _ Converter = (*llama)(nil) var _ ModelConverter = (*llamaModel)(nil)
func (p *llama) KV(t *Tokenizer) llm.KV { func (p *llamaModel) KV(t *Tokenizer) llm.KV {
kv := p.Parameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama" kv["general.architecture"] = "llama"
kv["general.name"] = "llama"
kv["llama.vocab_size"] = p.VocabSize kv["llama.vocab_size"] = p.VocabSize
kv["llama.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) kv["llama.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
...@@ -71,6 +77,27 @@ func (p *llama) KV(t *Tokenizer) llm.KV { ...@@ -71,6 +77,27 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
if p.RopeScaling.Type == "linear" { if p.RopeScaling.Type == "linear" {
kv["llama.rope.scaling.type"] = p.RopeScaling.Type kv["llama.rope.scaling.type"] = p.RopeScaling.Type
kv["llama.rope.scaling.factor"] = p.RopeScaling.Factor kv["llama.rope.scaling.factor"] = p.RopeScaling.Factor
} else if p.RopeScaling.RopeType == "llama3" {
dim := p.HiddenSize / p.NumAttentionHeads
for i := uint32(0); i < dim; i += 2 {
factor := cmp.Or(p.RopeScaling.Factor, 8.0)
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh
lambda := 2 * math.Pi * math.Pow(float64(p.RopeTheta), float64(i)/float64(dim))
if lambda < float64(lambdaHigh) {
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0)
} else if lambda > float64(lambdaLow) {
p.RopeScaling.factors = append(p.RopeScaling.factors, factor)
} else {
smooth := (float32(original)/float32(lambda) - factorLow) / (factorHigh - factorLow)
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0/((1-smooth)/factor+smooth))
}
}
} }
if p.NumKeyValueHeads > 0 { if p.NumKeyValueHeads > 0 {
...@@ -90,24 +117,29 @@ func (p *llama) KV(t *Tokenizer) llm.KV { ...@@ -90,24 +117,29 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
kv["llama.attention.value_length"] = p.HeadDim kv["llama.attention.value_length"] = p.HeadDim
} }
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
return kv return kv
} }
func (p *llama) Tensors(ts []Tensor) []llm.Tensor { func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor var out []llm.Tensor
if p.RopeScaling.factors != nil {
out = append(out, llm.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
WriterTo: p.RopeScaling.factors,
})
}
for _, t := range ts { for _, t := range ts {
name := p.tensorName(t.Name()) if strings.HasSuffix(t.Name(), "attn_q.weight") ||
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
strings.HasSuffix(name, "attn_k.weight") {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
} }
out = append(out, llm.Tensor{ out = append(out, llm.Tensor{
Name: name, Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
...@@ -117,8 +149,8 @@ func (p *llama) Tensors(ts []Tensor) []llm.Tensor { ...@@ -117,8 +149,8 @@ func (p *llama) Tensors(ts []Tensor) []llm.Tensor {
return out return out
} }
func (p *llama) tensorName(n string) string { func (p *llamaModel) Replacements() []string {
return strings.NewReplacer( return []string{
"lm_head", "output", "lm_head", "output",
"model.embed_tokens", "token_embd", "model.embed_tokens", "token_embd",
"model.norm", "output_norm", "model.norm", "output_norm",
...@@ -132,21 +164,19 @@ func (p *llama) tensorName(n string) string { ...@@ -132,21 +164,19 @@ func (p *llama) tensorName(n string) string {
"mlp.down_proj", "ffn_down", "mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up", "mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm", "post_attention_layernorm", "ffn_norm",
// mixtral }
"block_sparse_moe.gate", "ffn_gate_inp",
).Replace(n)
} }
func (p *llama) repack(name string, data []float32, shape []uint64) ([]float32, error) { func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int var dims []int
for _, dim := range shape { for _, dim := range shape {
dims = append(dims, int(dim)) dims = append(dims, int(dim))
} }
var heads uint32 var heads uint32
if strings.HasSuffix(name, "q_proj.weight") { if strings.HasSuffix(name, "attn_q.weight") {
heads = p.NumAttentionHeads heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "k_proj.weight") { } else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else { } else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name) return nil, fmt.Errorf("unknown tensor for repack: %s", name)
......
package convert
import (
"cmp"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type llamaAdapter struct {
AdapterParameters
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
}
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv
}
func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
shape[0], shape[1] = shape[1], shape[0]
t.SetRepacker(p.repackAndTranspose)
} else {
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
}
return out
}
func (p *llamaAdapter) Replacements() []string {
return []string{
"base_model.model.", "",
"model.layers", "blk",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"lora_A.weight", "weight.lora_a",
"lora_B.weight", "weight.lora_b",
"lora_a", "weight.lora_a",
"lora_b", "weight.lora_b",
}
}
func (p *llamaAdapter) repack(name string, data []float32, shape []uint64) ([]float32, error) {
dims := []int{int(shape[1]), int(shape[0])}
var heads uint32
if strings.HasSuffix(name, "attn_q.weight.lora_a") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight.lora_a") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return data, nil
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
func (p *llamaAdapter) repackAndTranspose(name string, data []float32, shape []uint64) ([]float32, error) {
dims := []int{int(shape[1]), int(shape[0])}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var heads uint32
if strings.HasSuffix(name, "attn_q.weight.lora_a") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight.lora_a") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
}
if heads > 0 {
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
}
if err := n.T(1, 0); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment