Unverified Commit 6ebab38b authored by Dane Madsen's avatar Dane Madsen Committed by GitHub
Browse files

Merge branch 'jmorganca:main' into main

parents 5d8e864d a3fcecf9
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"math"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
...@@ -53,8 +54,8 @@ type blobDownloadPart struct { ...@@ -53,8 +54,8 @@ type blobDownloadPart struct {
const ( const (
numDownloadParts = 64 numDownloadParts = 64
minDownloadPartSize int64 = 32 * 1000 * 1000 minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 256 * 1000 * 1000 maxDownloadPartSize int64 = 1000 * format.MegaByte
) )
func (p *blobDownloadPart) Name() string { func (p *blobDownloadPart) Name() string {
...@@ -147,7 +148,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis ...@@ -147,7 +148,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
continue continue
} }
i := i
g.Go(func() error { g.Go(func() error {
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
...@@ -158,12 +158,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis ...@@ -158,12 +158,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
// return immediately if the context is canceled or the device is out of space // return immediately if the context is canceled or the device is out of space
return err return err
case err != nil: case err != nil:
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
continue continue
default: default:
if try > 0 {
log.Printf("%s part %d completed after %d retries", b.Digest[7:19], i, try)
}
return nil return nil
} }
} }
...@@ -285,7 +284,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) ...@@ -285,7 +284,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
} }
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", b.Digest), Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest, Digest: b.Digest,
Total: b.Total, Total: b.Total,
Completed: b.Completed.Load(), Completed: b.Completed.Load(),
...@@ -304,7 +303,7 @@ type downloadOpts struct { ...@@ -304,7 +303,7 @@ type downloadOpts struct {
fn func(api.ProgressResponse) fn func(api.ProgressResponse)
} }
const maxRetries = 3 const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded") var errMaxRetriesExceeded = errors.New("max retries exceeded")
...@@ -322,7 +321,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { ...@@ -322,7 +321,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
return err return err
default: default:
opts.fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", opts.digest), Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
Digest: opts.digest, Digest: opts.digest,
Total: fi.Size(), Total: fi.Size(),
Completed: fi.Size(), Completed: fi.Size(),
......
...@@ -228,220 +228,181 @@ func GetModel(name string) (*Model, error) { ...@@ -228,220 +228,181 @@ func GetModel(name string) (*Model, error) {
return model, nil return model, nil
} }
func filenameWithPath(path, f string) (string, error) { func realpath(p string) string {
// if filePath starts with ~/, replace it with the user's home directory. abspath, err := filepath.Abs(p)
if strings.HasPrefix(f, fmt.Sprintf("~%s", string(os.PathSeparator))) { if err != nil {
parts := strings.Split(f, string(os.PathSeparator)) return p
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to open file: %v", err)
}
f = filepath.Join(home, filepath.Join(parts[1:]...))
}
// if filePath is not an absolute path, make it relative to the modelfile path
if !filepath.IsAbs(f) {
f = filepath.Join(filepath.Dir(path), f)
}
return f, nil
}
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2
var err error
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]bool)
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = true
}
deleteMap[manifest.Config.Digest] = true
}
} }
mf, err := os.Open(path) home, err := os.UserHomeDir()
if err != nil { if err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) return abspath
return fmt.Errorf("failed to open file: %w", err)
} }
defer mf.Close()
fn(api.ProgressResponse{Status: "parsing modelfile"}) if p == "~" {
commands, err := parser.Parse(mf) return home
if err != nil { } else if strings.HasPrefix(p, "~/") {
return err return filepath.Join(home, p[2:])
} }
return abspath
}
func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
config := ConfigV2{ config := ConfigV2{
Architecture: "amd64",
OS: "linux", OS: "linux",
Architecture: "amd64",
} }
deleteMap := make(map[string]struct{})
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
var sourceParams map[string]any fromParams := make(map[string]any)
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s", c.Name, c.Args)
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
mp := ParseModelPath(c.Args)
mf, _, err := GetManifest(mp)
if err != nil {
modelFile, err := filenameWithPath(path, c.Args)
if err != nil { if err != nil {
return err return err
} }
if _, err := os.Stat(modelFile); err != nil {
// the model file does not exist, try pulling it
if errors.Is(err, os.ErrNotExist) {
fn(api.ProgressResponse{Status: "pulling model file"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err
}
mf, _, err = GetManifest(mp)
if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
} else {
return err
}
} else {
// create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
ggml, err := llm.DecodeGGML(file) c.Args = blobPath
if err != nil { }
bin, err := os.Open(realpath(c.Args))
if err != nil {
// not a file on disk so must be a model reference
modelpath := ParseModelPath(c.Args)
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err return err
} }
config.ModelFormat = ggml.Name() manifest, _, err = GetManifest(modelpath)
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
// reset the file
file.Seek(0, io.SeekStart)
l, err := CreateLayer(file)
if err != nil { if err != nil {
return fmt.Errorf("failed to create layer: %v", err) return err
} }
l.MediaType = "application/vnd.ollama.image.model" case err != nil:
layers = append(layers, l) return err
} }
}
if mf != nil {
fn(api.ProgressResponse{Status: "reading model metadata"}) fn(api.ProgressResponse{Status: "reading model metadata"})
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest) fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
if err != nil { if err != nil {
return err return err
} }
sourceBlob, err := os.Open(sourceBlobPath) fromConfigFile, err := os.Open(fromConfigPath)
if err != nil { if err != nil {
return err return err
} }
defer sourceBlob.Close() defer fromConfigFile.Close()
var source ConfigV2 var fromConfig ConfigV2
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil { if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
return err return err
} }
// copy the model metadata config.ModelFormat = fromConfig.ModelFormat
config.ModelFamily = source.ModelFamily config.ModelFamily = fromConfig.ModelFamily
config.ModelType = source.ModelType config.ModelType = fromConfig.ModelType
config.ModelFormat = source.ModelFormat config.FileType = fromConfig.FileType
config.FileType = source.FileType
for _, l := range mf.Layers { for _, layer := range manifest.Layers {
if l.MediaType == "application/vnd.ollama.image.params" { deleteMap[layer.Digest] = struct{}{}
sourceParamsBlobPath, err := GetBlobsPath(l.Digest) if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
if err != nil { if err != nil {
return err return err
} }
sourceParamsBlob, err := os.Open(sourceParamsBlobPath) fromParamsFile, err := os.Open(fromParamsPath)
if err != nil { if err != nil {
return err return err
} }
defer sourceParamsBlob.Close() defer fromParamsFile.Close()
if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil { if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
return err return err
} }
} }
newLayer, err := GetLayerWithBufferFromLayer(l) layer, err := GetLayerWithBufferFromLayer(layer)
if err != nil { if err != nil {
return err return err
} }
newLayer.From = mp.GetShortTagname()
layers = append(layers, newLayer) layer.From = modelpath.GetShortTagname()
layers = append(layers, layer)
} }
deleteMap[manifest.Config.Digest] = struct{}{}
continue
} }
case "adapter": defer bin.Close()
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
fp, err := filenameWithPath(path, c.Args) fn(api.ProgressResponse{Status: "creating model layer"})
ggml, err := llm.DecodeGGML(bin)
if err != nil { if err != nil {
return err return err
} }
// create a model from this specified file config.ModelFormat = ggml.Name()
fn(api.ProgressResponse{Status: "creating model layer"}) config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
file, err := os.Open(fp) bin.Seek(0, io.SeekStart)
layer, err := CreateLayer(bin)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file: %v", err) return err
} }
defer file.Close()
l, err := CreateLayer(file) layer.MediaType = mediatype
layers = append(layers, layer)
case "adapter":
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(c.Args))
if err != nil { if err != nil {
return fmt.Errorf("failed to create layer: %v", err) return err
}
defer bin.Close()
layer, err := CreateLayer(bin)
if err != nil {
return err
} }
l.MediaType = "application/vnd.ollama.image.adapter"
layers = append(layers, l)
case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
layer, err := CreateLayer(strings.NewReader(c.Args)) layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil { if err != nil {
return err return err
} }
if layer.Size > 0 { if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
} }
case "template", "system", "prompt": case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
// remove the layer if one exists
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) // remove duplicate layers
layers = removeLayerFromLayers(layers, mediaType) layers = removeLayerFromLayers(layers, mediatype)
layer, err := CreateLayer(strings.NewReader(c.Args)) layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil { if err != nil {
...@@ -449,48 +410,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api ...@@ -449,48 +410,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
} }
if layer.Size > 0 { if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
} }
default: default:
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
} }
} }
// Create a single layer for the parameters
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameter layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
formattedParams, err := formatParams(params) formattedParams, err := formatParams(params)
if err != nil { if err != nil {
return fmt.Errorf("couldn't create params json: %v", err) return err
} }
for k, v := range sourceParams { for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok { if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v formattedParams[k] = v
} }
} }
if config.ModelType == "65B" { if config.ModelType == "65B" {
if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 { if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B" config.ModelType = "70B"
} }
} }
bts, err := json.Marshal(formattedParams) var b bytes.Buffer
if err != nil { if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
return err return err
} }
l, err := CreateLayer(bytes.NewReader(bts)) fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
if err != nil { if err != nil {
return fmt.Errorf("failed to create layer: %v", err) return err
} }
l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l) layer.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, layer)
} }
digests, err := getLayerDigests(layers) digests, err := getLayerDigests(layers)
...@@ -498,36 +458,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api ...@@ -498,36 +458,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
var manifestLayers []*Layer configLayer, err := createConfigLayer(config, digests)
for _, l := range layers {
manifestLayers = append(manifestLayers, &l.Layer)
delete(deleteMap, l.Layer.Digest)
}
// Create a layer for the config object
fn(api.ProgressResponse{Status: "creating config layer"})
cfg, err := createConfigLayer(config, digests)
if err != nil { if err != nil {
return err return err
} }
layers = append(layers, cfg)
delete(deleteMap, cfg.Layer.Digest) layers = append(layers, configLayer)
delete(deleteMap, configLayer.Digest)
if err := SaveLayers(layers, fn, false); err != nil { if err := SaveLayers(layers, fn, false); err != nil {
return err return err
} }
// Create the manifest var contentLayers []*Layer
for _, layer := range layers {
contentLayers = append(contentLayers, &layer.Layer)
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
err = CreateManifest(name, cfg, manifestLayers) if err := CreateManifest(name, configLayer, contentLayers); err != nil {
if err != nil {
return err return err
} }
if noprune == "" { if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
fn(api.ProgressResponse{Status: "removing any unused layers"}) if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
return err return err
} }
} }
...@@ -739,7 +694,7 @@ func CopyModel(src, dest string) error { ...@@ -739,7 +694,7 @@ func CopyModel(src, dest string) error {
return nil return nil
} }
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error { func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
fp, err := GetManifestPath() fp, err := GetManifestPath()
if err != nil { if err != nil {
return err return err
...@@ -779,21 +734,19 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry ...@@ -779,21 +734,19 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
} }
// only delete the files which are still in the deleteMap // only delete the files which are still in the deleteMap
for k, v := range deleteMap { for k := range deleteMap {
if v { fp, err := GetBlobsPath(k)
fp, err := GetBlobsPath(k) if err != nil {
if err != nil { log.Printf("couldn't get file path for '%s': %v", k, err)
log.Printf("couldn't get file path for '%s': %v", k, err) continue
}
if !dryRun {
if err := os.Remove(fp); err != nil {
log.Printf("couldn't remove file '%s': %v", fp, err)
continue continue
} }
if !dryRun { } else {
if err := os.Remove(fp); err != nil { log.Printf("wanted to remove: %s", fp)
log.Printf("couldn't remove file '%s': %v", fp, err)
continue
}
} else {
log.Printf("wanted to remove: %s", fp)
}
} }
} }
...@@ -801,7 +754,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry ...@@ -801,7 +754,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
} }
func PruneLayers() error { func PruneLayers() error {
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("") p, err := GetBlobsPath("")
if err != nil { if err != nil {
return err return err
...@@ -818,7 +771,9 @@ func PruneLayers() error { ...@@ -818,7 +771,9 @@ func PruneLayers() error {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
name = strings.ReplaceAll(name, "-", ":") name = strings.ReplaceAll(name, "-", ":")
} }
deleteMap[name] = true if strings.HasPrefix(name, "sha256:") {
deleteMap[name] = struct{}{}
}
} }
log.Printf("total blobs: %d", len(deleteMap)) log.Printf("total blobs: %d", len(deleteMap))
...@@ -873,11 +828,11 @@ func DeleteModel(name string) error { ...@@ -873,11 +828,11 @@ func DeleteModel(name string) error {
return err return err
} }
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = true deleteMap[layer.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = true deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap, false) err = deleteUnusedLayers(&mp, deleteMap, false)
if err != nil { if err != nil {
...@@ -979,6 +934,9 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu ...@@ -979,6 +934,9 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
for _, layer := range layers { for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
log.Printf("error uploading blob: %v", err) log.Printf("error uploading blob: %v", err)
if errors.Is(err, errUnauthorized) {
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
}
return err return err
} }
} }
...@@ -1013,7 +971,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu ...@@ -1013,7 +971,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var noprune string var noprune string
// build deleteMap to prune unused layers // build deleteMap to prune unused layers
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp) manifest, _, err = GetManifest(mp)
...@@ -1023,9 +981,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu ...@@ -1023,9 +981,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
if manifest != nil { if manifest != nil {
for _, l := range manifest.Layers { for _, l := range manifest.Layers {
deleteMap[l.Digest] = true deleteMap[l.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = true deleteMap[manifest.Config.Digest] = struct{}{}
} }
} }
...@@ -1165,44 +1123,52 @@ func GetSHA256Digest(r io.Reader) (string, int64) { ...@@ -1165,44 +1123,52 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
} }
var errUnauthorized = fmt.Errorf("unauthorized")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
for try := 0; try < maxRetries; try++ { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("request failed: %v", err)
}
return nil, err
}
switch {
case resp.StatusCode == http.StatusUnauthorized:
// Handle authentication error with one retry
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
if err != nil { if err != nil {
log.Printf("couldn't start upload: %v", err)
return nil, err return nil, err
} }
regOpts.Token = token
switch { if body != nil {
case resp.StatusCode == http.StatusUnauthorized: _, err = body.Seek(0, io.SeekStart)
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
regOpts.Token = token resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if body != nil { if resp.StatusCode == http.StatusUnauthorized {
body.Seek(0, io.SeekStart) return nil, errUnauthorized
} }
continue
case resp.StatusCode == http.StatusNotFound:
return nil, os.ErrNotExist
case resp.StatusCode >= http.StatusBadRequest:
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
return nil, fmt.Errorf("%d: %s", resp.StatusCode, body) return resp, err
default: case resp.StatusCode == http.StatusNotFound:
return resp, nil return nil, os.ErrNotExist
case resp.StatusCode >= http.StatusBadRequest:
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
} }
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
} }
return nil, errMaxRetriesExceeded return resp, nil
} }
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
......
...@@ -2,6 +2,7 @@ package server ...@@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -26,6 +27,7 @@ import ( ...@@ -26,6 +27,7 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
...@@ -409,8 +411,31 @@ func CreateModelHandler(c *gin.Context) { ...@@ -409,8 +411,31 @@ func CreateModelHandler(c *gin.Context) {
return return
} }
if req.Name == "" || req.Path == "" { if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
}
var modelfile io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
bin, err := os.Open(req.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
defer bin.Close()
modelfile = bin
}
commands, err := parser.Parse(modelfile)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
...@@ -424,7 +449,7 @@ func CreateModelHandler(c *gin.Context) { ...@@ -424,7 +449,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { if err := CreateModel(ctx, req.Name, commands, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
...@@ -625,6 +650,60 @@ func CopyModelHandler(c *gin.Context) { ...@@ -625,6 +650,60 @@ func CopyModelHandler(c *gin.Context) {
} }
} }
func HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
c.Status(http.StatusOK)
}
func CreateBlobHandler(c *gin.Context) {
targetPath, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
hash := sha256.New()
temp, err := os.CreateTemp(filepath.Dir(targetPath), c.Param("digest")+"-")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer temp.Close()
defer os.Remove(temp.Name())
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
return
}
if err := temp.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.Rename(temp.Name(), targetPath); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Status(http.StatusCreated)
}
var defaultAllowOrigins = []string{ var defaultAllowOrigins = []string{
"localhost", "localhost",
"127.0.0.1", "127.0.0.1",
...@@ -684,6 +763,8 @@ func Serve(ln net.Listener, allowOrigins []string) error { ...@@ -684,6 +763,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/copy", CopyModelHandler) r.POST("/api/copy", CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler) r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler) r.POST("/api/show", ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} { for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) { r.Handle(method, "/", func(c *gin.Context) {
...@@ -713,7 +794,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { ...@@ -713,7 +794,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
// check compatibility to log warnings // check compatibility to log warnings
if _, err := llm.CheckVRAM(); err != nil { if _, err := llm.CheckVRAM(); err != nil {
log.Printf("Warning: GPU support may not be enabled, check you have installed GPU drivers: %v", err) log.Printf(err.Error())
} }
} }
......
...@@ -5,9 +5,9 @@ import ( ...@@ -5,9 +5,9 @@ import (
"crypto/md5" "crypto/md5"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"log" "log"
"math"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
...@@ -35,6 +35,8 @@ type blobUpload struct { ...@@ -35,6 +35,8 @@ type blobUpload struct {
context.CancelFunc context.CancelFunc
file *os.File
done bool done bool
err error err error
references atomic.Int32 references atomic.Int32
...@@ -42,8 +44,8 @@ type blobUpload struct { ...@@ -42,8 +44,8 @@ type blobUpload struct {
const ( const (
numUploadParts = 64 numUploadParts = 64
minUploadPartSize int64 = 95 * 1000 * 1000 minUploadPartSize int64 = 100 * format.MegaByte
maxUploadPartSize int64 = 1000 * 1000 * 1000 maxUploadPartSize int64 = 1000 * format.MegaByte
) )
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
...@@ -55,7 +57,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg ...@@ -55,7 +57,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
if b.From != "" { if b.From != "" {
values := requestURL.Query() values := requestURL.Query()
values.Add("mount", b.Digest) values.Add("mount", b.Digest)
values.Add("from", b.From) values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
requestURL.RawQuery = values.Encode() requestURL.RawQuery = values.Encode()
} }
...@@ -77,6 +79,14 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg ...@@ -77,6 +79,14 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
b.Total = fi.Size() b.Total = fi.Size()
// http.StatusCreated indicates a blob has been mounted
// ref: https://distribution.github.io/distribution/spec/api/#cross-repository-blob-mount
if resp.StatusCode == http.StatusCreated {
b.Completed.Store(b.Total)
b.done = true
return nil
}
var size = b.Total / numUploadParts var size = b.Total / numUploadParts
switch { switch {
case size < minUploadPartSize: case size < minUploadPartSize:
...@@ -120,12 +130,12 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { ...@@ -120,12 +130,12 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
return return
} }
f, err := os.Open(p) b.file, err = os.Open(p)
if err != nil { if err != nil {
b.err = err b.err = err
return return
} }
defer f.Close() defer b.file.Close()
g, inner := errgroup.WithContext(ctx) g, inner := errgroup.WithContext(ctx)
g.SetLimit(numUploadParts) g.SetLimit(numUploadParts)
...@@ -137,7 +147,6 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { ...@@ -137,7 +147,6 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
g.Go(func() error { g.Go(func() error {
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
part.ReadSeeker = io.NewSectionReader(f, part.Offset, part.Size)
err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts) err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
...@@ -145,7 +154,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { ...@@ -145,7 +154,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
case errors.Is(err, errMaxRetriesExceeded): case errors.Is(err, errMaxRetriesExceeded):
return err return err
case err != nil: case err != nil:
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err) part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
continue continue
} }
...@@ -165,8 +177,16 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { ...@@ -165,8 +177,16 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
requestURL := <-b.nextURL requestURL := <-b.nextURL
var sb strings.Builder var sb strings.Builder
// calculate md5 checksum and add it to the commit request
for _, part := range b.Parts { for _, part := range b.Parts {
sb.Write(part.Sum(nil)) hash := md5.New()
if _, err := io.Copy(hash, io.NewSectionReader(b.file, part.Offset, part.Size)); err != nil {
b.err = err
return
}
sb.Write(hash.Sum(nil))
} }
md5sum := md5.Sum([]byte(sb.String())) md5sum := md5.Sum([]byte(sb.String()))
...@@ -180,29 +200,39 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { ...@@ -180,29 +200,39 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0") headers.Set("Content-Length", "0")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts) for try := 0; try < maxRetries; try++ {
if err != nil { resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
b.err = err if err != nil {
b.err = err
if errors.Is(err, context.Canceled) {
return
}
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep)
time.Sleep(sleep)
continue
}
defer resp.Body.Close()
b.err = nil
b.done = true
return return
} }
defer resp.Body.Close()
b.done = true
} }
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error { func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
part.Reset()
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size)) headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
headers.Set("X-Redirect-Uploads", "1")
if method == http.MethodPatch { if method == http.MethodPatch {
headers.Set("X-Redirect-Uploads", "1")
headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1)) headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
} }
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(part.ReadSeeker, io.MultiWriter(part, part.Hash)), opts) sr := io.NewSectionReader(b.file, part.Offset, part.Size)
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, part), opts)
if err != nil { if err != nil {
return err return err
} }
...@@ -227,6 +257,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL ...@@ -227,6 +257,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
return err return err
} }
// retry uploading to the redirect URL
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil) err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil)
switch { switch {
...@@ -235,7 +266,10 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL ...@@ -235,7 +266,10 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
case errors.Is(err, errMaxRetriesExceeded): case errors.Is(err, errMaxRetriesExceeded):
return err return err
case err != nil: case err != nil:
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err) part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
continue continue
} }
...@@ -260,7 +294,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL ...@@ -260,7 +294,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
return err return err
} }
return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body) return fmt.Errorf("http status %s: %s", resp.Status, body)
} }
if method == http.MethodPatch { if method == http.MethodPatch {
...@@ -293,7 +327,7 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er ...@@ -293,7 +327,7 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er
} }
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", b.Digest), Status: fmt.Sprintf("pushing %s", b.Digest[7:19]),
Digest: b.Digest, Digest: b.Digest,
Total: b.Total, Total: b.Total,
Completed: b.Completed.Load(), Completed: b.Completed.Load(),
...@@ -307,14 +341,10 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er ...@@ -307,14 +341,10 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er
type blobUploadPart struct { type blobUploadPart struct {
// N is the part number // N is the part number
N int N int
Offset int64 Offset int64
Size int64 Size int64
hash.Hash
written int64 written int64
io.ReadSeeker
*blobUpload *blobUpload
} }
...@@ -326,10 +356,8 @@ func (p *blobUploadPart) Write(b []byte) (n int, err error) { ...@@ -326,10 +356,8 @@ func (p *blobUploadPart) Write(b []byte) (n int, err error) {
} }
func (p *blobUploadPart) Reset() { func (p *blobUploadPart) Reset() {
p.Seek(0, io.SeekStart)
p.Completed.Add(-int64(p.written)) p.Completed.Add(-int64(p.written))
p.written = 0 p.written = 0
p.Hash = md5.New()
} }
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error { func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
...@@ -344,7 +372,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryO ...@@ -344,7 +372,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryO
default: default:
defer resp.Body.Close() defer resp.Body.Close()
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest), Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]),
Digest: layer.Digest, Digest: layer.Digest,
Total: layer.Size, Total: layer.Size,
Completed: layer.Size, Completed: layer.Size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment