download.go 11.3 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
	"errors"
	"fmt"
	"io"
9
	"log/slog"
Jeffrey Morgan's avatar
Jeffrey Morgan committed
10
	"math"
11
	"math/rand/v2"
12
	"net/http"
Michael Yang's avatar
Michael Yang committed
13
	"net/url"
14
	"os"
Michael Yang's avatar
Michael Yang committed
15
	"path/filepath"
16
	"strconv"
Michael Yang's avatar
Michael Yang committed
17
	"strings"
18
19
	"sync"
	"sync/atomic"
20
	"syscall"
21
	"time"
22

Michael Yang's avatar
Michael Yang committed
23
	"golang.org/x/sync/errgroup"
24

25
26
	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/format"
27
28
)

29
30
const maxRetries = 6

Michael Yang's avatar
lint  
Michael Yang committed
31
32
33
34
var (
	errMaxRetriesExceeded = errors.New("max retries exceeded")
	errPartStalled        = errors.New("part stalled")
)
35

36
var blobDownloadManager sync.Map
37

38
39
40
type blobDownload struct {
	Name   string
	Digest string
41

42
43
	Total     int64
	Completed atomic.Int64
44

45
	Parts []*blobDownloadPart
46

47
	context.CancelFunc
Michael Yang's avatar
Michael Yang committed
48

49
	done       chan struct{}
Michael Yang's avatar
Michael Yang committed
50
	err        error
Michael Yang's avatar
Michael Yang committed
51
	references atomic.Int32
52
}
53

54
type blobDownloadPart struct {
55
56
57
58
59
60
61
	N         int
	Offset    int64
	Size      int64
	Completed atomic.Int64

	lastUpdatedMu sync.Mutex
	lastUpdated   time.Time
Michael Yang's avatar
Michael Yang committed
62
63
64
65

	*blobDownload `json:"-"`
}

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
type jsonBlobDownloadPart struct {
	N         int
	Offset    int64
	Size      int64
	Completed int64
}

func (p *blobDownloadPart) MarshalJSON() ([]byte, error) {
	return json.Marshal(jsonBlobDownloadPart{
		N:         p.N,
		Offset:    p.Offset,
		Size:      p.Size,
		Completed: p.Completed.Load(),
	})
}

func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
	var j jsonBlobDownloadPart
	if err := json.Unmarshal(b, &j); err != nil {
		return err
	}
	*p = blobDownloadPart{
		N:      j.N,
		Offset: j.Offset,
		Size:   j.Size,
	}
	p.Completed.Store(j.Completed)
	return nil
}

96
97
const (
	numDownloadParts          = 64
Michael Yang's avatar
Michael Yang committed
98
99
	minDownloadPartSize int64 = 100 * format.MegaByte
	maxDownloadPartSize int64 = 1000 * format.MegaByte
100
101
)

Michael Yang's avatar
Michael Yang committed
102
103
104
105
func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
106
}
107

108
func (p *blobDownloadPart) StartsAt() int64 {
109
	return p.Offset + p.Completed.Load()
110
111
112
113
114
115
}

func (p *blobDownloadPart) StopsAt() int64 {
	return p.Offset + p.Size
}

116
117
118
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
	n = len(b)
	p.blobDownload.Completed.Add(int64(n))
119
	p.lastUpdatedMu.Lock()
120
	p.lastUpdated = time.Now()
121
	p.lastUpdatedMu.Unlock()
122
123
124
	return n, nil
}

Michael Yang's avatar
Michael Yang committed
125
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
126
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
127
	if err != nil {
128
		return err
129
130
	}

131
132
	b.done = make(chan struct{})

Michael Yang's avatar
Michael Yang committed
133
	for _, partFilePath := range partFilePaths {
134
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
135
136
		if err != nil {
			return err
137
138
		}

139
		b.Total += part.Size
140
		b.Completed.Add(part.Completed.Load())
141
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
142
	}
143

144
	if len(b.Parts) == 0 {
Michael Yang's avatar
Michael Yang committed
145
		resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
146
		if err != nil {
Michael Yang's avatar
Michael Yang committed
147
148
149
150
			return err
		}
		defer resp.Body.Close()

151
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
152

Michael Yang's avatar
Michael Yang committed
153
		size := b.Total / numDownloadParts
154
155
156
157
158
159
		switch {
		case size < minDownloadPartSize:
			size = minDownloadPartSize
		case size > maxDownloadPartSize:
			size = maxDownloadPartSize
		}
Michael Yang's avatar
Michael Yang committed
160

161
		var offset int64
162
163
164
165
166
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
167
			if err := b.newPart(offset, size); err != nil {
168
				return err
Michael Yang's avatar
Michael Yang committed
169
170
171
			}

			offset += size
172
173
174
		}
	}

175
	slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
176
177
178
	return nil
}

Michael Yang's avatar
Michael Yang committed
179
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
180
	defer close(b.done)
181
182
183
	b.err = b.run(ctx, requestURL, opts)
}

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
	var n int
	return func(ctx context.Context) error {
		if ctx.Err() != nil {
			return ctx.Err()
		}

		n++

		// n^2 backoff timer is a little smoother than the
		// common choice of 2^n.
		d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
		// Randomize the delay between 0.5-1.5 x msec, in order
		// to prevent accidental "thundering herd" problems.
		d = time.Duration(float64(d) * (rand.Float64() + 0.5))
		t := time.NewTimer(d)
		defer t.Stop()
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-t.C:
			return nil
		}
	}
}

210
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
211
212
213
	defer blobDownloadManager.Delete(b.Digest)
	ctx, b.CancelFunc = context.WithCancel(ctx)

Michael Yang's avatar
Michael Yang committed
214
	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
215
	if err != nil {
216
		return err
Michael Yang's avatar
Michael Yang committed
217
	}
218
	defer file.Close()
219
	setSparse(file)
220

Michael Yang's avatar
Michael Yang committed
221
	_ = file.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
222

223
224
225
226
227
228
229
230
231
232
233
234
235
	directURL, err := func() (*url.URL, error) {
		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
		defer cancel()

		backoff := newBackoff(10 * time.Second)
		for {
			// shallow clone opts to be used in the closure
			// without affecting the outer opts.
			newOpts := new(registryOptions)
			*newOpts = *opts

			newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
				if len(via) > 10 {
236
					return errors.New("maximum redirects exceeded (10) for directURL")
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
				}

				// if the hostname is the same, allow the redirect
				if req.URL.Hostname() == requestURL.Hostname() {
					return nil
				}

				// stop at the first redirect that is not
				// the same hostname as the original
				// request.
				return http.ErrUseLastResponse
			}

			resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
			if err != nil {
				slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
				if err := backoff(ctx); err != nil {
					return nil, err
				}
				continue
			}
			defer resp.Body.Close()
			if resp.StatusCode != http.StatusTemporaryRedirect {
				return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
			}
			return resp.Location()
		}
	}()
	if err != nil {
		return err
	}

269
270
	g, inner := errgroup.WithContext(ctx)
	g.SetLimit(numDownloadParts)
271
272
	for i := range b.Parts {
		part := b.Parts[i]
273
		if part.Completed.Load() == part.Size {
Michael Yang's avatar
Michael Yang committed
274
275
			continue
		}
276

277
		g.Go(func() error {
Michael Yang's avatar
Michael Yang committed
278
			var err error
Michael Yang's avatar
Michael Yang committed
279
			for try := 0; try < maxRetries; try++ {
280
				w := io.NewOffsetWriter(file, part.StartsAt())
281
				err = b.downloadChunk(inner, directURL, w, part)
282
				switch {
283
284
				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
					// return immediately if the context is canceled or the device is out of space
285
					return err
286
287
288
				case errors.Is(err, errPartStalled):
					try--
					continue
289
				case err != nil:
Jeffrey Morgan's avatar
Jeffrey Morgan committed
290
					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
291
					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
Michael Yang's avatar
Michael Yang committed
292
					time.Sleep(sleep)
Michael Yang's avatar
Michael Yang committed
293
					continue
294
295
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
296
297
298
				}
			}

Michael Yang's avatar
Michael Yang committed
299
			return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
Michael Yang's avatar
Michael Yang committed
300
		})
301
302
	}

Michael Yang's avatar
Michael Yang committed
303
	if err := g.Wait(); err != nil {
304
		return err
305
306
	}

307
308
	// explicitly close the file so we can rename it
	if err := file.Close(); err != nil {
309
		return err
Michael Yang's avatar
Michael Yang committed
310
311
	}

312
	for i := range b.Parts {
313
		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
314
			return err
Michael Yang's avatar
Michael Yang committed
315
		}
316
317
	}

Michael Yang's avatar
Michael Yang committed
318
	if err := os.Rename(file.Name(), b.Name); err != nil {
319
		return err
Michael Yang's avatar
Michael Yang committed
320
321
	}

322
	return nil
323
324
}

325
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
326
327
	g, ctx := errgroup.WithContext(ctx)
	g.Go(func() error {
328
329
330
331
332
333
		req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
		if err != nil {
			return err
		}
		req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
		resp, err := http.DefaultClient.Do(req)
334
335
336
337
		if err != nil {
			return err
		}
		defer resp.Body.Close()
338

339
		n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
340
341
342
343
344
		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
			// rollback progress
			b.Completed.Add(-n)
			return err
		}
345

346
		part.Completed.Add(n)
347
348
349
350
351
		if err := b.writePart(part.Name(), part); err != nil {
			return err
		}

		// return nil or context.Canceled or UnexpectedEOF (resumable)
Michael Yang's avatar
Michael Yang committed
352
		return err
353
354
355
356
357
358
359
	})

	g.Go(func() error {
		ticker := time.NewTicker(time.Second)
		for {
			select {
			case <-ticker.C:
360
				if part.Completed.Load() >= part.Size {
361
362
363
					return nil
				}

364
365
366
367
368
				part.lastUpdatedMu.Lock()
				lastUpdated := part.lastUpdated
				part.lastUpdatedMu.Unlock()

				if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second {
369
370
					const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
					slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
371
					// reset last updated
372
					part.lastUpdatedMu.Lock()
373
					part.lastUpdated = time.Time{}
374
					part.lastUpdatedMu.Unlock()
375
376
377
378
379
380
381
					return errPartStalled
				}
			case <-ctx.Done():
				return ctx.Err()
			}
		}
	})
Michael Yang's avatar
Michael Yang committed
382

383
	return g.Wait()
384
385
}

Michael Yang's avatar
Michael Yang committed
386
387
388
389
390
391
392
393
394
395
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

396
397
398
399
400
401
402
403
404
405
406
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
407

Michael Yang's avatar
Michael Yang committed
408
	part.blobDownload = b
409
	return &part, nil
Michael Yang's avatar
Michael Yang committed
410
411
}

412
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
Michael Yang's avatar
Michael Yang committed
413
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
Michael Yang's avatar
Michael Yang committed
414
415
	if err != nil {
		return err
416
	}
Michael Yang's avatar
Michael Yang committed
417
	defer partFile.Close()
418

Michael Yang's avatar
Michael Yang committed
419
	return json.NewEncoder(partFile).Encode(part)
420
}
421

Michael Yang's avatar
Michael Yang committed
422
423
424
425
426
427
428
429
430
431
func (b *blobDownload) acquire() {
	b.references.Add(1)
}

func (b *blobDownload) release() {
	if b.references.Add(-1) == 0 {
		b.CancelFunc()
	}
}

432
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
433
434
	b.acquire()
	defer b.release()
435
436
437
438

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
439
440
		case <-b.done:
			return b.err
441
		case <-ticker.C:
442
443
444
445
446
447
			fn(api.ProgressResponse{
				Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
				Digest:    b.Digest,
				Total:     b.Total,
				Completed: b.Completed.Load(),
			})
448
449
450
451
452
453
454
455
456
		case <-ctx.Done():
			return ctx.Err()
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
Michael Yang's avatar
Michael Yang committed
457
	regOpts *registryOptions
458
459
460
461
	fn      func(api.ProgressResponse)
}

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
462
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
463
464
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
465
		return false, err
466
467
468
469
470
471
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
472
		return false, err
473
474
	default:
		opts.fn(api.ProgressResponse{
Jeffrey Morgan's avatar
Jeffrey Morgan committed
475
			Status:    fmt.Sprintf("pulling %s", opts.digest[7:19]),
476
477
478
479
480
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

481
		return true, nil
482
483
	}

Michael Yang's avatar
names  
Michael Yang committed
484
485
	data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	download := data.(*blobDownload)
486
487
488
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
Michael Yang's avatar
names  
Michael Yang committed
489
		if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
490
			blobDownloadManager.Delete(opts.digest)
491
			return false, err
492
493
		}

Michael Yang's avatar
Michael Yang committed
494
		//nolint:contextcheck
Michael Yang's avatar
names  
Michael Yang committed
495
		go download.Run(context.Background(), requestURL, opts.regOpts)
496
497
	}

498
	return false, download.Wait(ctx, opts.fn)
499
}