Unverified Commit ef27d52e authored by Blake Mizerany's avatar Blake Mizerany Committed by GitHub
Browse files

server/internal/client/ollama: cache completed chunks (#9933)

This change adds tracking of download chunks during the pull process so
that subsequent pulls can skip downloading already completed chunks.
This works across restarts of ollama.

Currently, download state will be lost if a prune is triggered during a
pull (e.g. restart or remove). This issue should be addressed in a
follow-up PR.
parent b2a46529
...@@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { ...@@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
return err return err
} }
func canRetry(err error) bool {
var re *Error
if !errors.As(err, &re) {
return false
}
return re.Status >= 500
}
// trackingReader is an io.Reader that tracks the number of bytes read and // trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read. // calls the update function with the layer, the number of bytes read.
// //
...@@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error { ...@@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
break break
} }
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
wg.Add(1) wg.Add(1)
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { defer func() {
if err == nil { if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
received.Add(cs.Chunk.Size()) received.Add(cs.Chunk.Size())
} else { } else {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err) t.update(l, 0, err)
} }
wg.Done() wg.Done()
}() }()
...@@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { ...@@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err return err
} }
if received.Load() != expected { if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected) return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
} }
md := blob.DigestFromBytes(m.Data) md := blob.DigestFromBytes(m.Data)
...@@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer { ...@@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
return nil return nil
} }
func (m *Manifest) All() iter.Seq[*Layer] {
return func(yield func(*Layer) bool) {
if !yield(m.Config) {
return
}
for _, l := range m.Layers {
if !yield(l) {
return
}
}
}
}
func (m *Manifest) Size() int64 {
var size int64
if m.Config != nil {
size += m.Config.Size
}
for _, l := range m.Layers {
size += l.Size
}
return size
}
// MarshalJSON implements json.Marshaler. // MarshalJSON implements json.Marshaler.
// //
// NOTE: It adds an empty config object to the manifest, which is required by // NOTE: It adds an empty config object to the manifest, which is required by
...@@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se ...@@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return return
} }
// A chunksums response is a sequence of chunksums in a // The response is a sequence of chunksums.
// simple, easy to parse line-oriented format. //
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// > GET /v2/<namespace>/<model>/chunksums/<digest>
// //
// Example: // < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
// //
// >> GET /v2/<namespace>/<model>/chunksums/<digest> // The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
// //
// << HTTP/1.1 200 OK // Ranges may be used directly in Range headers like
// << Content-Location: <blobURL> // "bytes=<start>-<end>".
// <<
// << <digest> <start>-<end>
// << ...
// //
// The blobURL is the URL to download the chunks from. // The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s", chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme, scheme,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment