download.go 8.98 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
	"errors"
	"fmt"
	"io"
9
	"log/slog"
Jeffrey Morgan's avatar
Jeffrey Morgan committed
10
	"math"
11
	"net/http"
Michael Yang's avatar
Michael Yang committed
12
	"net/url"
13
	"os"
Michael Yang's avatar
Michael Yang committed
14
	"path/filepath"
15
	"strconv"
Michael Yang's avatar
Michael Yang committed
16
	"strings"
17
18
	"sync"
	"sync/atomic"
19
	"syscall"
20
	"time"
21

Michael Yang's avatar
Michael Yang committed
22
	"golang.org/x/sync/errgroup"
23
	"golang.org/x/sync/semaphore"
24
25

	"github.com/jmorganca/ollama/api"
26
	"github.com/jmorganca/ollama/format"
27
28
)

29
30
31
32
33
const maxRetries = 6

var errMaxRetriesExceeded = errors.New("max retries exceeded")
var errPartStalled = errors.New("part stalled")

34
var blobDownloadManager sync.Map
35

36
37
38
type blobDownload struct {
	Name   string
	Digest string
39

40
41
	Total     int64
	Completed atomic.Int64
42

43
	Parts []*blobDownloadPart
44

45
	context.CancelFunc
Michael Yang's avatar
Michael Yang committed
46
47
48

	done       bool
	err        error
Michael Yang's avatar
Michael Yang committed
49
	references atomic.Int32
50
}
51

52
type blobDownloadPart struct {
53
54
55
56
57
	N           int
	Offset      int64
	Size        int64
	Completed   int64
	lastUpdated time.Time
Michael Yang's avatar
Michael Yang committed
58
59
60
61

	*blobDownload `json:"-"`
}

62
63
const (
	numDownloadParts          = 64
Michael Yang's avatar
Michael Yang committed
64
65
	minDownloadPartSize int64 = 100 * format.MegaByte
	maxDownloadPartSize int64 = 1000 * format.MegaByte
66
67
)

Michael Yang's avatar
Michael Yang committed
68
69
70
71
func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
72
}
73

74
75
76
77
78
79
80
81
func (p *blobDownloadPart) StartsAt() int64 {
	return p.Offset + p.Completed
}

func (p *blobDownloadPart) StopsAt() int64 {
	return p.Offset + p.Size
}

82
83
84
85
86
87
88
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
	n = len(b)
	p.blobDownload.Completed.Add(int64(n))
	p.lastUpdated = time.Now()
	return n, nil
}

Michael Yang's avatar
Michael Yang committed
89
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
90
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
91
	if err != nil {
92
		return err
93
94
	}

Michael Yang's avatar
Michael Yang committed
95
	for _, partFilePath := range partFilePaths {
96
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
97
98
		if err != nil {
			return err
99
100
		}

101
102
103
		b.Total += part.Size
		b.Completed.Add(part.Completed)
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
104
	}
105

106
	if len(b.Parts) == 0 {
Michael Yang's avatar
Michael Yang committed
107
		resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
108
		if err != nil {
Michael Yang's avatar
Michael Yang committed
109
110
111
112
			return err
		}
		defer resp.Body.Close()

113
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
114

Michael Yang's avatar
Michael Yang committed
115
		size := b.Total / numDownloadParts
116
117
118
119
120
121
		switch {
		case size < minDownloadPartSize:
			size = minDownloadPartSize
		case size > maxDownloadPartSize:
			size = maxDownloadPartSize
		}
Michael Yang's avatar
Michael Yang committed
122

123
		var offset int64
124
125
126
127
128
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
129
			if err := b.newPart(offset, size); err != nil {
130
				return err
Michael Yang's avatar
Michael Yang committed
131
132
133
			}

			offset += size
134
135
136
		}
	}

137
	slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
138
139
140
	return nil
}

Michael Yang's avatar
Michael Yang committed
141
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
142
143
144
	defer blobDownloadManager.Delete(b.Digest)
	ctx, b.CancelFunc = context.WithCancel(ctx)

Michael Yang's avatar
Michael Yang committed
145
	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
146
	if err != nil {
Michael Yang's avatar
Michael Yang committed
147
148
		b.err = err
		return
Michael Yang's avatar
Michael Yang committed
149
	}
150
	defer file.Close()
151

Michael Yang's avatar
Michael Yang committed
152
	_ = file.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
153

154
	g, inner := NewLimitGroup(ctx, numDownloadParts)
155
156
	for i := range b.Parts {
		part := b.Parts[i]
Michael Yang's avatar
Michael Yang committed
157
158
159
		if part.Completed == part.Size {
			continue
		}
160

Michael Yang's avatar
Michael Yang committed
161
		g.Go(func() error {
Michael Yang's avatar
Michael Yang committed
162
			var err error
Michael Yang's avatar
Michael Yang committed
163
			for try := 0; try < maxRetries; try++ {
164
				w := io.NewOffsetWriter(file, part.StartsAt())
Michael Yang's avatar
Michael Yang committed
165
				err = b.downloadChunk(inner, requestURL, w, part, opts)
166
				switch {
167
168
				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
					// return immediately if the context is canceled or the device is out of space
169
					return err
170
171
172
				case errors.Is(err, errPartStalled):
					try--
					continue
173
				case err != nil:
Jeffrey Morgan's avatar
Jeffrey Morgan committed
174
					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
175
					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
Michael Yang's avatar
Michael Yang committed
176
					time.Sleep(sleep)
Michael Yang's avatar
Michael Yang committed
177
					continue
178
179
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
180
181
182
				}
			}

Michael Yang's avatar
Michael Yang committed
183
			return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
Michael Yang's avatar
Michael Yang committed
184
		})
185
186
	}

Michael Yang's avatar
Michael Yang committed
187
	if err := g.Wait(); err != nil {
Michael Yang's avatar
Michael Yang committed
188
189
		b.err = err
		return
190
191
	}

192
193
	// explicitly close the file so we can rename it
	if err := file.Close(); err != nil {
Michael Yang's avatar
Michael Yang committed
194
195
		b.err = err
		return
Michael Yang's avatar
Michael Yang committed
196
197
	}

198
	for i := range b.Parts {
199
		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
Michael Yang's avatar
Michael Yang committed
200
201
			b.err = err
			return
Michael Yang's avatar
Michael Yang committed
202
		}
203
204
	}

Michael Yang's avatar
Michael Yang committed
205
	if err := os.Rename(file.Name(), b.Name); err != nil {
Michael Yang's avatar
Michael Yang committed
206
207
		b.err = err
		return
Michael Yang's avatar
Michael Yang committed
208
209
210
	}

	b.done = true
Michael Yang's avatar
Michael Yang committed
211
	return
212
213
}

Michael Yang's avatar
Michael Yang committed
214
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
215
216
217
218
219
220
221
222
223
	g, ctx := errgroup.WithContext(ctx)
	g.Go(func() error {
		headers := make(http.Header)
		headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
		resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
		if err != nil {
			return err
		}
		defer resp.Body.Close()
224

225
226
227
228
229
230
		n, err := io.Copy(w, io.TeeReader(resp.Body, part))
		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
			// rollback progress
			b.Completed.Add(-n)
			return err
		}
231

232
233
234
235
236
237
		part.Completed += n
		if err := b.writePart(part.Name(), part); err != nil {
			return err
		}

		// return nil or context.Canceled or UnexpectedEOF (resumable)
Michael Yang's avatar
Michael Yang committed
238
		return err
239
240
241
242
243
244
245
246
247
248
249
250
	})

	g.Go(func() error {
		ticker := time.NewTicker(time.Second)
		for {
			select {
			case <-ticker.C:
				if part.Completed >= part.Size {
					return nil
				}

				if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
Michael Yang's avatar
Michael Yang committed
251
					slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
252
253
254
255
256
257
258
259
260
					// reset last updated
					part.lastUpdated = time.Time{}
					return errPartStalled
				}
			case <-ctx.Done():
				return ctx.Err()
			}
		}
	})
Michael Yang's avatar
Michael Yang committed
261

262
	return g.Wait()
263
264
}

Michael Yang's avatar
Michael Yang committed
265
266
267
268
269
270
271
272
273
274
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

275
276
277
278
279
280
281
282
283
284
285
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
286

Michael Yang's avatar
Michael Yang committed
287
	part.blobDownload = b
288
	return &part, nil
Michael Yang's avatar
Michael Yang committed
289
290
}

291
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
Michael Yang's avatar
Michael Yang committed
292
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
Michael Yang's avatar
Michael Yang committed
293
294
	if err != nil {
		return err
295
	}
Michael Yang's avatar
Michael Yang committed
296
	defer partFile.Close()
297

Michael Yang's avatar
Michael Yang committed
298
	return json.NewEncoder(partFile).Encode(part)
299
}
300

Michael Yang's avatar
Michael Yang committed
301
302
303
304
305
306
307
308
309
310
func (b *blobDownload) acquire() {
	b.references.Add(1)
}

func (b *blobDownload) release() {
	if b.references.Add(-1) == 0 {
		b.CancelFunc()
	}
}

311
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
312
313
	b.acquire()
	defer b.release()
314
315
316
317
318

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
		case <-ticker.C:
319
320
321
322
323
324
325
326
327
328
			fn(api.ProgressResponse{
				Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
				Digest:    b.Digest,
				Total:     b.Total,
				Completed: b.Completed.Load(),
			})

			if b.done || b.err != nil {
				return b.err
			}
329
330
331
332
333
334
335
336
337
		case <-ctx.Done():
			return ctx.Err()
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
Michael Yang's avatar
Michael Yang committed
338
	regOpts *registryOptions
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
	fn      func(api.ProgressResponse)
}

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
		return err
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
		return err
	default:
		opts.fn(api.ProgressResponse{
Jeffrey Morgan's avatar
Jeffrey Morgan committed
356
			Status:    fmt.Sprintf("pulling %s", opts.digest[7:19]),
357
358
359
360
361
362
363
364
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

		return nil
	}

Michael Yang's avatar
names  
Michael Yang committed
365
366
	data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	download := data.(*blobDownload)
367
368
369
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
Michael Yang's avatar
names  
Michael Yang committed
370
		if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
371
			blobDownloadManager.Delete(opts.digest)
372
373
374
			return err
		}

Michael Yang's avatar
Michael Yang committed
375
		// nolint: contextcheck
Michael Yang's avatar
names  
Michael Yang committed
376
		go download.Run(context.Background(), requestURL, opts.regOpts)
377
378
	}

Michael Yang's avatar
names  
Michael Yang committed
379
	return download.Wait(ctx, opts.fn)
380
}
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418

type LimitGroup struct {
	*errgroup.Group
	context.Context
	Semaphore *semaphore.Weighted

	weight, max_weight int64
}

func NewLimitGroup(ctx context.Context, n int64) (*LimitGroup, context.Context) {
	g, ctx := errgroup.WithContext(ctx)
	return &LimitGroup{
		Group:      g,
		Context:    ctx,
		Semaphore:  semaphore.NewWeighted(n),
		weight:     n,
		max_weight: n,
	}, ctx
}

func (g *LimitGroup) Go(fn func() error) {
	weight := g.weight
	g.Semaphore.Acquire(g.Context, weight)
	if g.Context.Err() != nil {
		return
	}

	g.Group.Go(func() error {
		defer g.Semaphore.Release(weight)
		return fn()
	})
}

func (g *LimitGroup) SetLimit(n int64) {
	if n > 0 {
		g.weight = g.max_weight / n
	}
}