download.go 11.4 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
	"errors"
	"fmt"
	"io"
9
	"log/slog"
Jeffrey Morgan's avatar
Jeffrey Morgan committed
10
	"math"
11
	"math/rand/v2"
12
	"net/http"
Michael Yang's avatar
Michael Yang committed
13
	"net/url"
14
	"os"
Michael Yang's avatar
Michael Yang committed
15
	"path/filepath"
16
	"strconv"
Michael Yang's avatar
Michael Yang committed
17
	"strings"
18
19
	"sync"
	"sync/atomic"
20
	"syscall"
21
	"time"
22

Michael Yang's avatar
Michael Yang committed
23
	"golang.org/x/sync/errgroup"
24

25
26
	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/format"
27
28
)

29
30
const maxRetries = 6

Michael Yang's avatar
lint  
Michael Yang committed
31
32
33
34
var (
	errMaxRetriesExceeded = errors.New("max retries exceeded")
	errPartStalled        = errors.New("part stalled")
)
35

36
var blobDownloadManager sync.Map
37

38
39
40
type blobDownload struct {
	Name   string
	Digest string
41

42
43
	Total     int64
	Completed atomic.Int64
44

45
	Parts []*blobDownloadPart
46

47
	context.CancelFunc
Michael Yang's avatar
Michael Yang committed
48

49
	done       chan struct{}
Michael Yang's avatar
Michael Yang committed
50
	err        error
Michael Yang's avatar
Michael Yang committed
51
	references atomic.Int32
52
}
53

54
type blobDownloadPart struct {
55
56
57
58
59
60
61
	N         int
	Offset    int64
	Size      int64
	Completed atomic.Int64

	lastUpdatedMu sync.Mutex
	lastUpdated   time.Time
Michael Yang's avatar
Michael Yang committed
62
63
64
65

	*blobDownload `json:"-"`
}

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
type jsonBlobDownloadPart struct {
	N         int
	Offset    int64
	Size      int64
	Completed int64
}

func (p *blobDownloadPart) MarshalJSON() ([]byte, error) {
	return json.Marshal(jsonBlobDownloadPart{
		N:         p.N,
		Offset:    p.Offset,
		Size:      p.Size,
		Completed: p.Completed.Load(),
	})
}

func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
	var j jsonBlobDownloadPart
	if err := json.Unmarshal(b, &j); err != nil {
		return err
	}
	*p = blobDownloadPart{
		N:      j.N,
		Offset: j.Offset,
		Size:   j.Size,
	}
	p.Completed.Store(j.Completed)
	return nil
}

96
97
const (
	numDownloadParts          = 64
Michael Yang's avatar
Michael Yang committed
98
99
	minDownloadPartSize int64 = 100 * format.MegaByte
	maxDownloadPartSize int64 = 1000 * format.MegaByte
100
101
)

Michael Yang's avatar
Michael Yang committed
102
103
104
105
func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
106
}
107

108
func (p *blobDownloadPart) StartsAt() int64 {
109
	return p.Offset + p.Completed.Load()
110
111
112
113
114
115
}

func (p *blobDownloadPart) StopsAt() int64 {
	return p.Offset + p.Size
}

116
117
118
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
	n = len(b)
	p.blobDownload.Completed.Add(int64(n))
119
	p.lastUpdatedMu.Lock()
120
	p.lastUpdated = time.Now()
121
	p.lastUpdatedMu.Unlock()
122
123
124
	return n, nil
}

Michael Yang's avatar
Michael Yang committed
125
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
126
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
127
	if err != nil {
128
		return err
129
130
	}

131
132
	b.done = make(chan struct{})

Michael Yang's avatar
Michael Yang committed
133
	for _, partFilePath := range partFilePaths {
134
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
135
136
		if err != nil {
			return err
137
138
		}

139
		b.Total += part.Size
140
		b.Completed.Add(part.Completed.Load())
141
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
142
	}
143

144
	if len(b.Parts) == 0 {
Michael Yang's avatar
Michael Yang committed
145
		resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
146
		if err != nil {
Michael Yang's avatar
Michael Yang committed
147
148
149
150
			return err
		}
		defer resp.Body.Close()

151
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
152

Michael Yang's avatar
Michael Yang committed
153
		size := b.Total / numDownloadParts
154
155
156
157
158
159
		switch {
		case size < minDownloadPartSize:
			size = minDownloadPartSize
		case size > maxDownloadPartSize:
			size = maxDownloadPartSize
		}
Michael Yang's avatar
Michael Yang committed
160

161
		var offset int64
162
163
164
165
166
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
167
			if err := b.newPart(offset, size); err != nil {
168
				return err
Michael Yang's avatar
Michael Yang committed
169
170
171
			}

			offset += size
172
173
174
		}
	}

175
	slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
176
177
178
	return nil
}

Michael Yang's avatar
Michael Yang committed
179
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
180
	defer close(b.done)
181
182
183
	b.err = b.run(ctx, requestURL, opts)
}

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
	var n int
	return func(ctx context.Context) error {
		if ctx.Err() != nil {
			return ctx.Err()
		}

		n++

		// n^2 backoff timer is a little smoother than the
		// common choice of 2^n.
		d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
		// Randomize the delay between 0.5-1.5 x msec, in order
		// to prevent accidental "thundering herd" problems.
		d = time.Duration(float64(d) * (rand.Float64() + 0.5))
		t := time.NewTimer(d)
		defer t.Stop()
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-t.C:
			return nil
		}
	}
}

210
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
211
212
213
	defer blobDownloadManager.Delete(b.Digest)
	ctx, b.CancelFunc = context.WithCancel(ctx)

Michael Yang's avatar
Michael Yang committed
214
	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
215
	if err != nil {
216
		return err
Michael Yang's avatar
Michael Yang committed
217
	}
218
	defer file.Close()
219
220
221
	if err := setSparse(file); err != nil {
		return err
	}
222

Michael Yang's avatar
Michael Yang committed
223
	_ = file.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
	directURL, err := func() (*url.URL, error) {
		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
		defer cancel()

		backoff := newBackoff(10 * time.Second)
		for {
			// shallow clone opts to be used in the closure
			// without affecting the outer opts.
			newOpts := new(registryOptions)
			*newOpts = *opts

			newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
				if len(via) > 10 {
					return errors.New("maxium redirects exceeded (10) for directURL")
				}

				// if the hostname is the same, allow the redirect
				if req.URL.Hostname() == requestURL.Hostname() {
					return nil
				}

				// stop at the first redirect that is not
				// the same hostname as the original
				// request.
				return http.ErrUseLastResponse
			}

			resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
			if err != nil {
				slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
				if err := backoff(ctx); err != nil {
					return nil, err
				}
				continue
			}
			defer resp.Body.Close()
			if resp.StatusCode != http.StatusTemporaryRedirect {
				return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
			}
			return resp.Location()
		}
	}()
	if err != nil {
		return err
	}

271
272
	g, inner := errgroup.WithContext(ctx)
	g.SetLimit(numDownloadParts)
273
274
	for i := range b.Parts {
		part := b.Parts[i]
275
		if part.Completed.Load() == part.Size {
Michael Yang's avatar
Michael Yang committed
276
277
			continue
		}
278

279
		g.Go(func() error {
Michael Yang's avatar
Michael Yang committed
280
			var err error
Michael Yang's avatar
Michael Yang committed
281
			for try := 0; try < maxRetries; try++ {
282
				w := io.NewOffsetWriter(file, part.StartsAt())
283
				err = b.downloadChunk(inner, directURL, w, part)
284
				switch {
285
286
				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
					// return immediately if the context is canceled or the device is out of space
287
					return err
288
289
290
				case errors.Is(err, errPartStalled):
					try--
					continue
291
				case err != nil:
Jeffrey Morgan's avatar
Jeffrey Morgan committed
292
					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
293
					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
Michael Yang's avatar
Michael Yang committed
294
					time.Sleep(sleep)
Michael Yang's avatar
Michael Yang committed
295
					continue
296
297
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
298
299
300
				}
			}

Michael Yang's avatar
Michael Yang committed
301
			return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
Michael Yang's avatar
Michael Yang committed
302
		})
303
304
	}

Michael Yang's avatar
Michael Yang committed
305
	if err := g.Wait(); err != nil {
306
		return err
307
308
	}

309
310
	// explicitly close the file so we can rename it
	if err := file.Close(); err != nil {
311
		return err
Michael Yang's avatar
Michael Yang committed
312
313
	}

314
	for i := range b.Parts {
315
		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
316
			return err
Michael Yang's avatar
Michael Yang committed
317
		}
318
319
	}

Michael Yang's avatar
Michael Yang committed
320
	if err := os.Rename(file.Name(), b.Name); err != nil {
321
		return err
Michael Yang's avatar
Michael Yang committed
322
323
	}

324
	return nil
325
326
}

327
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
328
329
	g, ctx := errgroup.WithContext(ctx)
	g.Go(func() error {
330
331
332
333
334
335
		req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
		if err != nil {
			return err
		}
		req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
		resp, err := http.DefaultClient.Do(req)
336
337
338
339
		if err != nil {
			return err
		}
		defer resp.Body.Close()
340

341
		n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
342
343
344
345
346
		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
			// rollback progress
			b.Completed.Add(-n)
			return err
		}
347

348
		part.Completed.Add(n)
349
350
351
352
353
		if err := b.writePart(part.Name(), part); err != nil {
			return err
		}

		// return nil or context.Canceled or UnexpectedEOF (resumable)
Michael Yang's avatar
Michael Yang committed
354
		return err
355
356
357
358
359
360
361
	})

	g.Go(func() error {
		ticker := time.NewTicker(time.Second)
		for {
			select {
			case <-ticker.C:
362
				if part.Completed.Load() >= part.Size {
363
364
365
					return nil
				}

366
367
368
369
370
				part.lastUpdatedMu.Lock()
				lastUpdated := part.lastUpdated
				part.lastUpdatedMu.Unlock()

				if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second {
371
372
					const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
					slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
373
					// reset last updated
374
					part.lastUpdatedMu.Lock()
375
					part.lastUpdated = time.Time{}
376
					part.lastUpdatedMu.Unlock()
377
378
379
380
381
382
383
					return errPartStalled
				}
			case <-ctx.Done():
				return ctx.Err()
			}
		}
	})
Michael Yang's avatar
Michael Yang committed
384

385
	return g.Wait()
386
387
}

Michael Yang's avatar
Michael Yang committed
388
389
390
391
392
393
394
395
396
397
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

398
399
400
401
402
403
404
405
406
407
408
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
409

Michael Yang's avatar
Michael Yang committed
410
	part.blobDownload = b
411
	return &part, nil
Michael Yang's avatar
Michael Yang committed
412
413
}

414
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
Michael Yang's avatar
Michael Yang committed
415
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
Michael Yang's avatar
Michael Yang committed
416
417
	if err != nil {
		return err
418
	}
Michael Yang's avatar
Michael Yang committed
419
	defer partFile.Close()
420

Michael Yang's avatar
Michael Yang committed
421
	return json.NewEncoder(partFile).Encode(part)
422
}
423

Michael Yang's avatar
Michael Yang committed
424
425
426
427
428
429
430
431
432
433
func (b *blobDownload) acquire() {
	b.references.Add(1)
}

func (b *blobDownload) release() {
	if b.references.Add(-1) == 0 {
		b.CancelFunc()
	}
}

434
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
435
436
	b.acquire()
	defer b.release()
437
438
439
440

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
441
442
		case <-b.done:
			return b.err
443
		case <-ticker.C:
444
445
446
447
448
449
			fn(api.ProgressResponse{
				Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
				Digest:    b.Digest,
				Total:     b.Total,
				Completed: b.Completed.Load(),
			})
450
451
452
453
454
455
456
457
458
		case <-ctx.Done():
			return ctx.Err()
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
Michael Yang's avatar
Michael Yang committed
459
	regOpts *registryOptions
460
461
462
463
	fn      func(api.ProgressResponse)
}

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
464
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
465
466
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
467
		return false, err
468
469
470
471
472
473
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
474
		return false, err
475
476
	default:
		opts.fn(api.ProgressResponse{
Jeffrey Morgan's avatar
Jeffrey Morgan committed
477
			Status:    fmt.Sprintf("pulling %s", opts.digest[7:19]),
478
479
480
481
482
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

483
		return true, nil
484
485
	}

Michael Yang's avatar
names  
Michael Yang committed
486
487
	data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	download := data.(*blobDownload)
488
489
490
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
Michael Yang's avatar
names  
Michael Yang committed
491
		if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
492
			blobDownloadManager.Delete(opts.digest)
493
			return false, err
494
495
		}

Michael Yang's avatar
Michael Yang committed
496
		//nolint:contextcheck
Michael Yang's avatar
names  
Michael Yang committed
497
		go download.Run(context.Background(), requestURL, opts.regOpts)
498
499
	}

500
	return false, download.Wait(ctx, opts.fn)
501
}