Commit 0ce8bcfd authored by xuxzh1's avatar xuxzh1 🎱
Browse files

init

parent b0135f4b
...@@ -88,10 +88,15 @@ DialogFontSize=12 ...@@ -88,10 +88,15 @@ DialogFontSize=12
[Files] [Files]
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
#if DirExists("..\dist\windows-amd64\cuda")
Source: "..\dist\windows-amd64\cuda\*"; DestDir: "{app}\cuda\"; Flags: ignoreversion recursesubdirs
#endif
#if DirExists("..\dist\windows-amd64\oneapi")
Source: "..\dist\windows-amd64\oneapi\*"; DestDir: "{app}\oneapi\"; Flags: ignoreversion recursesubdirs
#endif
#if DirExists("..\dist\windows-amd64\rocm") #if DirExists("..\dist\windows-amd64\rocm")
Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
#endif #endif
...@@ -122,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models" ...@@ -122,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history" Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved ; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
[InstallDelete]
Type: filesandordirs; Name: "{%TEMP}\ollama*"
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
[Messages] [Messages]
WizardReady=Ollama Windows Preview WizardReady=Ollama Windows Preview
ReadyLabel1=%nLet's get you up and running with your own large language models. ReadyLabel1=%nLet's get you up and running with your own large language models.
...@@ -129,7 +138,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi ...@@ -129,7 +138,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
;FinishedHeadingLabel=Run your first model ;FinishedHeadingLabel=Run your first model
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3 ;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3.1
;ClickFinish=%n ;ClickFinish=%n
[Registry] [Registry]
......
...@@ -4,5 +4,5 @@ write-host "Welcome to Ollama!" ...@@ -4,5 +4,5 @@ write-host "Welcome to Ollama!"
write-host "" write-host ""
write-host "Run your first model:" write-host "Run your first model:"
write-host "" write-host ""
write-host "`tollama run llama3" write-host "`tollama run llama3.1"
write-host "" write-host ""
\ No newline at end of file
...@@ -29,7 +29,6 @@ func GetID() string { ...@@ -29,7 +29,6 @@ func GetID() string {
initStore() initStore()
} }
return store.ID return store.ID
} }
func GetFirstTimeRun() bool { func GetFirstTimeRun() bool {
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
package tray package tray
import ( import (
"fmt" "errors"
"github.com/ollama/ollama/app/tray/commontray" "github.com/ollama/ollama/app/tray/commontray"
) )
func InitPlatformTray(icon, updateIcon []byte) (commontray.OllamaTray, error) { func InitPlatformTray(icon, updateIcon []byte) (commontray.OllamaTray, error) {
return nil, fmt.Errorf("NOT IMPLEMENTED YET") return nil, errors.New("not implemented")
} }
...@@ -11,9 +11,7 @@ import ( ...@@ -11,9 +11,7 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
var ( var quitOnce sync.Once
quitOnce sync.Once
)
func (t *winTray) Run() { func (t *winTray) Run() {
nativeLoop() nativeLoop()
...@@ -47,7 +45,6 @@ func nativeLoop() { ...@@ -47,7 +45,6 @@ func nativeLoop() {
default: default:
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
} }
} }
} }
...@@ -160,8 +157,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui ...@@ -160,8 +157,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
lResult, _, _ = pDefWindowProc.Call( lResult, _, _ = pDefWindowProc.Call(
uintptr(hWnd), uintptr(hWnd),
uintptr(message), uintptr(message),
uintptr(wParam), wParam,
uintptr(lParam), lParam,
) )
} }
return return
......
...@@ -13,8 +13,9 @@ import ( ...@@ -13,8 +13,9 @@ import (
"sync" "sync"
"unsafe" "unsafe"
"github.com/ollama/ollama/app/tray/commontray"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/ollama/ollama/app/tray/commontray"
) )
// Helpful sources: https://github.com/golang/exp/blob/master/shiny/driver/internal/win32 // Helpful sources: https://github.com/golang/exp/blob/master/shiny/driver/internal/win32
...@@ -186,7 +187,7 @@ func (t *winTray) initInstance() error { ...@@ -186,7 +187,7 @@ func (t *winTray) initInstance() error {
t.muNID.Lock() t.muNID.Lock()
defer t.muNID.Unlock() defer t.muNID.Unlock()
t.nid = &notifyIconData{ t.nid = &notifyIconData{
Wnd: windows.Handle(t.window), Wnd: t.window,
ID: 100, ID: 100,
Flags: NIF_MESSAGE, Flags: NIF_MESSAGE,
CallbackMessage: t.wmSystrayMessage, CallbackMessage: t.wmSystrayMessage,
...@@ -197,7 +198,6 @@ func (t *winTray) initInstance() error { ...@@ -197,7 +198,6 @@ func (t *winTray) initInstance() error {
} }
func (t *winTray) createMenu() error { func (t *winTray) createMenu() error {
menuHandle, _, err := pCreatePopupMenu.Call() menuHandle, _, err := pCreatePopupMenu.Call()
if menuHandle == 0 { if menuHandle == 0 {
return err return err
...@@ -246,7 +246,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title ...@@ -246,7 +246,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
mi := menuItemInfo{ mi := menuItemInfo{
Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE, Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE,
Type: MFT_STRING, Type: MFT_STRING,
ID: uint32(menuItemId), ID: menuItemId,
TypeData: titlePtr, TypeData: titlePtr,
Cch: uint32(len(title)), Cch: uint32(len(title)),
} }
...@@ -302,11 +302,10 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title ...@@ -302,11 +302,10 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
} }
func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error { func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
mi := menuItemInfo{ mi := menuItemInfo{
Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE, Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE,
Type: MFT_SEPARATOR, Type: MFT_SEPARATOR,
ID: uint32(menuItemId), ID: menuItemId,
} }
mi.Size = uint32(unsafe.Sizeof(mi)) mi.Size = uint32(unsafe.Sizeof(mi))
...@@ -416,7 +415,7 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) { ...@@ -416,7 +415,7 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) {
iconFilePath := filepath.Join(os.TempDir(), "ollama_temp_icon_"+dataHash) iconFilePath := filepath.Join(os.TempDir(), "ollama_temp_icon_"+dataHash)
if _, err := os.Stat(iconFilePath); os.IsNotExist(err) { if _, err := os.Stat(iconFilePath); os.IsNotExist(err) {
if err := os.WriteFile(iconFilePath, iconBytes, 0644); err != nil { if err := os.WriteFile(iconFilePath, iconBytes, 0o644); err != nil {
return "", err return "", err
} }
} }
...@@ -426,7 +425,6 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) { ...@@ -426,7 +425,6 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) {
// Loads an image from file and shows it in tray. // Loads an image from file and shows it in tray.
// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx // Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx
func (t *winTray) setIcon(src string) error { func (t *winTray) setIcon(src string) error {
h, err := t.loadIconFrom(src) h, err := t.loadIconFrom(src)
if err != nil { if err != nil {
return err return err
...@@ -444,7 +442,6 @@ func (t *winTray) setIcon(src string) error { ...@@ -444,7 +442,6 @@ func (t *winTray) setIcon(src string) error {
// Loads an image from file to be shown in tray or menu item. // Loads an image from file to be shown in tray or menu item.
// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx // LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx
func (t *winTray) loadIconFrom(src string) (windows.Handle, error) { func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
// Save and reuse handles of loaded images // Save and reuse handles of loaded images
t.muLoadedImages.RLock() t.muLoadedImages.RLock()
h, ok := t.loadedImages[src] h, ok := t.loadedImages[src]
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
...@@ -78,7 +79,7 @@ func Sign(ctx context.Context, bts []byte) (string, error) { ...@@ -78,7 +79,7 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
parts := bytes.Split(publicKey, []byte(" ")) parts := bytes.Split(publicKey, []byte(" "))
if len(parts) < 2 { if len(parts) < 2 {
return "", fmt.Errorf("malformed public key") return "", errors.New("malformed public key")
} }
signedData, err := privateKey.Sign(rand.Reader, bts) signedData, err := privateKey.Sign(rand.Reader, bts)
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"runtime" "runtime"
"slices"
"strings" "strings"
"syscall" "syscall"
"time" "time"
...@@ -29,7 +30,6 @@ import ( ...@@ -29,7 +30,6 @@ import (
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
...@@ -162,9 +162,6 @@ func tempZipFiles(path string) (string, error) { ...@@ -162,9 +162,6 @@ func tempZipFiles(path string) (string, error) {
} }
defer tempfile.Close() defer tempfile.Close()
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
detectContentType := func(path string) (string, error) { detectContentType := func(path string) (string, error) {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
...@@ -233,6 +230,9 @@ func tempZipFiles(path string) (string, error) { ...@@ -233,6 +230,9 @@ func tempZipFiles(path string) (string, error) {
files = append(files, tks...) files = append(files, tks...)
} }
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
for _, file := range files { for _, file := range files {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
...@@ -287,38 +287,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er ...@@ -287,38 +287,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
} }
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
name := args[0]
// check if the model exists on the server
show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
if err := PullHandler(cmd, []string{name}); err != nil {
return err
}
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
if err != nil {
return err
}
case err != nil:
return err
}
interactive := true interactive := true
opts := runOptions{ opts := runOptions{
Model: args[0], Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color", WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{}, Options: map[string]interface{}{},
MultiModal: slices.Contains(show.Details.Families, "clip"),
ParentModel: show.Details.ParentModel,
} }
format, err := cmd.Flags().GetString("format") format, err := cmd.Flags().GetString("format")
...@@ -362,11 +336,53 @@ func RunHandler(cmd *cobra.Command, args []string) error { ...@@ -362,11 +336,53 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.WordWrap = !nowrap opts.WordWrap = !nowrap
if !interactive { // Fill out the rest of the options based on information about the
return generate(cmd, opts) // model.
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
}
return info, err
}()
if err != nil {
return err
} }
return generateInteractive(cmd, opts) opts.MultiModal = slices.Contains(info.Details.Families, "clip")
opts.ParentModel = info.Details.ParentModel
if interactive {
if err := loadModel(cmd, &opts); err != nil {
return err
}
for _, msg := range info.Messages {
switch msg.Role {
case "user":
fmt.Printf(">>> %s\n", msg.Content)
case "assistant":
state := &displayResponseState{}
displayResponse(msg.Content, opts.WordWrap, state)
fmt.Println()
fmt.Println()
}
}
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
} }
func errFromUnknownKey(unknownKeyErr error) error { func errFromUnknownKey(unknownKeyErr error) error {
...@@ -579,10 +595,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error { ...@@ -579,10 +595,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
if len(args) != 1 {
return errors.New("missing model name")
}
license, errLicense := cmd.Flags().GetBool("license") license, errLicense := cmd.Flags().GetBool("license")
modelfile, errModelfile := cmd.Flags().GetBool("modelfile") modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
parameters, errParams := cmd.Flags().GetBool("parameters") parameters, errParams := cmd.Flags().GetBool("parameters")
...@@ -625,8 +637,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error { ...@@ -625,8 +637,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
if flagsSet > 1 { if flagsSet > 1 {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified") return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
} else if flagsSet == 0 {
return errors.New("one of '--license', '--modelfile', '--parameters', '--system', or '--template' must be specified")
} }
req := api.ShowRequest{Name: args[0]} req := api.ShowRequest{Name: args[0]}
...@@ -635,22 +645,141 @@ func ShowHandler(cmd *cobra.Command, args []string) error { ...@@ -635,22 +645,141 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
switch showType { if flagsSet == 1 {
case "license": switch showType {
fmt.Println(resp.License) case "license":
case "modelfile": fmt.Println(resp.License)
fmt.Println(resp.Modelfile) case "modelfile":
case "parameters": fmt.Println(resp.Modelfile)
fmt.Println(resp.Parameters) case "parameters":
case "system": fmt.Println(resp.Parameters)
fmt.Println(resp.System) case "system":
case "template": fmt.Println(resp.System)
fmt.Println(resp.Template) case "template":
fmt.Println(resp.Template)
}
return nil
} }
showInfo(resp)
return nil return nil
} }
func showInfo(resp *api.ShowResponse) {
arch := resp.ModelInfo["general.architecture"].(string)
modelData := [][]string{
{"arch", arch},
{"parameters", resp.Details.ParameterSize},
{"quantization", resp.Details.QuantizationLevel},
{"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))},
{"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))},
}
mainTableData := [][]string{
{"Model"},
{renderSubTable(modelData, false)},
}
if resp.ProjectorInfo != nil {
projectorData := [][]string{
{"arch", "clip"},
{"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
}
if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok {
projectorData = append(projectorData, []string{"projector type", projectorType.(string)})
}
projectorData = append(projectorData,
[]string{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))},
[]string{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},
)
mainTableData = append(mainTableData,
[]string{"Projector"},
[]string{renderSubTable(projectorData, false)},
)
}
if resp.Parameters != "" {
mainTableData = append(mainTableData, []string{"Parameters"}, []string{formatParams(resp.Parameters)})
}
if resp.System != "" {
mainTableData = append(mainTableData, []string{"System"}, []string{renderSubTable(twoLines(resp.System), true)})
}
if resp.License != "" {
mainTableData = append(mainTableData, []string{"License"}, []string{renderSubTable(twoLines(resp.License), true)})
}
table := tablewriter.NewWriter(os.Stdout)
table.SetAutoWrapText(false)
table.SetBorder(false)
table.SetAlignment(tablewriter.ALIGN_LEFT)
for _, v := range mainTableData {
table.Append(v)
}
table.Render()
}
func renderSubTable(data [][]string, file bool) string {
var buf bytes.Buffer
table := tablewriter.NewWriter(&buf)
table.SetAutoWrapText(!file)
table.SetBorder(false)
table.SetNoWhiteSpace(true)
table.SetTablePadding("\t")
table.SetAlignment(tablewriter.ALIGN_LEFT)
for _, v := range data {
table.Append(v)
}
table.Render()
renderedTable := buf.String()
lines := strings.Split(renderedTable, "\n")
for i, line := range lines {
lines[i] = "\t" + line
}
return strings.Join(lines, "\n")
}
func twoLines(s string) [][]string {
lines := strings.Split(s, "\n")
res := [][]string{}
count := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" {
count++
res = append(res, []string{line})
if count == 2 {
return res
}
}
}
return res
}
func formatParams(s string) string {
lines := strings.Split(s, "\n")
table := [][]string{}
for _, line := range lines {
table = append(table, strings.Fields(line))
}
return renderSubTable(table, false)
}
func CopyHandler(cmd *cobra.Command, args []string) error { func CopyHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
...@@ -729,7 +858,6 @@ type runOptions struct { ...@@ -729,7 +858,6 @@ type runOptions struct {
WordWrap bool WordWrap bool
Format string Format string
System string System string
Template string
Images []api.ImageData Images []api.ImageData
Options map[string]interface{} Options map[string]interface{}
MultiModal bool MultiModal bool
...@@ -746,7 +874,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState) ...@@ -746,7 +874,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
if wordWrap && termWidth >= 10 { if wordWrap && termWidth >= 10 {
for _, ch := range content { for _, ch := range content {
if state.lineLength+1 > termWidth-5 { if state.lineLength+1 > termWidth-5 {
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 { if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", state.wordBuffer, ch) fmt.Printf("%s%c", state.wordBuffer, ch)
state.wordBuffer = "" state.wordBuffer = ""
...@@ -924,7 +1051,6 @@ func generate(cmd *cobra.Command, opts runOptions) error { ...@@ -924,7 +1051,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
Images: opts.Images, Images: opts.Images,
Format: opts.Format, Format: opts.Format,
System: opts.System, System: opts.System,
Template: opts.Template,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
} }
...@@ -961,17 +1087,11 @@ func generate(cmd *cobra.Command, opts runOptions) error { ...@@ -961,17 +1087,11 @@ func generate(cmd *cobra.Command, opts runOptions) error {
} }
func RunServer(cmd *cobra.Command, _ []string) error { func RunServer(cmd *cobra.Command, _ []string) error {
// retrieve the OLLAMA_HOST environment variable
ollamaHost, err := api.GetOllamaHost()
if err != nil {
return err
}
if err := initializeKeypair(); err != nil { if err := initializeKeypair(); err != nil {
return err return err
} }
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port)) ln, err := net.Listen("tcp", envconfig.Host().Host)
if err != nil { if err != nil {
return err return err
} }
...@@ -1030,24 +1150,6 @@ func initializeKeypair() error { ...@@ -1030,24 +1150,6 @@ func initializeKeypair() error {
return nil return nil
} }
//nolint:unused
func waitForServer(ctx context.Context, client *api.Client) error {
// wait for the server to start
timeout := time.After(5 * time.Second)
tick := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
...@@ -1058,7 +1160,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { ...@@ -1058,7 +1160,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err return err
} }
if err := startApp(cmd.Context(), client); err != nil { if err := startApp(cmd.Context(), client); err != nil {
return fmt.Errorf("could not connect to ollama app, is it running?") return errors.New("could not connect to ollama app, is it running?")
} }
} }
return nil return nil
...@@ -1254,10 +1356,10 @@ func NewCLI() *cobra.Command { ...@@ -1254,10 +1356,10 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NUM_PARALLEL"], envVars["OLLAMA_NUM_PARALLEL"],
envVars["OLLAMA_NOPRUNE"], envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"], envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"], envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
}) })
default: default:
appendEnvDocs(cmd, envs) appendEnvDocs(cmd, envs)
......
package cmd package cmd
import ( import (
"cmp"
"errors" "errors"
"fmt" "fmt"
"io" "io"
...@@ -8,14 +9,15 @@ import ( ...@@ -8,14 +9,15 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"sort" "slices"
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/slices" "golang.org/x/exp/maps"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
...@@ -27,74 +29,29 @@ const ( ...@@ -27,74 +29,29 @@ const (
MultilineNone MultilineState = iota MultilineNone MultilineState = iota
MultilinePrompt MultilinePrompt
MultilineSystem MultilineSystem
MultilineTemplate
) )
func loadModel(cmd *cobra.Command, opts *runOptions) error { func loadModel(cmd *cobra.Command, opts *runOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.StopAndClear() defer p.StopAndClear()
spinner := progress.NewSpinner("") spinner := progress.NewSpinner("")
p.Add("", spinner) p.Add("", spinner)
showReq := api.ShowRequest{Name: opts.Model} client, err := api.ClientFromEnvironment()
showResp, err := client.Show(cmd.Context(), &showReq)
if err != nil { if err != nil {
return err return err
} }
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
opts.ParentModel = showResp.Details.ParentModel
if len(showResp.Messages) > 0 {
opts.Messages = append(opts.Messages, showResp.Messages...)
}
chatReq := &api.ChatRequest{ chatReq := &api.ChatRequest{
Model: opts.Model, Model: opts.Model,
Messages: []api.Message{}, KeepAlive: opts.KeepAlive,
} }
if opts.KeepAlive != nil { return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
chatReq.KeepAlive = opts.KeepAlive
}
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
p.StopAndClear()
if len(opts.Messages) > 0 {
for _, msg := range opts.Messages {
switch msg.Role {
case "user":
fmt.Printf(">>> %s\n", msg.Content)
case "assistant":
state := &displayResponseState{}
displayResponse(msg.Content, opts.WordWrap, state)
fmt.Println()
fmt.Println()
}
}
}
return nil
})
if err != nil {
return err
}
return nil
} }
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = make([]api.Message, 0)
err := loadModel(cmd, &opts)
if err != nil {
return err
}
usage := func() { usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables") fmt.Fprintln(os.Stderr, " /set Set session variables")
...@@ -119,7 +76,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -119,7 +76,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter") fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system message") fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history") fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history") fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
...@@ -165,6 +121,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -165,6 +121,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict") fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens") fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities") fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
fmt.Fprintln(os.Stderr, " /set parameter min_p <float> Pick token based on top token probability * min_p")
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size") fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level") fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions") fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
...@@ -184,7 +141,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -184,7 +141,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
if envconfig.NoHistory { if envconfig.NoHistory() {
scanner.HistoryDisable() scanner.HistoryDisable()
} }
...@@ -229,10 +186,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -229,10 +186,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset() sb.Reset()
case MultilineTemplate:
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
} }
multiline = MultilineNone multiline = MultilineNone
...@@ -351,17 +304,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -351,17 +304,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", ")) fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]] opts.Options[args[2]] = fp[args[2]]
case "system", "template": case "system":
if len(args) < 3 { if len(args) < 3 {
usageSet() usageSet()
continue continue
} }
if args[1] == "system" { multiline = MultilineSystem
multiline = MultilineSystem
} else if args[1] == "template" {
multiline = MultilineTemplate
}
line := strings.Join(args[2:], " ") line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`) line, ok := strings.CutPrefix(line, `"""`)
...@@ -381,23 +330,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -381,23 +330,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
continue continue
} }
if args[1] == "system" { opts.System = sb.String() // for display in modelfile
opts.System = sb.String() // for display in modelfile newMessage := api.Message{Role: "system", Content: sb.String()}
newMessage := api.Message{Role: "system", Content: sb.String()} // Check if the slice is not empty and the last message is from 'system'
// Check if the slice is not empty and the last message is from 'system' if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" { // Replace the last message
// Replace the last message opts.Messages[len(opts.Messages)-1] = newMessage
opts.Messages[len(opts.Messages)-1] = newMessage } else {
} else { opts.Messages = append(opts.Messages, newMessage)
opts.Messages = append(opts.Messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
} }
fmt.Println("Set system message.")
sb.Reset()
sb.Reset() sb.Reset()
continue continue
...@@ -416,10 +359,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -416,10 +359,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
req := &api.ShowRequest{ req := &api.ShowRequest{
Name: opts.Model, Name: opts.Model,
System: opts.System, System: opts.System,
Template: opts.Template, Options: opts.Options,
Options: opts.Options,
} }
resp, err := client.Show(cmd.Context(), req) resp, err := client.Show(cmd.Context(), req)
if err != nil { if err != nil {
...@@ -429,15 +371,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -429,15 +371,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] { switch args[1] {
case "info": case "info":
fmt.Println("Model details:") showInfo(resp)
if len(resp.Details.Families) > 0 {
fmt.Printf("Family %s\n", strings.Join(resp.Details.Families, ", "))
} else if resp.Details.Family != "" {
fmt.Printf("Family %s\n", resp.Details.Family)
}
fmt.Printf("Parameter Size %s\n", resp.Details.ParameterSize)
fmt.Printf("Quantization Level %s\n", resp.Details.QuantizationLevel)
fmt.Println("")
case "license": case "license":
if resp.License == "" { if resp.License == "" {
fmt.Println("No license was specified for this model.") fmt.Println("No license was specified for this model.")
...@@ -470,12 +404,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -470,12 +404,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("No system message was specified for this model.") fmt.Println("No system message was specified for this model.")
} }
case "template": case "template":
switch { if resp.Template != "" {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template) fmt.Println(resp.Template)
default: } else {
fmt.Println("No prompt template was specified for this model.") fmt.Println("No prompt template was specified for this model.")
} }
default: default:
...@@ -559,35 +490,35 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -559,35 +490,35 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
func buildModelfile(opts runOptions) string { func buildModelfile(opts runOptions) string {
var mf strings.Builder var f parser.File
model := opts.ParentModel f.Commands = append(f.Commands, parser.Command{Name: "model", Args: cmp.Or(opts.ParentModel, opts.Model)})
if model == "" {
model = opts.Model
}
fmt.Fprintf(&mf, "FROM %s\n", model)
if opts.System != "" {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
}
if opts.Template != "" { if opts.System != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template) f.Commands = append(f.Commands, parser.Command{Name: "system", Args: opts.System})
} }
keys := make([]string, 0) keys := maps.Keys(opts.Options)
for k := range opts.Options { slices.Sort(keys)
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys { for _, k := range keys {
fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k]) v := opts.Options[k]
var cmds []parser.Command
switch t := v.(type) {
case []string:
for _, s := range t {
cmds = append(cmds, parser.Command{Name: k, Args: s})
}
default:
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", t)})
}
f.Commands = append(f.Commands, cmds...)
} }
fmt.Fprintln(&mf)
for _, msg := range opts.Messages { for _, msg := range opts.Messages {
fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content) f.Commands = append(f.Commands, parser.Command{Name: "message", Args: fmt.Sprintf("%s: %s", msg.Role, msg.Content)})
} }
return mf.String() return f.String()
} }
func normalizeFilePath(fp string) string { func normalizeFilePath(fp string) string {
...@@ -673,7 +604,7 @@ func getImageData(filePath string) ([]byte, error) { ...@@ -673,7 +604,7 @@ func getImageData(filePath string) ([]byte, error) {
// Check if the file size exceeds 100MB // Check if the file size exceeds 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
if info.Size() > maxSize { if info.Size() > maxSize {
return nil, fmt.Errorf("file size exceeds maximum limit (100MB)") return nil, errors.New("file size exceeds maximum limit (100MB)")
} }
buf = make([]byte, info.Size()) buf = make([]byte, info.Size())
......
package cmd package cmd
import ( import (
"bytes"
"testing" "testing"
"text/template"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
...@@ -56,61 +55,53 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8 ...@@ -56,61 +55,53 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
func TestModelfileBuilder(t *testing.T) { func TestModelfileBuilder(t *testing.T) {
opts := runOptions{ opts := runOptions{
Model: "hork", Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things", System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{ Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"}, {Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
}, },
Options: map[string]interface{}{}, Options: map[string]any{
"temperature": 0.9,
"seed": 42,
"penalize_newline": false,
"stop": []string{"hi", "there"},
},
} }
opts.Options["temperature"] = 0.9 t.Run("model", func(t *testing.T) {
opts.Options["seed"] = 42 expect := `FROM hork
opts.Options["penalize_newline"] = false SYSTEM You are part horse and part shark, but all hork. Do horklike things
opts.Options["stop"] = []string{"hi", "there"}
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop [hi there] PARAMETER stop hi
PARAMETER stop there
PARAMETER temperature 0.9 PARAMETER temperature 0.9
MESSAGE user Hey there hork!
MESSAGE user """Hey there hork!""" MESSAGE assistant Yes it is true, I am half horse, half shark.
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
tmpl, err := template.New("").Parse(expectedModelfile) actual := buildModelfile(opts)
assert.Nil(t, err) if diff := cmp.Diff(expect, actual); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
var buf bytes.Buffer }
err = tmpl.Execute(&buf, opts) })
assert.Nil(t, err)
assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark" t.Run("parent model", func(t *testing.T) {
mf = buildModelfile(opts) opts.ParentModel = "horseshark"
expectedModelfile = `FROM {{.ParentModel}} expect := `FROM horseshark
SYSTEM """{{.System}}""" SYSTEM You are part horse and part shark, but all hork. Do horklike things
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop [hi there] PARAMETER stop hi
PARAMETER stop there
PARAMETER temperature 0.9 PARAMETER temperature 0.9
MESSAGE user Hey there hork!
MESSAGE user """Hey there hork!""" MESSAGE assistant Yes it is true, I am half horse, half shark.
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
actual := buildModelfile(opts)
tmpl, err = template.New("").Parse(expectedModelfile) if diff := cmp.Diff(expect, actual); diff != "" {
assert.Nil(t, err) t.Errorf("mismatch (-want +got):\n%s", diff)
}
var parentBuf bytes.Buffer })
err = tmpl.Execute(&parentBuf, opts)
assert.Nil(t, err)
assert.Equal(t, parentBuf.String(), mf)
} }
//go:build darwin || windows
package cmd
import (
"context"
"errors"
"time"
"github.com/ollama/ollama/api"
)
func waitForServer(ctx context.Context, client *api.Client) error {
// wait for the server to start
timeout := time.After(5 * time.Second)
tick := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}
...@@ -2,7 +2,7 @@ package cmd ...@@ -2,7 +2,7 @@ package cmd
import ( import (
"context" "context"
"fmt" "errors"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
...@@ -20,7 +20,7 @@ func startApp(ctx context.Context, client *api.Client) error { ...@@ -20,7 +20,7 @@ func startApp(ctx context.Context, client *api.Client) error {
return err return err
} }
if !strings.Contains(link, "Ollama.app") { if !strings.Contains(link, "Ollama.app") {
return fmt.Errorf("could not find ollama app") return errors.New("could not find ollama app")
} }
path := strings.Split(link, "Ollama.app") path := strings.Split(link, "Ollama.app")
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil { if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
......
...@@ -4,11 +4,11 @@ package cmd ...@@ -4,11 +4,11 @@ package cmd
import ( import (
"context" "context"
"fmt" "errors"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
func startApp(ctx context.Context, client *api.Client) error { func startApp(ctx context.Context, client *api.Client) error {
return fmt.Errorf("could not connect to ollama server, run 'ollama serve' to start it") return errors.New("could not connect to ollama server, run 'ollama serve' to start it")
} }
...@@ -31,7 +31,7 @@ func startApp(ctx context.Context, client *api.Client) error { ...@@ -31,7 +31,7 @@ func startApp(ctx context.Context, client *api.Client) error {
// Finally look in the path // Finally look in the path
appExe, err = exec.LookPath(AppName) appExe, err = exec.LookPath(AppName)
if err != nil { if err != nil {
return fmt.Errorf("could not locate ollama app") return errors.New("could not locate ollama app")
} }
} }
} }
......
package convert package convert
import ( import (
"cmp"
"encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log/slog" "log/slog"
"os"
"path/filepath"
"slices"
"strings"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
const ( type Parameters struct {
_ int32 = iota Architectures []string `json:"architectures"`
tokenTypeNormal VocabSize uint32 `json:"vocab_size"`
tokenTypeUnknown
tokenTypeControl
tokenTypeUserDefined
tokenTypeUnused
tokenTypeByte
)
type Params struct {
Architectures []string `json:"architectures"`
VocabSize int `json:"vocab_size"`
HiddenSize int `json:"hidden_size"` // n_embd
HiddenLayers int `json:"num_hidden_layers"` // n_layer
ContextSize int `json:"max_position_embeddings"`
IntermediateSize int `json:"intermediate_size"`
AttentionHeads int `json:"num_attention_heads"` // n_head
KeyValHeads int `json:"num_key_value_heads"`
NormEPS float64 `json:"rms_norm_eps"`
BoSTokenID int `json:"bos_token_id"`
EoSTokenID int `json:"eos_token_id"`
HeadDimension int `json:"head_dim"`
PaddingTokenID int `json:"pad_token_id"`
RopeFrequencyBase float64 `json:"rope_theta"`
Experts int `json:"num_local_experts"`
ExpertsUsed int `json:"num_experts_per_tok"`
PreTokenizer string
ByteOrder
} }
type ByteOrder interface { func (Parameters) KV(t *Tokenizer) llm.KV {
binary.ByteOrder kv := llm.KV{
binary.AppendByteOrder "general.file_type": uint32(1),
} "general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
"tokenizer.ggml.model": t.Vocabulary.Model,
"tokenizer.ggml.tokens": t.Vocabulary.Tokens,
"tokenizer.ggml.scores": t.Vocabulary.Scores,
"tokenizer.ggml.token_type": t.Vocabulary.Types,
}
type ModelArch interface { if t.Template != "" {
GetTensors() error kv["tokenizer.chat_template"] = t.Template
LoadVocab() error }
WriteGGUF(io.WriteSeeker) error
}
type ModelFormat interface { for _, sv := range t.SpecialVocabulary {
GetLayerName(string) (string, error) kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
GetTensors(string, *Params) ([]llm.Tensor, error) kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
GetParams(string) (*Params, error) }
GetModelArch(string, string, *Params) (ModelArch, error)
}
type ModelData struct { return kv
Path string
Name string
Params *Params
Vocab *Vocab
Tensors []llm.Tensor
Format ModelFormat
} }
func GetModelFormat(dirname string) (ModelFormat, error) { func (Parameters) specialTokenTypes() []string {
files, err := filepath.Glob(filepath.Join(dirname, "*")) return []string{
if err != nil { "bos", "eos", "unk", "sep", "pad", "cls", "mask",
return nil, err
}
for _, fn := range files {
if strings.HasSuffix(fn, ".safetensors") {
return &SafetensorFormat{}, nil
} else if strings.HasSuffix(fn, ".bin") || strings.HasSuffix(fn, ".pth") {
slog.Debug("model is torch")
return &TorchFormat{}, nil
}
} }
}
return nil, fmt.Errorf("couldn't determine model format") func (Parameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts)
} }
// Details on gguf's tokenizer can be found at: type Converter interface {
// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#tokenizer // KV maps parameters to LLM key-values
type Vocab struct { KV(*Tokenizer) llm.KV
Tokens []string // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Scores []float32 Tensors([]Tensor) []llm.Tensor
Types []int32
Merges []string // tensorName returns the LLM tensor name for a specific input name
tensorName(string) string
// specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
} }
func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) { // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model"))) // and files it finds in the input path.
in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model")) // Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func Convert(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil { if err != nil {
return nil, err return err
} }
// To regenerate sentencepiece from the protobufs use: var p Parameters
// protoc -I=./ --go_out=./ sentencepiece_model.proto if err := json.Unmarshal(bts, &p); err != nil {
modelProto := &sentencepiece.ModelProto{} return err
if err := proto.Unmarshal(in, modelProto); err != nil {
return nil, err
} }
v := &Vocab{ if len(p.Architectures) < 1 {
Tokens: make([]string, 0), return errors.New("unknown architecture")
Scores: make([]float32, 0),
Types: make([]int32, 0),
} }
pieces := modelProto.GetPieces() var conv Converter
for _, p := range pieces { switch p.Architectures[0] {
v.Tokens = append(v.Tokens, p.GetPiece()) case "LlamaForCausalLM", "MistralForCausalLM":
v.Scores = append(v.Scores, p.GetScore()) conv = &llama{}
t := p.GetType() case "MixtralForCausalLM":
switch t { conv = &mixtral{}
case sentencepiece.ModelProto_SentencePiece_UNKNOWN: case "GemmaForCausalLM":
case sentencepiece.ModelProto_SentencePiece_CONTROL: conv = &gemma{}
case sentencepiece.ModelProto_SentencePiece_UNUSED: default:
case sentencepiece.ModelProto_SentencePiece_BYTE: return errors.New("unsupported architecture")
default:
t = sentencepiece.ModelProto_SentencePiece_NORMAL
}
v.Types = append(v.Types, int32(t))
}
slog.Info(fmt.Sprintf("vocab size: %d", len(v.Tokens)))
// add any additional tokens
addIn, err := os.ReadFile(filepath.Join(dirpath, "added_tokens.json"))
if os.IsNotExist(err) {
return v, nil
} else if err != nil {
return nil, err
}
slog.Info("reading user defined tokens")
var extraTokenData map[string]int
if err := json.Unmarshal(addIn, &extraTokenData); err != nil {
return nil, err
} }
type token struct { if err := json.Unmarshal(bts, conv); err != nil {
key string return err
pos int
} }
extraTokens := make([]token, 0) t, err := parseTokenizer(fsys, conv.specialTokenTypes())
for k, id := range extraTokenData { if err != nil {
extraTokens = append(extraTokens, token{k, id}) return err
} }
slices.SortFunc(extraTokens, func(a, b token) int { if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) {
return cmp.Compare(a.pos, b.pos) slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens))
}) for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
numToks := len(v.Tokens) t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
for cnt, t := range extraTokens {
// the token id should match the specific index for the total number of tokens
if t.pos != cnt+numToks {
return nil, fmt.Errorf("token ID '%d' for '%s' doesn't match total token size", t.pos, t.key)
} }
v.Tokens = append(v.Tokens, t.key) } else {
v.Scores = append(v.Scores, -1000.0) slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
v.Types = append(v.Types, tokenTypeUserDefined)
} }
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
ts, err := parseTensors(fsys)
if params.VocabSize > len(v.Tokens) { if err != nil {
missingTokens := params.VocabSize - len(v.Tokens) return err
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
for cnt := 0; cnt < missingTokens; cnt++ {
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
v.Scores = append(v.Scores, -1)
v.Types = append(v.Types, tokenTypeUserDefined)
}
} }
return v, nil return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
} }
package convert
import (
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type gemma struct {
Parameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
}
var _ Converter = (*gemma)(nil)
func (p *gemma) KV(t *Tokenizer) llm.KV {
kv := p.Parameters.KV(t)
kv["general.architecture"] = "gemma"
kv["general.name"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings
kv["gemma.embedding_length"] = p.HiddenSize
kv["gemma.block_count"] = p.HiddenLayers
kv["gemma.feed_forward_length"] = p.IntermediateSize
kv["gemma.attention.head_count"] = p.NumAttentionHeads
kv["gemma.attention.head_count_kv"] = p.NumKeyValueHeads
kv["gemma.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma.attention.key_length"] = p.HeadDim
kv["gemma.attention.value_length"] = p.HeadDim
kv["tokenizer.ggml.eot_token_id"] = uint32(107)
kv["tokenizer.ggml.middle_token_id"] = uint32(68)
kv["tokenizer.ggml.prefix_token_id"] = uint32(67)
kv["tokenizer.ggml.suffix_token_id"] = uint32(69)
return kv
}
func (p *gemma) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
name := p.tensorName(t.Name())
if strings.HasSuffix(name, "_norm.weight") {
t.SetRepacker(p.addOne)
}
out = append(out, llm.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *gemma) tensorName(n string) string {
return strings.NewReplacer(
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm",
"block_sparse_moe.gate", "ffn_inp",
).Replace(n)
}
func (*gemma) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0]))
n, err := n.Add(ones)
if err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 0)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type llama struct {
Parameters
NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NLayer uint32 `json:"n_layer"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NCtx uint32 `json:"n_ctx"`
HiddenSize uint32 `json:"hidden_size"`
NEmbd uint32 `json:"n_embd"`
IntermediateSize uint32 `json:"intermediate_size"`
NInner uint32 `json:"n_inner"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NHead uint32 `json:"n_head"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling struct {
Type string `json:"type"`
Factor float32 `json:"factor"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
HeadDim uint32 `json:"head_dim"`
}
var _ Converter = (*llama)(nil)
func (p *llama) KV(t *Tokenizer) llm.KV {
kv := p.Parameters.KV(t)
kv["general.architecture"] = "llama"
kv["general.name"] = "llama"
kv["llama.vocab_size"] = p.VocabSize
kv["llama.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
if contextLength := cmp.Or(p.MaxPositionEmbeddings, p.NCtx); contextLength > 0 {
kv["llama.context_length"] = contextLength
}
if embeddingLength := cmp.Or(p.HiddenSize, p.NEmbd); embeddingLength > 0 {
kv["llama.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
}
if feedForwardLength := cmp.Or(p.IntermediateSize, p.NInner); feedForwardLength > 0 {
kv["llama.feed_forward_length"] = cmp.Or(p.IntermediateSize, p.NInner)
}
if headCount := cmp.Or(p.NumAttentionHeads, p.NHead); headCount > 0 {
kv["llama.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
}
if p.RopeTheta > 0 {
kv["llama.rope.freq_base"] = p.RopeTheta
}
if p.RopeScaling.Type == "linear" {
kv["llama.rope.scaling.type"] = p.RopeScaling.Type
kv["llama.rope.scaling.factor"] = p.RopeScaling.Factor
}
if p.NumKeyValueHeads > 0 {
kv["llama.attention.head_count_kv"] = p.NumKeyValueHeads
}
if p.RMSNormEPS > 0 {
kv["llama.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon, p.NormEpsilon); layerNormEpsilon > 0 {
kv["llama.attention.layer_norm_epsilon"] = layerNormEpsilon
}
if p.HeadDim > 0 {
kv["llama.attention.key_length"] = p.HeadDim
kv["llama.attention.value_length"] = p.HeadDim
}
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
return kv
}
func (p *llama) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
name := p.tensorName(t.Name())
if strings.HasSuffix(name, "attn_q.weight") ||
strings.HasSuffix(name, "attn_k.weight") {
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *llama) tensorName(n string) string {
return strings.NewReplacer(
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm",
// mixtral
"block_sparse_moe.gate", "ffn_gate_inp",
).Replace(n)
}
func (p *llama) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, "q_proj.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "k_proj.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
package convert
import (
"fmt"
"io"
"slices"
"strings"
"github.com/ollama/ollama/llm"
)
type mixtral struct {
llama
NumLocalExperts uint32 `json:"num_local_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
var _ Converter = (*mixtral)(nil)
func (p *mixtral) KV(t *Tokenizer) llm.KV {
kv := p.llama.KV(t)
if p.NumLocalExperts > 0 {
kv["llama.expert_count"] = p.NumLocalExperts
}
if p.NumExpertsPerToken > 0 {
kv["llama.expert_used_count"] = p.NumExpertsPerToken
}
return kv
}
func (p *mixtral) Tensors(ts []Tensor) []llm.Tensor {
oldnew := []string{
"model.layers", "blk",
"w1", "ffn_gate_exps",
"w2", "ffn_down_exps",
"w3", "ffn_up_exps",
}
for i := range p.NumLocalExperts {
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
}
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
namer := strings.NewReplacer(oldnew...)
experts := make(map[string]experts)
// merge experts into a single tensor while removing them from ts
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
return false
}
name := namer.Replace(t.Name())
experts[name] = append(experts[name], t)
return true
})
var out []llm.Tensor
for n, e := range experts {
// TODO(mxyng): sanity check experts
out = append(out, llm.Tensor{
Name: n,
Kind: e[0].Kind(),
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
WriterTo: e,
})
}
return append(out, p.llama.Tensors(ts)...)
}
type experts []Tensor
func (e experts) WriteTo(w io.Writer) (int64, error) {
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
for _, t := range e {
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
// this accomplishes the same thing by writing each expert tensor in sequence
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}
//go:build slow
package convert package convert
import ( import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"os" "os"
"path/filepath" "path/filepath"
"slices"
"testing" "testing"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
func convertFull(t *testing.T, p string) (llm.KV, llm.Tensors) { func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
t.Helper() t.Helper()
mf, err := GetModelFormat(p) f, err := os.CreateTemp(t.TempDir(), "f16")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer f.Close()
params, err := mf.GetParams(p) if err := Convert(fsys, f); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
arch, err := mf.GetModelArch("", p, params) r, err := os.Open(f.Name())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(func() { r.Close() })
if err := arch.LoadVocab(); err != nil { m, _, err := llm.DecodeGGML(r, math.MaxInt)
t.Fatal(err)
}
if err := arch.GetTensors(); err != nil {
t.Fatal(err)
}
f, err := os.CreateTemp(t.TempDir(), "f16")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer f.Close()
if err := arch.WriteGGUF(f); err != nil { if _, err := r.Seek(0, io.SeekStart); err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, err := os.Open(f.Name()) return r, m.KV(), m.Tensors()
if err != nil { }
t.Fatal(err)
}
defer r.Close()
m, _, err := llm.DecodeGGML(r)
if err != nil {
t.Fatal(err)
}
return m.KV(), m.Tensors() func TestMain(m *testing.M) {
var level slog.Level
flag.TextVar(&level, "level", slog.LevelInfo, "log level")
flag.Parse()
slog.SetLogLoggerLevel(level)
os.Exit(m.Run())
} }
func TestConvertFull(t *testing.T) { func TestConvertFull(t *testing.T) {
cases := []struct { cases := []string{
path string "Meta-Llama-3-8B-Instruct",
arch string "Mistral-7B-Instruct-v0.2",
tensors int "Mixtral-8x7B-Instruct-v0.1",
layers int "gemma-2b-it",
}{
{"Meta-Llama-3-8B-Instruct", "llama", 291, 35},
{"Mistral-7B-Instruct-v0.2", "llama", 291, 35},
{"Mixtral-8x7B-Instruct-v0.1", "llama", 291, 35},
{"gemma-2b-it", "gemma", 164, 20},
} }
for _, tt := range cases { for i := range cases {
t.Run(tt.path, func(t *testing.T) { tt := cases[i]
p := filepath.Join("testdata", tt.path) t.Run(tt, func(t *testing.T) {
if _, err := os.Stat(p); err != nil { t.Parallel()
p := filepath.Join("testdata", tt)
if testing.Short() {
t.Skip("skipping in short mode")
} else if _, err := os.Stat(p); err != nil {
t.Skipf("%s not found", p) t.Skipf("%s not found", p)
} }
kv, tensors := convertFull(t, p) f, kv, tensors := convertFull(t, os.DirFS(p))
actual := make(map[string]string)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
bts, err := json.Marshal(s)
if err != nil {
t.Fatal(err)
}
actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
}
}
for _, tensor := range tensors.Items {
sha256sum := sha256.New()
sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
if _, err := io.Copy(sha256sum, sr); err != nil {
t.Fatal(err)
}
if kv.Architecture() != tt.arch { actual[tensor.Name] = hex.EncodeToString(sha256sum.Sum(nil))
t.Fatalf("expected llama, got %s", kv.Architecture())
} }
if kv.FileType().String() != "F16" { expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
t.Fatalf("expected F16, got %s", kv.FileType()) if err != nil {
t.Fatal(err)
} }
if len(tensors) != tt.tensors { var expect map[string]string
t.Fatalf("expected %d tensors, got %d", tt.tensors, len(tensors)) if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
t.Fatal(err)
} }
layers := tensors.Layers() keys := maps.Keys(expect)
if len(layers) != tt.layers { slices.Sort(keys)
t.Fatalf("expected %d layers, got %d", tt.layers, len(layers)) for _, k := range keys {
if v, ok := actual[k]; !ok {
t.Errorf("missing %s", k)
} else if v != expect[k] {
t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
}
} }
}) })
} }
......
package convert
import (
"archive/zip"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
)
type ZipReader struct {
r *zip.Reader
p string
// limit is the maximum size of a file that can be read directly
// from the zip archive. Files larger than this size will be extracted
limit int64
}
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
return &ZipReader{r, p, limit}
}
func (z *ZipReader) Open(name string) (fs.File, error) {
r, err := z.r.Open(name)
if err != nil {
return nil, err
}
defer r.Close()
if fi, err := r.Stat(); err != nil {
return nil, err
} else if fi.Size() < z.limit {
return r, nil
}
if !filepath.IsLocal(name) {
return nil, zip.ErrInsecurePath
}
n := filepath.Join(z.p, name)
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
w, err := os.Create(n)
if err != nil {
return nil, err
}
defer w.Close()
if _, err := io.Copy(w, r); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
return os.Open(n)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment