Unverified Commit c8af3c2d authored by Blake Mizerany's avatar Blake Mizerany Committed by GitHub
Browse files

server: reuse original download URL for images (#5962)

This changes the registry client to reuse the original download URL
it gets on the first redirect response for all subsequent requests,
preventing thundering herd issues when hot new LLMs are released.
parent 455e6117
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"math" "math"
"math/rand/v2"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
...@@ -141,6 +142,32 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis ...@@ -141,6 +142,32 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis
b.err = b.run(ctx, requestURL, opts) b.err = b.run(ctx, requestURL, opts)
} }
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
var n int
return func(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
defer t.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return nil
}
}
}
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
defer blobDownloadManager.Delete(b.Digest) defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx) ctx, b.CancelFunc = context.WithCancel(ctx)
...@@ -153,6 +180,52 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis ...@@ -153,6 +180,52 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
_ = file.Truncate(b.Total) _ = file.Truncate(b.Total)
directURL, err := func() (*url.URL, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
backoff := newBackoff(10 * time.Second)
for {
// shallow clone opts to be used in the closure
// without affecting the outer opts.
newOpts := new(registryOptions)
*newOpts = *opts
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 {
return errors.New("maxium redirects exceeded (10) for directURL")
}
// if the hostname is the same, allow the redirect
if req.URL.Hostname() == requestURL.Hostname() {
return nil
}
// stop at the first redirect that is not
// the same hostname as the original
// request.
return http.ErrUseLastResponse
}
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
if err != nil {
slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
if err := backoff(ctx); err != nil {
return nil, err
}
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusTemporaryRedirect {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
return resp.Location()
}
}()
if err != nil {
return err
}
g, inner := errgroup.WithContext(ctx) g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts) g.SetLimit(numDownloadParts)
for i := range b.Parts { for i := range b.Parts {
...@@ -165,7 +238,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis ...@@ -165,7 +238,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt()) w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, requestURL, w, part, opts) err = b.downloadChunk(inner, directURL, w, part, opts)
switch { switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space // return immediately if the context is canceled or the device is out of space
......
...@@ -54,6 +54,8 @@ type registryOptions struct { ...@@ -54,6 +54,8 @@ type registryOptions struct {
Username string Username string
Password string Password string
Token string Token string
CheckRedirect func(req *http.Request, via []*http.Request) error
} }
type Model struct { type Model struct {
...@@ -1131,7 +1133,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header ...@@ -1131,7 +1133,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
req.ContentLength = contentLength req.ContentLength = contentLength
} }
resp, err := http.DefaultClient.Do(req) resp, err := (&http.Client{
CheckRedirect: regOpts.CheckRedirect,
}).Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment