download.go 9.67 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
	"errors"
	"fmt"
	"io"
9
	"log/slog"
Jeffrey Morgan's avatar
Jeffrey Morgan committed
10
	"math"
11
	"net/http"
Michael Yang's avatar
Michael Yang committed
12
	"net/url"
13
	"os"
Michael Yang's avatar
Michael Yang committed
14
	"path/filepath"
15
	"strconv"
Michael Yang's avatar
Michael Yang committed
16
	"strings"
17
18
	"sync"
	"sync/atomic"
19
	"syscall"
20
	"time"
21

Michael Yang's avatar
Michael Yang committed
22
	"golang.org/x/sync/errgroup"
23
	"golang.org/x/sync/semaphore"
24
25

	"github.com/jmorganca/ollama/api"
26
	"github.com/jmorganca/ollama/format"
27
28
)

29
30
31
32
33
const maxRetries = 6

var errMaxRetriesExceeded = errors.New("max retries exceeded")
var errPartStalled = errors.New("part stalled")

34
var blobDownloadManager sync.Map
35

36
37
38
type blobDownload struct {
	Name   string
	Digest string
39

40
41
	Total     int64
	Completed atomic.Int64
42

43
	Parts []*blobDownloadPart
44

45
	context.CancelFunc
Michael Yang's avatar
Michael Yang committed
46
47
48

	done       bool
	err        error
Michael Yang's avatar
Michael Yang committed
49
	references atomic.Int32
50
}
51

52
type blobDownloadPart struct {
53
54
55
56
57
	N           int
	Offset      int64
	Size        int64
	Completed   int64
	lastUpdated time.Time
Michael Yang's avatar
Michael Yang committed
58
59
60
61

	*blobDownload `json:"-"`
}

62
63
const (
	numDownloadParts          = 64
Michael Yang's avatar
Michael Yang committed
64
65
	minDownloadPartSize int64 = 100 * format.MegaByte
	maxDownloadPartSize int64 = 1000 * format.MegaByte
66
67
)

Michael Yang's avatar
Michael Yang committed
68
69
70
71
func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
72
}
73

74
75
76
77
78
79
80
81
func (p *blobDownloadPart) StartsAt() int64 {
	return p.Offset + p.Completed
}

func (p *blobDownloadPart) StopsAt() int64 {
	return p.Offset + p.Size
}

82
83
84
85
86
87
88
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
	n = len(b)
	p.blobDownload.Completed.Add(int64(n))
	p.lastUpdated = time.Now()
	return n, nil
}

Michael Yang's avatar
Michael Yang committed
89
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
90
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
91
	if err != nil {
92
		return err
93
94
	}

Michael Yang's avatar
Michael Yang committed
95
	for _, partFilePath := range partFilePaths {
96
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
97
98
		if err != nil {
			return err
99
100
		}

101
102
103
		b.Total += part.Size
		b.Completed.Add(part.Completed)
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
104
	}
105

106
	if len(b.Parts) == 0 {
Michael Yang's avatar
Michael Yang committed
107
		resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
108
		if err != nil {
Michael Yang's avatar
Michael Yang committed
109
110
111
112
			return err
		}
		defer resp.Body.Close()

113
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
114

Michael Yang's avatar
Michael Yang committed
115
		size := b.Total / numDownloadParts
116
117
118
119
120
121
		switch {
		case size < minDownloadPartSize:
			size = minDownloadPartSize
		case size > maxDownloadPartSize:
			size = maxDownloadPartSize
		}
Michael Yang's avatar
Michael Yang committed
122

123
		var offset int64
124
125
126
127
128
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
129
			if err := b.newPart(offset, size); err != nil {
130
				return err
Michael Yang's avatar
Michael Yang committed
131
132
133
			}

			offset += size
134
135
136
		}
	}

137
	slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
138
139
140
	return nil
}

Michael Yang's avatar
Michael Yang committed
141
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
142
143
144
	defer blobDownloadManager.Delete(b.Digest)
	ctx, b.CancelFunc = context.WithCancel(ctx)

Michael Yang's avatar
Michael Yang committed
145
	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
146
	if err != nil {
Michael Yang's avatar
Michael Yang committed
147
148
		b.err = err
		return
Michael Yang's avatar
Michael Yang committed
149
	}
150
	defer file.Close()
151

Michael Yang's avatar
Michael Yang committed
152
	_ = file.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
153

154
	g, inner := NewLimitGroup(ctx, numDownloadParts)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

	go func() {
		ticker := time.NewTicker(time.Second)
		var n int64 = 1
		var maxDelta float64
		var buckets []int64
		for {
			select {
			case <-ticker.C:
				buckets = append(buckets, b.Completed.Load())
				if len(buckets) < 2 {
					continue
				} else if len(buckets) > 10 {
					buckets = buckets[1:]
				}

				delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets))
				slog.Debug(fmt.Sprintf("delta: %s/s max_delta: %s/s", format.HumanBytes(int64(delta)), format.HumanBytes(int64(maxDelta))))
				if delta > maxDelta*1.5 {
					maxDelta = delta
					g.SetLimit(n)
					n++
				}

			case <-ctx.Done():
				return
			}
		}
	}()

185
186
	for i := range b.Parts {
		part := b.Parts[i]
Michael Yang's avatar
Michael Yang committed
187
188
189
		if part.Completed == part.Size {
			continue
		}
190

Michael Yang's avatar
Michael Yang committed
191
		g.Go(func() error {
Michael Yang's avatar
Michael Yang committed
192
			var err error
Michael Yang's avatar
Michael Yang committed
193
			for try := 0; try < maxRetries; try++ {
194
				w := io.NewOffsetWriter(file, part.StartsAt())
Michael Yang's avatar
Michael Yang committed
195
				err = b.downloadChunk(inner, requestURL, w, part, opts)
196
				switch {
197
198
				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
					// return immediately if the context is canceled or the device is out of space
199
					return err
200
201
202
				case errors.Is(err, errPartStalled):
					try--
					continue
203
				case err != nil:
Jeffrey Morgan's avatar
Jeffrey Morgan committed
204
					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
205
					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
Michael Yang's avatar
Michael Yang committed
206
					time.Sleep(sleep)
Michael Yang's avatar
Michael Yang committed
207
					continue
208
209
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
210
211
212
				}
			}

Michael Yang's avatar
Michael Yang committed
213
			return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
Michael Yang's avatar
Michael Yang committed
214
		})
215
216
	}

Michael Yang's avatar
Michael Yang committed
217
	if err := g.Wait(); err != nil {
Michael Yang's avatar
Michael Yang committed
218
219
		b.err = err
		return
220
221
	}

222
223
	// explicitly close the file so we can rename it
	if err := file.Close(); err != nil {
Michael Yang's avatar
Michael Yang committed
224
225
		b.err = err
		return
Michael Yang's avatar
Michael Yang committed
226
227
	}

228
	for i := range b.Parts {
229
		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
Michael Yang's avatar
Michael Yang committed
230
231
			b.err = err
			return
Michael Yang's avatar
Michael Yang committed
232
		}
233
234
	}

Michael Yang's avatar
Michael Yang committed
235
	if err := os.Rename(file.Name(), b.Name); err != nil {
Michael Yang's avatar
Michael Yang committed
236
237
		b.err = err
		return
Michael Yang's avatar
Michael Yang committed
238
239
240
	}

	b.done = true
Michael Yang's avatar
Michael Yang committed
241
	return
242
243
}

Michael Yang's avatar
Michael Yang committed
244
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
245
246
247
248
249
250
251
252
253
	g, ctx := errgroup.WithContext(ctx)
	g.Go(func() error {
		headers := make(http.Header)
		headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
		resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
		if err != nil {
			return err
		}
		defer resp.Body.Close()
254

255
256
257
258
259
260
		n, err := io.Copy(w, io.TeeReader(resp.Body, part))
		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
			// rollback progress
			b.Completed.Add(-n)
			return err
		}
261

262
263
264
265
266
267
		part.Completed += n
		if err := b.writePart(part.Name(), part); err != nil {
			return err
		}

		// return nil or context.Canceled or UnexpectedEOF (resumable)
Michael Yang's avatar
Michael Yang committed
268
		return err
269
270
271
272
273
274
275
276
277
278
279
280
	})

	g.Go(func() error {
		ticker := time.NewTicker(time.Second)
		for {
			select {
			case <-ticker.C:
				if part.Completed >= part.Size {
					return nil
				}

				if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
Michael Yang's avatar
Michael Yang committed
281
					slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
282
283
284
285
286
287
288
289
290
					// reset last updated
					part.lastUpdated = time.Time{}
					return errPartStalled
				}
			case <-ctx.Done():
				return ctx.Err()
			}
		}
	})
Michael Yang's avatar
Michael Yang committed
291

292
	return g.Wait()
293
294
}

Michael Yang's avatar
Michael Yang committed
295
296
297
298
299
300
301
302
303
304
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

305
306
307
308
309
310
311
312
313
314
315
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
316

Michael Yang's avatar
Michael Yang committed
317
	part.blobDownload = b
318
	return &part, nil
Michael Yang's avatar
Michael Yang committed
319
320
}

321
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
Michael Yang's avatar
Michael Yang committed
322
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
Michael Yang's avatar
Michael Yang committed
323
324
	if err != nil {
		return err
325
	}
Michael Yang's avatar
Michael Yang committed
326
	defer partFile.Close()
327

Michael Yang's avatar
Michael Yang committed
328
	return json.NewEncoder(partFile).Encode(part)
329
}
330

Michael Yang's avatar
Michael Yang committed
331
332
333
334
335
336
337
338
339
340
func (b *blobDownload) acquire() {
	b.references.Add(1)
}

func (b *blobDownload) release() {
	if b.references.Add(-1) == 0 {
		b.CancelFunc()
	}
}

341
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
342
343
	b.acquire()
	defer b.release()
344
345
346
347
348

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
		case <-ticker.C:
349
350
351
352
353
354
355
356
357
358
			fn(api.ProgressResponse{
				Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
				Digest:    b.Digest,
				Total:     b.Total,
				Completed: b.Completed.Load(),
			})

			if b.done || b.err != nil {
				return b.err
			}
359
360
361
362
363
364
365
366
367
		case <-ctx.Done():
			return ctx.Err()
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
Michael Yang's avatar
Michael Yang committed
368
	regOpts *registryOptions
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
	fn      func(api.ProgressResponse)
}

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
		return err
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
		return err
	default:
		opts.fn(api.ProgressResponse{
Jeffrey Morgan's avatar
Jeffrey Morgan committed
386
			Status:    fmt.Sprintf("pulling %s", opts.digest[7:19]),
387
388
389
390
391
392
393
394
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

		return nil
	}

Michael Yang's avatar
names  
Michael Yang committed
395
396
	data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	download := data.(*blobDownload)
397
398
399
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
Michael Yang's avatar
names  
Michael Yang committed
400
		if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
401
			blobDownloadManager.Delete(opts.digest)
402
403
404
			return err
		}

Michael Yang's avatar
Michael Yang committed
405
		// nolint: contextcheck
Michael Yang's avatar
names  
Michael Yang committed
406
		go download.Run(context.Background(), requestURL, opts.regOpts)
407
408
	}

Michael Yang's avatar
names  
Michael Yang committed
409
	return download.Wait(ctx, opts.fn)
410
}
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445

type LimitGroup struct {
	*errgroup.Group
	context.Context
	Semaphore *semaphore.Weighted

	weight, max_weight int64
}

func NewLimitGroup(ctx context.Context, n int64) (*LimitGroup, context.Context) {
	g, ctx := errgroup.WithContext(ctx)
	return &LimitGroup{
		Group:      g,
		Context:    ctx,
		Semaphore:  semaphore.NewWeighted(n),
		weight:     n,
		max_weight: n,
	}, ctx
}

func (g *LimitGroup) Go(fn func() error) {
	weight := g.weight
	g.Semaphore.Acquire(g.Context, weight)
	if g.Context.Err() != nil {
		return
	}

	g.Group.Go(func() error {
		defer g.Semaphore.Release(weight)
		return fn()
	})
}

func (g *LimitGroup) SetLimit(n int64) {
	if n > 0 {
446
		slog.Debug(fmt.Sprintf("setting limit to %d", n))
447
448
449
		g.weight = g.max_weight / n
	}
}