"docs/source/zh/optimization/cache.md" did not exist on "4fcd0bc7ebb934a1559d0b516f09534ba22c8a0d"
Unverified Commit 2010cbc5 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

Merge pull request #3833 from ollama/mxyng/fix-from

fix: from blob
parents ade4b555 ac0801ec
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
...@@ -53,8 +54,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error { ...@@ -53,8 +54,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
bars := make(map[string]*progress.Bar)
modelfile, err := os.ReadFile(filename) modelfile, err := os.ReadFile(filename)
if err != nil { if err != nil {
return err return err
...@@ -95,32 +94,92 @@ func CreateHandler(cmd *cobra.Command, args []string) error { ...@@ -95,32 +94,92 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
// TODO make this work w/ adapters
if fi.IsDir() { if fi.IsDir() {
tf, err := os.CreateTemp("", "ollama-tf") // this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
if err != nil { if err != nil {
return err return err
} }
defer os.RemoveAll(tf.Name()) defer os.RemoveAll(tempfile)
zf := zip.NewWriter(tf) path = tempfile
}
files := []string{} digest, err := createBlob(cmd, client, path)
if err != nil {
return err
}
name := c.Name
if c.Name == "model" {
name = "from"
}
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
}
}
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
spinner.Stop()
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
spinner.Stop()
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err
}
return nil
}
func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
if err != nil {
return "", err
}
defer tempfile.Close()
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin")) tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
if err != nil { if err != nil {
return err return "", err
} else if len(tfiles) == 0 { } else if len(tfiles) == 0 {
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors")) tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil { if err != nil {
return err return "", err
} }
} }
files := []string{}
files = append(files, tfiles...) files = append(files, tfiles...)
if len(files) == 0 { if len(files) == 0 {
return fmt.Errorf("no models were found in '%s'", path) return "", fmt.Errorf("no models were found in '%s'", path)
} }
// add the safetensor/torch config file + tokenizer // add the safetensor/torch config file + tokenizer
...@@ -142,90 +201,40 @@ func CreateHandler(cmd *cobra.Command, args []string) error { ...@@ -142,90 +201,40 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if os.IsNotExist(err) { if os.IsNotExist(err) {
continue continue
} else if err != nil { } else if err != nil {
return err return "", err
} }
} else { } else {
continue continue
} }
} else if err != nil { } else if err != nil {
return err return "", err
} }
fi, err := f.Stat() fi, err := f.Stat()
if err != nil { if err != nil {
return err return "", err
} }
h, err := zip.FileInfoHeader(fi) h, err := zip.FileInfoHeader(fi)
if err != nil { if err != nil {
return err return "", err
} }
h.Name = filepath.Base(fn) h.Name = filepath.Base(fn)
h.Method = zip.Store h.Method = zip.Store
w, err := zf.CreateHeader(h) w, err := zipfile.CreateHeader(h)
if err != nil { if err != nil {
return err return "", err
} }
_, err = io.Copy(w, f) _, err = io.Copy(w, f)
if err != nil { if err != nil {
return err return "", err
}
}
if err := zf.Close(); err != nil {
return err
}
if err := tf.Close(); err != nil {
return err
}
path = tf.Name()
}
digest, err := createBlob(cmd, client, path)
if err != nil {
return err
}
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
}
}
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
spinner.Stop()
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
spinner.Stop()
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
} }
quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err
} }
return nil return tempfile.Name(), nil
} }
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment