download.go 11.4 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
	"errors"
	"fmt"
	"io"
9
	"log/slog"
Jeffrey Morgan's avatar
Jeffrey Morgan committed
10
	"math"
11
	"math/rand/v2"
12
	"net/http"
Michael Yang's avatar
Michael Yang committed
13
	"net/url"
14
	"os"
Michael Yang's avatar
Michael Yang committed
15
	"path/filepath"
16
	"strconv"
Michael Yang's avatar
Michael Yang committed
17
	"strings"
18
19
	"sync"
	"sync/atomic"
20
	"syscall"
21
	"time"
22

Michael Yang's avatar
Michael Yang committed
23
	"golang.org/x/sync/errgroup"
24

25
26
	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/format"
27
28
)

29
30
const maxRetries = 6

Michael Yang's avatar
lint  
Michael Yang committed
31
32
33
34
var (
	errMaxRetriesExceeded = errors.New("max retries exceeded")
	errPartStalled        = errors.New("part stalled")
)
35

36
var blobDownloadManager sync.Map
37

38
39
40
type blobDownload struct {
	Name   string
	Digest string
41

42
43
	Total     int64
	Completed atomic.Int64
44

45
	Parts []*blobDownloadPart
46

47
	context.CancelFunc
Michael Yang's avatar
Michael Yang committed
48

49
	done       chan struct{}
Michael Yang's avatar
Michael Yang committed
50
	err        error
Michael Yang's avatar
Michael Yang committed
51
	references atomic.Int32
52
}
53

54
type blobDownloadPart struct {
55
56
57
58
59
60
61
	N         int
	Offset    int64
	Size      int64
	Completed atomic.Int64

	lastUpdatedMu sync.Mutex
	lastUpdated   time.Time
Michael Yang's avatar
Michael Yang committed
62
63
64
65

	*blobDownload `json:"-"`
}

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
type jsonBlobDownloadPart struct {
	N         int
	Offset    int64
	Size      int64
	Completed int64
}

func (p *blobDownloadPart) MarshalJSON() ([]byte, error) {
	return json.Marshal(jsonBlobDownloadPart{
		N:         p.N,
		Offset:    p.Offset,
		Size:      p.Size,
		Completed: p.Completed.Load(),
	})
}

func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
	var j jsonBlobDownloadPart
	if err := json.Unmarshal(b, &j); err != nil {
		return err
	}
	*p = blobDownloadPart{
		N:      j.N,
		Offset: j.Offset,
		Size:   j.Size,
	}
	p.Completed.Store(j.Completed)
	return nil
}

96
const (
97
	numDownloadParts          = 16
Michael Yang's avatar
Michael Yang committed
98
99
	minDownloadPartSize int64 = 100 * format.MegaByte
	maxDownloadPartSize int64 = 1000 * format.MegaByte
100
101
)

Michael Yang's avatar
Michael Yang committed
102
103
104
105
func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
106
}
107

108
func (p *blobDownloadPart) StartsAt() int64 {
109
	return p.Offset + p.Completed.Load()
110
111
112
113
114
115
}

func (p *blobDownloadPart) StopsAt() int64 {
	return p.Offset + p.Size
}

116
117
118
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
	n = len(b)
	p.blobDownload.Completed.Add(int64(n))
119
	p.lastUpdatedMu.Lock()
120
	p.lastUpdated = time.Now()
121
	p.lastUpdatedMu.Unlock()
122
123
124
	return n, nil
}

Michael Yang's avatar
Michael Yang committed
125
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
126
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
127
	if err != nil {
128
		return err
129
130
	}

131
132
	b.done = make(chan struct{})

Michael Yang's avatar
Michael Yang committed
133
	for _, partFilePath := range partFilePaths {
134
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
135
136
		if err != nil {
			return err
137
138
		}

139
		b.Total += part.Size
140
		b.Completed.Add(part.Completed.Load())
141
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
142
	}
143

144
	if len(b.Parts) == 0 {
Michael Yang's avatar
Michael Yang committed
145
		resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
146
		if err != nil {
Michael Yang's avatar
Michael Yang committed
147
148
149
150
			return err
		}
		defer resp.Body.Close()

151
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
152

Michael Yang's avatar
Michael Yang committed
153
		size := b.Total / numDownloadParts
154
155
156
157
158
159
		switch {
		case size < minDownloadPartSize:
			size = minDownloadPartSize
		case size > maxDownloadPartSize:
			size = maxDownloadPartSize
		}
Michael Yang's avatar
Michael Yang committed
160

161
		var offset int64
162
163
164
165
166
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
167
			if err := b.newPart(offset, size); err != nil {
168
				return err
Michael Yang's avatar
Michael Yang committed
169
170
171
			}

			offset += size
172
		}
173
174
175
	}

	if len(b.Parts) > 0 {
176
		slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
177
	}
178

179
180
181
	return nil
}

Michael Yang's avatar
Michael Yang committed
182
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
183
	defer close(b.done)
184
185
186
	b.err = b.run(ctx, requestURL, opts)
}

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
	var n int
	return func(ctx context.Context) error {
		if ctx.Err() != nil {
			return ctx.Err()
		}

		n++

		// n^2 backoff timer is a little smoother than the
		// common choice of 2^n.
		d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
		// Randomize the delay between 0.5-1.5 x msec, in order
		// to prevent accidental "thundering herd" problems.
		d = time.Duration(float64(d) * (rand.Float64() + 0.5))
		t := time.NewTimer(d)
		defer t.Stop()
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-t.C:
			return nil
		}
	}
}

213
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
214
215
216
	defer blobDownloadManager.Delete(b.Digest)
	ctx, b.CancelFunc = context.WithCancel(ctx)

Michael Yang's avatar
Michael Yang committed
217
	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
218
	if err != nil {
219
		return err
Michael Yang's avatar
Michael Yang committed
220
	}
221
	defer file.Close()
222
	setSparse(file)
223

Michael Yang's avatar
Michael Yang committed
224
	_ = file.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
225

226
227
228
229
230
231
232
233
234
235
236
237
238
	directURL, err := func() (*url.URL, error) {
		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
		defer cancel()

		backoff := newBackoff(10 * time.Second)
		for {
			// shallow clone opts to be used in the closure
			// without affecting the outer opts.
			newOpts := new(registryOptions)
			*newOpts = *opts

			newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
				if len(via) > 10 {
239
					return errors.New("maximum redirects exceeded (10) for directURL")
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
				}

				// if the hostname is the same, allow the redirect
				if req.URL.Hostname() == requestURL.Hostname() {
					return nil
				}

				// stop at the first redirect that is not
				// the same hostname as the original
				// request.
				return http.ErrUseLastResponse
			}

			resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
			if err != nil {
				slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
				if err := backoff(ctx); err != nil {
					return nil, err
				}
				continue
			}
			defer resp.Body.Close()
262
			if resp.StatusCode != http.StatusTemporaryRedirect && resp.StatusCode != http.StatusOK {
263
264
265
266
267
268
269
270
271
				return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
			}
			return resp.Location()
		}
	}()
	if err != nil {
		return err
	}

272
273
	g, inner := errgroup.WithContext(ctx)
	g.SetLimit(numDownloadParts)
274
275
	for i := range b.Parts {
		part := b.Parts[i]
276
		if part.Completed.Load() == part.Size {
Michael Yang's avatar
Michael Yang committed
277
278
			continue
		}
279

280
		g.Go(func() error {
Michael Yang's avatar
Michael Yang committed
281
			var err error
Michael Yang's avatar
Michael Yang committed
282
			for try := 0; try < maxRetries; try++ {
283
				w := io.NewOffsetWriter(file, part.StartsAt())
284
				err = b.downloadChunk(inner, directURL, w, part)
285
				switch {
286
287
				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
					// return immediately if the context is canceled or the device is out of space
288
					return err
289
290
291
				case errors.Is(err, errPartStalled):
					try--
					continue
292
				case err != nil:
Jeffrey Morgan's avatar
Jeffrey Morgan committed
293
					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
294
					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
Michael Yang's avatar
Michael Yang committed
295
					time.Sleep(sleep)
Michael Yang's avatar
Michael Yang committed
296
					continue
297
298
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
299
300
301
				}
			}

Michael Yang's avatar
Michael Yang committed
302
			return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
Michael Yang's avatar
Michael Yang committed
303
		})
304
305
	}

Michael Yang's avatar
Michael Yang committed
306
	if err := g.Wait(); err != nil {
307
		return err
308
309
	}

310
311
	// explicitly close the file so we can rename it
	if err := file.Close(); err != nil {
312
		return err
Michael Yang's avatar
Michael Yang committed
313
314
	}

315
	for i := range b.Parts {
316
		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
317
			return err
Michael Yang's avatar
Michael Yang committed
318
		}
319
320
	}

Michael Yang's avatar
Michael Yang committed
321
	if err := os.Rename(file.Name(), b.Name); err != nil {
322
		return err
Michael Yang's avatar
Michael Yang committed
323
324
	}

325
	return nil
326
327
}

328
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
329
330
	g, ctx := errgroup.WithContext(ctx)
	g.Go(func() error {
331
332
333
334
335
336
		req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
		if err != nil {
			return err
		}
		req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
		resp, err := http.DefaultClient.Do(req)
337
338
339
340
		if err != nil {
			return err
		}
		defer resp.Body.Close()
341

342
		n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
343
344
345
346
347
		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
			// rollback progress
			b.Completed.Add(-n)
			return err
		}
348

349
		part.Completed.Add(n)
350
351
352
353
354
		if err := b.writePart(part.Name(), part); err != nil {
			return err
		}

		// return nil or context.Canceled or UnexpectedEOF (resumable)
Michael Yang's avatar
Michael Yang committed
355
		return err
356
357
358
359
360
361
362
	})

	g.Go(func() error {
		ticker := time.NewTicker(time.Second)
		for {
			select {
			case <-ticker.C:
363
				if part.Completed.Load() >= part.Size {
364
365
366
					return nil
				}

367
368
369
370
				part.lastUpdatedMu.Lock()
				lastUpdated := part.lastUpdated
				part.lastUpdatedMu.Unlock()

371
				if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
372
373
					const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
					slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
374
					// reset last updated
375
					part.lastUpdatedMu.Lock()
376
					part.lastUpdated = time.Time{}
377
					part.lastUpdatedMu.Unlock()
378
379
380
381
382
383
384
					return errPartStalled
				}
			case <-ctx.Done():
				return ctx.Err()
			}
		}
	})
Michael Yang's avatar
Michael Yang committed
385

386
	return g.Wait()
387
388
}

Michael Yang's avatar
Michael Yang committed
389
390
391
392
393
394
395
396
397
398
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

399
400
401
402
403
404
405
406
407
408
409
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
410

Michael Yang's avatar
Michael Yang committed
411
	part.blobDownload = b
412
	return &part, nil
Michael Yang's avatar
Michael Yang committed
413
414
}

415
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
Michael Yang's avatar
Michael Yang committed
416
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
Michael Yang's avatar
Michael Yang committed
417
418
	if err != nil {
		return err
419
	}
Michael Yang's avatar
Michael Yang committed
420
	defer partFile.Close()
421

Michael Yang's avatar
Michael Yang committed
422
	return json.NewEncoder(partFile).Encode(part)
423
}
424

Michael Yang's avatar
Michael Yang committed
425
426
427
428
429
430
431
432
433
434
func (b *blobDownload) acquire() {
	b.references.Add(1)
}

func (b *blobDownload) release() {
	if b.references.Add(-1) == 0 {
		b.CancelFunc()
	}
}

435
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
436
437
	b.acquire()
	defer b.release()
438
439
440
441

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
442
443
		case <-b.done:
			return b.err
444
		case <-ticker.C:
445
446
447
448
449
450
			fn(api.ProgressResponse{
				Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
				Digest:    b.Digest,
				Total:     b.Total,
				Completed: b.Completed.Load(),
			})
451
452
453
454
455
456
457
458
459
		case <-ctx.Done():
			return ctx.Err()
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
Michael Yang's avatar
Michael Yang committed
460
	regOpts *registryOptions
461
462
463
464
	fn      func(api.ProgressResponse)
}

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
465
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
466
467
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
468
		return false, err
469
470
471
472
473
474
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
475
		return false, err
476
477
	default:
		opts.fn(api.ProgressResponse{
Jeffrey Morgan's avatar
Jeffrey Morgan committed
478
			Status:    fmt.Sprintf("pulling %s", opts.digest[7:19]),
479
480
481
482
483
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

484
		return true, nil
485
486
	}

Michael Yang's avatar
names  
Michael Yang committed
487
488
	data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	download := data.(*blobDownload)
489
490
491
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
Michael Yang's avatar
names  
Michael Yang committed
492
		if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
493
			blobDownloadManager.Delete(opts.digest)
494
			return false, err
495
496
		}

Michael Yang's avatar
Michael Yang committed
497
		//nolint:contextcheck
Michael Yang's avatar
names  
Michael Yang committed
498
		go download.Run(context.Background(), requestURL, opts.regOpts)
499
500
	}

501
	return false, download.Wait(ctx, opts.fn)
502
}