Unverified Commit 1e7f62cb authored by Blake Mizerany's avatar Blake Mizerany Committed by GitHub
Browse files

cmd: add retry/backoff (#10069)

This commit adds retry/backoff to the registry client for pull requests.

Also, revert progress indication to match original client's until we can
"get it right."

Also, make WithTrace wrap existing traces instead of clobbering them.
This allows clients to compose traces.
parent ccb7eb81
...@@ -808,13 +808,38 @@ func PullHandler(cmd *cobra.Command, args []string) error { ...@@ -808,13 +808,38 @@ func PullHandler(cmd *cobra.Command, args []string) error {
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
}
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
} }
bar, ok := bars[resp.Digest] bar, ok := bars[resp.Digest]
if !ok { if !ok {
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed) name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar bars[resp.Digest] = bar
p.Add(resp.Digest, bar) p.Add(resp.Digest, bar)
} }
...@@ -834,11 +859,7 @@ func PullHandler(cmd *cobra.Command, args []string) error { ...@@ -834,11 +859,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
} }
request := api.PullRequest{Name: args[0], Insecure: insecure} request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(cmd.Context(), &request, fn); err != nil { return client.Pull(cmd.Context(), &request, fn)
return err
}
return nil
} }
type generateContextKey string type generateContextKey string
......
...@@ -107,15 +107,20 @@ func DefaultCache() (*blob.DiskCache, error) { ...@@ -107,15 +107,20 @@ func DefaultCache() (*blob.DiskCache, error) {
// //
// In both cases, the code field is optional and may be empty. // In both cases, the code field is optional and may be empty.
type Error struct { type Error struct {
Status int `json:"-"` // TODO(bmizerany): remove this status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"` Code string `json:"code"`
Message string `json:"message"` Message string `json:"message"`
} }
// Temporary reports if the error is temporary (e.g. 5xx status code).
func (e *Error) Temporary() bool {
return e.status >= 500
}
func (e *Error) Error() string { func (e *Error) Error() string {
var b strings.Builder var b strings.Builder
b.WriteString("registry responded with status ") b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.Status)) b.WriteString(strconv.Itoa(e.status))
if e.Code != "" { if e.Code != "" {
b.WriteString(": code ") b.WriteString(": code ")
b.WriteString(e.Code) b.WriteString(e.Code)
...@@ -129,7 +134,7 @@ func (e *Error) Error() string { ...@@ -129,7 +134,7 @@ func (e *Error) Error() string {
func (e *Error) LogValue() slog.Value { func (e *Error) LogValue() slog.Value {
return slog.GroupValue( return slog.GroupValue(
slog.Int("status", e.Status), slog.Int("status", e.status),
slog.String("code", e.Code), slog.String("code", e.Code),
slog.String("message", e.Message), slog.String("message", e.Message),
) )
...@@ -428,12 +433,12 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { ...@@ -428,12 +433,12 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
type trackingReader struct { type trackingReader struct {
l *Layer l *Layer
r io.Reader r io.Reader
update func(l *Layer, n int64, err error) update func(n int64)
} }
func (r *trackingReader) Read(p []byte) (n int, err error) { func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p) n, err = r.r.Read(p)
r.update(r.l, int64(n), nil) r.update(int64(n))
return return
} }
...@@ -478,111 +483,120 @@ func (r *Registry) Pull(ctx context.Context, name string) error { ...@@ -478,111 +483,120 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
expected += l.Size expected += l.Size
} }
var received atomic.Int64 var completed atomic.Int64
var g errgroup.Group var g errgroup.Group
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
for _, l := range layers { for _, l := range layers {
var received atomic.Int64
info, err := c.Get(l.Digest) info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size { if err == nil && info.Size == l.Size {
received.Add(l.Size) received.Add(l.Size)
completed.Add(l.Size)
t.update(l, l.Size, ErrCached) t.update(l, l.Size, ErrCached)
continue continue
} }
var wg sync.WaitGroup func() {
chunked, err := c.Chunked(l.Digest, l.Size) var wg sync.WaitGroup
if err != nil { chunked, err := c.Chunked(l.Digest, l.Size)
t.update(l, 0, err)
continue
}
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if err != nil {
// Chunksum stream interrupted. Note in trace t.update(l, received.Load(), err)
// log and let in-flight downloads complete. return
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
}
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
} }
defer func() {
wg.Add(1) // Close the chunked writer when all chunks are
g.Go(func() (err error) { // downloaded.
defer func() { //
if err == nil { // This is done as a background task in the
// Ignore cache key write errors for now. We've already // group to allow the next layer to start while
// reported to trace that the chunk is complete. // we wait for the final chunk in this layer to
// // complete. It also ensures this is done
// Ideally, we should only report completion to trace // before we exit Pull.
// after successful cache commit. This current approach g.Go(func() error {
// works but could trigger unnecessary redownloads if wg.Wait()
// the checkpoint key is missing on next pull. chunked.Close()
// return nil
// Not incorrect, just suboptimal - fix this in a })
// future update. }()
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
for cs, err := range r.chunksums(ctx, name, l) {
received.Add(cs.Chunk.Size())
} else {
t.update(l, 0, err)
}
wg.Done()
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil { if err != nil {
return err // Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, received.Load(), err)
break
} }
defer res.Body.Close()
body := &trackingReader{l: l, r: res.Body, update: t.update} cacheKey := fmt.Sprintf(
return chunked.Put(cs.Chunk, cs.Digest, body) "v1 pull chunksum %s %s %d-%d",
}) l.Digest,
} cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
recv := received.Add(cs.Chunk.Size())
completed.Add(cs.Chunk.Size())
t.update(l, recv, ErrCached)
continue
}
// Close writer immediately after downloads finish, not at Pull wg.Add(1)
// exit. Using defer would keep file descriptors open until all g.Go(func() (err error) {
// layers complete, potentially exhausting system limits with defer func() {
// many layers. if err == nil {
// // Ignore cache key write errors for now. We've already
// The WaitGroup tracks when all chunks finish downloading, // reported to trace that the chunk is complete.
// allowing precise writer closure in a background goroutine. //
// Each layer briefly uses one extra goroutine while at most // Ideally, we should only report completion to trace
// maxStreams()-1 chunks download in parallel. // after successful cache commit. This current approach
// // works but could trigger unnecessary redownloads if
// This caps file descriptors at maxStreams() instead of // the checkpoint key is missing on next pull.
// growing with layer count. //
g.Go(func() error { // Not incorrect, just suboptimal - fix this in a
wg.Wait() // future update.
chunked.Close() _ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
return nil } else {
}) t.update(l, received.Load(), err)
}
wg.Done()
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tr := &trackingReader{
l: l,
r: res.Body,
update: func(n int64) {
completed.Add(n)
recv := received.Add(n)
t.update(l, recv, nil)
},
}
return chunked.Put(cs.Chunk, cs.Digest, tr)
})
}
}()
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err return err
} }
if received.Load() != expected { if recv := completed.Load(); recv != expected {
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected) return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, recv, expected)
} }
md := blob.DigestFromBytes(m.Data) md := blob.DigestFromBytes(m.Data)
...@@ -973,7 +987,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) ...@@ -973,7 +987,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error)
return nil, ErrModelNotFound return nil, ErrModelNotFound
} }
re.Status = res.StatusCode re.status = res.StatusCode
return nil, &re return nil, &re
} }
return res, nil return res, nil
......
...@@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) { ...@@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) {
func checkErrCode(t *testing.T, err error, status int, code string) { func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper() t.Helper()
var e *Error var e *Error
if !errors.As(err, &e) || e.Status != status || e.Code != code { if !errors.As(err, &e) || e.status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code) t.Errorf("err = %v; want %v %v", err, status, code)
} }
} }
...@@ -860,8 +860,8 @@ func TestPullChunksumStreaming(t *testing.T) { ...@@ -860,8 +860,8 @@ func TestPullChunksumStreaming(t *testing.T) {
// now send the second chunksum and ensure it kicks off work immediately // now send the second chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c")) fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
if g := <-update; g != 1 { if g := <-update; g != 3 {
t.Fatalf("got %d, want 1", g) t.Fatalf("got %d, want 3", g)
} }
csw.Close() csw.Close()
testutil.Check(t, <-errc) testutil.Check(t, <-errc)
...@@ -944,10 +944,10 @@ func TestPullChunksumsCached(t *testing.T) { ...@@ -944,10 +944,10 @@ func TestPullChunksumsCached(t *testing.T) {
_, err = c.Cache.Resolve("o.com/library/abc:latest") _, err = c.Cache.Resolve("o.com/library/abc:latest")
check(err) check(err)
if g := written.Load(); g != 3 { if g := written.Load(); g != 5 {
t.Fatalf("wrote %d bytes, want 3", g) t.Fatalf("wrote %d bytes, want 3", g)
} }
if g := cached.Load(); g != 2 { // "ab" should have been cached if g := cached.Load(); g != 2 { // "ab" should have been cached
t.Fatalf("cached %d bytes, want 3", g) t.Fatalf("cached %d bytes, want 5", g)
} }
} }
...@@ -34,10 +34,27 @@ func (t *Trace) update(l *Layer, n int64, err error) { ...@@ -34,10 +34,27 @@ func (t *Trace) update(l *Layer, n int64, err error) {
type traceKey struct{} type traceKey struct{}
// WithTrace returns a context derived from ctx that uses t to report trace // WithTrace adds a trace to the context for transfer progress reporting.
// events.
func WithTrace(ctx context.Context, t *Trace) context.Context { func WithTrace(ctx context.Context, t *Trace) context.Context {
return context.WithValue(ctx, traceKey{}, t) old := traceFromContext(ctx)
if old == t {
// No change, return the original context. This also prevents
// infinite recursion below, if the caller passes the same
// Trace.
return ctx
}
// Create a new Trace that wraps the old one, if any. If we used the
// same pointer t, we end up with a recursive structure.
composed := &Trace{
Update: func(l *Layer, n int64, err error) {
if old != nil {
old.update(l, n, err)
}
t.update(l, n, err)
},
}
return context.WithValue(ctx, traceKey{}, composed)
} }
var emptyTrace = &Trace{} var emptyTrace = &Trace{}
......
...@@ -9,13 +9,14 @@ import ( ...@@ -9,13 +9,14 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"maps"
"net/http" "net/http"
"slices"
"sync" "sync"
"time" "time"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/internal/backoff"
) )
// Local implements an http.Handler for handling local Ollama API model // Local implements an http.Handler for handling local Ollama API model
...@@ -265,68 +266,81 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { ...@@ -265,68 +266,81 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
} }
return err return err
} }
return enc.Encode(progressUpdateJSON{Status: "success"}) enc.Encode(progressUpdateJSON{Status: "success"})
return nil
} }
maybeFlush := func() {
fl, _ := w.(http.Flusher)
if fl != nil {
fl.Flush()
}
}
defer maybeFlush()
var mu sync.Mutex var mu sync.Mutex
progress := make(map[*ollama.Layer]int64) var progress []progressUpdateJSON
progressCopy := make(map[*ollama.Layer]int64, len(progress))
flushProgress := func() { flushProgress := func() {
defer maybeFlush()
// TODO(bmizerany): Flushing every layer in one update doesn't
// scale well. We could flush only the modified layers or track
// the full download. Needs further consideration, though it's
// fine for now.
mu.Lock() mu.Lock()
maps.Copy(progressCopy, progress) progress := slices.Clone(progress) // make a copy and release lock before encoding to the wire
mu.Unlock() mu.Unlock()
for l, n := range progressCopy { for _, p := range progress {
enc.Encode(progressUpdateJSON{ enc.Encode(p)
Digest: l.Digest, }
Total: l.Size, fl, _ := w.(http.Flusher)
Completed: n, if fl != nil {
}) fl.Flush()
} }
} }
defer flushProgress() defer flushProgress()
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer t := time.NewTicker(1<<63 - 1) // "unstarted" timer
start := sync.OnceFunc(func() { start := sync.OnceFunc(func() {
flushProgress() // flush initial state flushProgress() // flush initial state
t.Reset(100 * time.Millisecond) t.Reset(100 * time.Millisecond)
}) })
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) { Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 { if err != nil && !errors.Is(err, ollama.ErrCached) {
// Block flushing progress updates until every s.Logger.Error("pulling", "model", p.model(), "error", err)
// layer is accounted for. Clients depend on a return
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
} }
mu.Lock()
progress[l] += n func() {
mu.Unlock() mu.Lock()
defer mu.Unlock()
for i, p := range progress {
if p.Digest == l.Digest {
progress[i].Completed = n
return
}
}
progress = append(progress, progressUpdateJSON{
Digest: l.Digest,
Total: l.Size,
})
}()
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
}, },
}) })
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() (err error) {
done <- s.Client.Pull(ctx, p.model()) defer func() { done <- err }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := s.Client.Pull(ctx, p.model())
var oe *ollama.Error
if errors.As(err, &oe) && oe.Temporary() {
continue // retry
}
return err
}
return nil
}() }()
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
for { for {
select { select {
case <-t.C: case <-t.C:
...@@ -341,7 +355,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { ...@@ -341,7 +355,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
status = fmt.Sprintf("error: %v", err) status = fmt.Sprintf("error: %v", err)
} }
enc.Encode(progressUpdateJSON{Status: status}) enc.Encode(progressUpdateJSON{Status: status})
return nil
} }
// Emulate old client pull progress (for now):
enc.Encode(progressUpdateJSON{Status: "verifying sha256 digest"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return nil return nil
} }
} }
......
...@@ -78,7 +78,12 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local { ...@@ -78,7 +78,12 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder { func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
t.Helper() t.Helper()
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body)) ctx := ollama.WithTrace(t.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
t.Logf("update: %s %d %v", l.Digest, n, err)
},
})
req := httptest.NewRequestWithContext(ctx, method, path, strings.NewReader(body))
return s.sendRequest(t, req) return s.sendRequest(t, req)
} }
...@@ -184,36 +189,34 @@ func TestServerPull(t *testing.T) { ...@@ -184,36 +189,34 @@ func TestServerPull(t *testing.T) {
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) { checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper() t.Helper()
if got.Code != 200 { if got.Code != 200 {
t.Errorf("Code = %d; want 200", got.Code) t.Errorf("Code = %d; want 200", got.Code)
} }
gotlines := got.Body.String() gotlines := got.Body.String()
if strings.TrimSpace(gotlines) == "" {
gotlines = "<empty>"
}
t.Logf("got:\n%s", gotlines) t.Logf("got:\n%s", gotlines)
for want := range strings.Lines(wantlines) { for want := range strings.Lines(wantlines) {
want = strings.TrimSpace(want) want = strings.TrimSpace(want)
want, unwanted := strings.CutPrefix(want, "!") want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want) want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) { if !unwanted && !strings.Contains(gotlines, want) {
t.Errorf("! missing %q in body", want) t.Errorf("\t! missing %q in body", want)
} }
if unwanted && strings.Contains(gotlines, want) { if unwanted && strings.Contains(gotlines, want) {
t.Errorf("! unexpected %q in body", want) t.Errorf("\t! unexpected %q in body", want)
} }
} }
} }
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, ` checkResponse(got, `
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} {"status":"pulling manifest"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5} {"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3} {"status":"verifying sha256 digest"}
{"status":"writing manifest"}
{"status":"success"}
`) `)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`) got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment