"examples/vscode:/vscode.git/clone" did not exist on "44638b93364efcf74ed9fe99eba45c73a0196799"
Unverified Commit 1e7f62cb authored by Blake Mizerany's avatar Blake Mizerany Committed by GitHub
Browse files

cmd: add retry/backoff (#10069)

This commit adds retry/backoff to the registry client for pull requests.

Also, revert progress indication to match original client's until we can
"get it right."

Also, make WithTrace wrap existing traces instead of clobbering them.
This allows clients to compose traces.
parent ccb7eb81
...@@ -808,13 +808,38 @@ func PullHandler(cmd *cobra.Command, args []string) error { ...@@ -808,13 +808,38 @@ func PullHandler(cmd *cobra.Command, args []string) error {
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
}
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
} }
bar, ok := bars[resp.Digest] bar, ok := bars[resp.Digest]
if !ok { if !ok {
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed) name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar bars[resp.Digest] = bar
p.Add(resp.Digest, bar) p.Add(resp.Digest, bar)
} }
...@@ -834,11 +859,7 @@ func PullHandler(cmd *cobra.Command, args []string) error { ...@@ -834,11 +859,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
} }
request := api.PullRequest{Name: args[0], Insecure: insecure} request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(cmd.Context(), &request, fn); err != nil { return client.Pull(cmd.Context(), &request, fn)
return err
}
return nil
} }
type generateContextKey string type generateContextKey string
......
...@@ -107,15 +107,20 @@ func DefaultCache() (*blob.DiskCache, error) { ...@@ -107,15 +107,20 @@ func DefaultCache() (*blob.DiskCache, error) {
// //
// In both cases, the code field is optional and may be empty. // In both cases, the code field is optional and may be empty.
type Error struct { type Error struct {
Status int `json:"-"` // TODO(bmizerany): remove this status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"` Code string `json:"code"`
Message string `json:"message"` Message string `json:"message"`
} }
// Temporary reports if the error is temporary (e.g. 5xx status code).
func (e *Error) Temporary() bool {
return e.status >= 500
}
func (e *Error) Error() string { func (e *Error) Error() string {
var b strings.Builder var b strings.Builder
b.WriteString("registry responded with status ") b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.Status)) b.WriteString(strconv.Itoa(e.status))
if e.Code != "" { if e.Code != "" {
b.WriteString(": code ") b.WriteString(": code ")
b.WriteString(e.Code) b.WriteString(e.Code)
...@@ -129,7 +134,7 @@ func (e *Error) Error() string { ...@@ -129,7 +134,7 @@ func (e *Error) Error() string {
func (e *Error) LogValue() slog.Value { func (e *Error) LogValue() slog.Value {
return slog.GroupValue( return slog.GroupValue(
slog.Int("status", e.Status), slog.Int("status", e.status),
slog.String("code", e.Code), slog.String("code", e.Code),
slog.String("message", e.Message), slog.String("message", e.Message),
) )
...@@ -428,12 +433,12 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { ...@@ -428,12 +433,12 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
type trackingReader struct { type trackingReader struct {
l *Layer l *Layer
r io.Reader r io.Reader
update func(l *Layer, n int64, err error) update func(n int64)
} }
func (r *trackingReader) Read(p []byte) (n int, err error) { func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p) n, err = r.r.Read(p)
r.update(r.l, int64(n), nil) r.update(int64(n))
return return
} }
...@@ -478,23 +483,42 @@ func (r *Registry) Pull(ctx context.Context, name string) error { ...@@ -478,23 +483,42 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
expected += l.Size expected += l.Size
} }
var received atomic.Int64 var completed atomic.Int64
var g errgroup.Group var g errgroup.Group
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
for _, l := range layers { for _, l := range layers {
var received atomic.Int64
info, err := c.Get(l.Digest) info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size { if err == nil && info.Size == l.Size {
received.Add(l.Size) received.Add(l.Size)
completed.Add(l.Size)
t.update(l, l.Size, ErrCached) t.update(l, l.Size, ErrCached)
continue continue
} }
func() {
var wg sync.WaitGroup var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size) chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil { if err != nil {
t.update(l, 0, err) t.update(l, received.Load(), err)
continue return
} }
defer func() {
// Close the chunked writer when all chunks are
// downloaded.
//
// This is done as a background task in the
// group to allow the next layer to start while
// we wait for the final chunk in this layer to
// complete. It also ensures this is done
// before we exit Pull.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
}()
for cs, err := range r.chunksums(ctx, name, l) { for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if err != nil {
...@@ -502,7 +526,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { ...@@ -502,7 +526,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// log and let in-flight downloads complete. // log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete // This will naturally trigger ErrIncomplete
// since received < expected bytes. // since received < expected bytes.
t.update(l, 0, err) t.update(l, received.Load(), err)
break break
} }
...@@ -516,8 +540,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error { ...@@ -516,8 +540,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
cacheKeyDigest := blob.DigestFromBytes(cacheKey) cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest) _, err := c.Get(cacheKeyDigest)
if err == nil { if err == nil {
received.Add(cs.Chunk.Size()) recv := received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached) completed.Add(cs.Chunk.Size())
t.update(l, recv, ErrCached)
continue continue
} }
...@@ -536,10 +561,8 @@ func (r *Registry) Pull(ctx context.Context, name string) error { ...@@ -536,10 +561,8 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Not incorrect, just suboptimal - fix this in a // Not incorrect, just suboptimal - fix this in a
// future update. // future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey) _ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
received.Add(cs.Chunk.Size())
} else { } else {
t.update(l, 0, err) t.update(l, received.Load(), err)
} }
wg.Done() wg.Done()
}() }()
...@@ -555,34 +578,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error { ...@@ -555,34 +578,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
} }
defer res.Body.Close() defer res.Body.Close()
body := &trackingReader{l: l, r: res.Body, update: t.update} tr := &trackingReader{
return chunked.Put(cs.Chunk, cs.Digest, body) l: l,
}) r: res.Body,
update: func(n int64) {
completed.Add(n)
recv := received.Add(n)
t.update(l, recv, nil)
},
} }
return chunked.Put(cs.Chunk, cs.Digest, tr)
// Close writer immediately after downloads finish, not at Pull
// exit. Using defer would keep file descriptors open until all
// layers complete, potentially exhausting system limits with
// many layers.
//
// The WaitGroup tracks when all chunks finish downloading,
// allowing precise writer closure in a background goroutine.
// Each layer briefly uses one extra goroutine while at most
// maxStreams()-1 chunks download in parallel.
//
// This caps file descriptors at maxStreams() instead of
// growing with layer count.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
}) })
} }
}()
}
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err return err
} }
if received.Load() != expected { if recv := completed.Load(); recv != expected {
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected) return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, recv, expected)
} }
md := blob.DigestFromBytes(m.Data) md := blob.DigestFromBytes(m.Data)
...@@ -973,7 +987,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) ...@@ -973,7 +987,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error)
return nil, ErrModelNotFound return nil, ErrModelNotFound
} }
re.Status = res.StatusCode re.status = res.StatusCode
return nil, &re return nil, &re
} }
return res, nil return res, nil
......
...@@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) { ...@@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) {
func checkErrCode(t *testing.T, err error, status int, code string) { func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper() t.Helper()
var e *Error var e *Error
if !errors.As(err, &e) || e.Status != status || e.Code != code { if !errors.As(err, &e) || e.status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code) t.Errorf("err = %v; want %v %v", err, status, code)
} }
} }
...@@ -860,8 +860,8 @@ func TestPullChunksumStreaming(t *testing.T) { ...@@ -860,8 +860,8 @@ func TestPullChunksumStreaming(t *testing.T) {
// now send the second chunksum and ensure it kicks off work immediately // now send the second chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c")) fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
if g := <-update; g != 1 { if g := <-update; g != 3 {
t.Fatalf("got %d, want 1", g) t.Fatalf("got %d, want 3", g)
} }
csw.Close() csw.Close()
testutil.Check(t, <-errc) testutil.Check(t, <-errc)
...@@ -944,10 +944,10 @@ func TestPullChunksumsCached(t *testing.T) { ...@@ -944,10 +944,10 @@ func TestPullChunksumsCached(t *testing.T) {
_, err = c.Cache.Resolve("o.com/library/abc:latest") _, err = c.Cache.Resolve("o.com/library/abc:latest")
check(err) check(err)
if g := written.Load(); g != 3 { if g := written.Load(); g != 5 {
t.Fatalf("wrote %d bytes, want 3", g) t.Fatalf("wrote %d bytes, want 3", g)
} }
if g := cached.Load(); g != 2 { // "ab" should have been cached if g := cached.Load(); g != 2 { // "ab" should have been cached
t.Fatalf("cached %d bytes, want 3", g) t.Fatalf("cached %d bytes, want 5", g)
} }
} }
...@@ -34,10 +34,27 @@ func (t *Trace) update(l *Layer, n int64, err error) { ...@@ -34,10 +34,27 @@ func (t *Trace) update(l *Layer, n int64, err error) {
type traceKey struct{} type traceKey struct{}
// WithTrace returns a context derived from ctx that uses t to report trace // WithTrace adds a trace to the context for transfer progress reporting.
// events.
func WithTrace(ctx context.Context, t *Trace) context.Context { func WithTrace(ctx context.Context, t *Trace) context.Context {
return context.WithValue(ctx, traceKey{}, t) old := traceFromContext(ctx)
if old == t {
// No change, return the original context. This also prevents
// infinite recursion below, if the caller passes the same
// Trace.
return ctx
}
// Create a new Trace that wraps the old one, if any. If we used the
// same pointer t, we end up with a recursive structure.
composed := &Trace{
Update: func(l *Layer, n int64, err error) {
if old != nil {
old.update(l, n, err)
}
t.update(l, n, err)
},
}
return context.WithValue(ctx, traceKey{}, composed)
} }
var emptyTrace = &Trace{} var emptyTrace = &Trace{}
......
...@@ -9,13 +9,14 @@ import ( ...@@ -9,13 +9,14 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"maps"
"net/http" "net/http"
"slices"
"sync" "sync"
"time" "time"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/internal/backoff"
) )
// Local implements an http.Handler for handling local Ollama API model // Local implements an http.Handler for handling local Ollama API model
...@@ -265,49 +266,53 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { ...@@ -265,49 +266,53 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
} }
return err return err
} }
return enc.Encode(progressUpdateJSON{Status: "success"}) enc.Encode(progressUpdateJSON{Status: "success"})
} return nil
maybeFlush := func() {
fl, _ := w.(http.Flusher)
if fl != nil {
fl.Flush()
}
} }
defer maybeFlush()
var mu sync.Mutex var mu sync.Mutex
progress := make(map[*ollama.Layer]int64) var progress []progressUpdateJSON
progressCopy := make(map[*ollama.Layer]int64, len(progress))
flushProgress := func() { flushProgress := func() {
defer maybeFlush()
// TODO(bmizerany): Flushing every layer in one update doesn't
// scale well. We could flush only the modified layers or track
// the full download. Needs further consideration, though it's
// fine for now.
mu.Lock() mu.Lock()
maps.Copy(progressCopy, progress) progress := slices.Clone(progress) // make a copy and release lock before encoding to the wire
mu.Unlock() mu.Unlock()
for l, n := range progressCopy { for _, p := range progress {
enc.Encode(progressUpdateJSON{ enc.Encode(p)
Digest: l.Digest, }
Total: l.Size, fl, _ := w.(http.Flusher)
Completed: n, if fl != nil {
}) fl.Flush()
} }
} }
defer flushProgress() defer flushProgress()
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer t := time.NewTicker(1<<63 - 1) // "unstarted" timer
start := sync.OnceFunc(func() { start := sync.OnceFunc(func() {
flushProgress() // flush initial state flushProgress() // flush initial state
t.Reset(100 * time.Millisecond) t.Reset(100 * time.Millisecond)
}) })
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) { Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 { if err != nil && !errors.Is(err, ollama.ErrCached) {
s.Logger.Error("pulling", "model", p.model(), "error", err)
return
}
func() {
mu.Lock()
defer mu.Unlock()
for i, p := range progress {
if p.Digest == l.Digest {
progress[i].Completed = n
return
}
}
progress = append(progress, progressUpdateJSON{
Digest: l.Digest,
Total: l.Size,
})
}()
// Block flushing progress updates until every // Block flushing progress updates until every
// layer is accounted for. Clients depend on a // layer is accounted for. Clients depend on a
// complete model size to calculate progress // complete model size to calculate progress
...@@ -315,18 +320,27 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { ...@@ -315,18 +320,27 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
// progress indicators would erratically jump // progress indicators would erratically jump
// as new layers are registered. // as new layers are registered.
start() start()
}
mu.Lock()
progress[l] += n
mu.Unlock()
}, },
}) })
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() (err error) {
done <- s.Client.Pull(ctx, p.model()) defer func() { done <- err }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := s.Client.Pull(ctx, p.model())
var oe *ollama.Error
if errors.As(err, &oe) && oe.Temporary() {
continue // retry
}
return err
}
return nil
}() }()
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
for { for {
select { select {
case <-t.C: case <-t.C:
...@@ -341,7 +355,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { ...@@ -341,7 +355,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
status = fmt.Sprintf("error: %v", err) status = fmt.Sprintf("error: %v", err)
} }
enc.Encode(progressUpdateJSON{Status: status}) enc.Encode(progressUpdateJSON{Status: status})
return nil
} }
// Emulate old client pull progress (for now):
enc.Encode(progressUpdateJSON{Status: "verifying sha256 digest"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return nil return nil
} }
} }
......
...@@ -78,7 +78,12 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local { ...@@ -78,7 +78,12 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder { func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
t.Helper() t.Helper()
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body)) ctx := ollama.WithTrace(t.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
t.Logf("update: %s %d %v", l.Digest, n, err)
},
})
req := httptest.NewRequestWithContext(ctx, method, path, strings.NewReader(body))
return s.sendRequest(t, req) return s.sendRequest(t, req)
} }
...@@ -184,36 +189,34 @@ func TestServerPull(t *testing.T) { ...@@ -184,36 +189,34 @@ func TestServerPull(t *testing.T) {
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) { checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper() t.Helper()
if got.Code != 200 { if got.Code != 200 {
t.Errorf("Code = %d; want 200", got.Code) t.Errorf("Code = %d; want 200", got.Code)
} }
gotlines := got.Body.String() gotlines := got.Body.String()
if strings.TrimSpace(gotlines) == "" {
gotlines = "<empty>"
}
t.Logf("got:\n%s", gotlines) t.Logf("got:\n%s", gotlines)
for want := range strings.Lines(wantlines) { for want := range strings.Lines(wantlines) {
want = strings.TrimSpace(want) want = strings.TrimSpace(want)
want, unwanted := strings.CutPrefix(want, "!") want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want) want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) { if !unwanted && !strings.Contains(gotlines, want) {
t.Errorf("! missing %q in body", want) t.Errorf("\t! missing %q in body", want)
} }
if unwanted && strings.Contains(gotlines, want) { if unwanted && strings.Contains(gotlines, want) {
t.Errorf("! unexpected %q in body", want) t.Errorf("\t! unexpected %q in body", want)
} }
} }
} }
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, ` checkResponse(got, `
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} {"status":"pulling manifest"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5} {"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3} {"status":"verifying sha256 digest"}
{"status":"writing manifest"}
{"status":"success"}
`) `)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`) got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment