download.go 7.42 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
9
10
	"errors"
	"fmt"
	"io"
	"log"
	"net/http"
Michael Yang's avatar
Michael Yang committed
11
	"net/url"
12
	"os"
Michael Yang's avatar
Michael Yang committed
13
	"path/filepath"
14
	"strconv"
Michael Yang's avatar
Michael Yang committed
15
	"strings"
16
17
	"sync"
	"sync/atomic"
18
	"syscall"
19
	"time"
20

Michael Yang's avatar
Michael Yang committed
21
	"golang.org/x/sync/errgroup"
22
23

	"github.com/jmorganca/ollama/api"
24
	"github.com/jmorganca/ollama/format"
25
26
)

27
var blobDownloadManager sync.Map
28

29
30
31
type blobDownload struct {
	Name   string
	Digest string
32

33
34
	Total     int64
	Completed atomic.Int64
35

36
	Parts []*blobDownloadPart
37

38
	context.CancelFunc
Michael Yang's avatar
Michael Yang committed
39
40
41

	done       bool
	err        error
Michael Yang's avatar
Michael Yang committed
42
	references atomic.Int32
43
}
44

45
type blobDownloadPart struct {
Michael Yang's avatar
Michael Yang committed
46
	N         int
47
48
49
	Offset    int64
	Size      int64
	Completed int64
Michael Yang's avatar
Michael Yang committed
50
51
52
53

	*blobDownload `json:"-"`
}

54
55
56
57
58
59
const (
	numDownloadParts          = 64
	minDownloadPartSize int64 = 32 * 1000 * 1000
	maxDownloadPartSize int64 = 256 * 1000 * 1000
)

Michael Yang's avatar
Michael Yang committed
60
61
62
63
func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
64
}
65

66
67
68
69
70
71
72
73
func (p *blobDownloadPart) StartsAt() int64 {
	return p.Offset + p.Completed
}

func (p *blobDownloadPart) StopsAt() int64 {
	return p.Offset + p.Size
}

74
75
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
76
	if err != nil {
77
		return err
78
79
	}

Michael Yang's avatar
Michael Yang committed
80
	for _, partFilePath := range partFilePaths {
81
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
82
83
		if err != nil {
			return err
84
85
		}

86
87
88
		b.Total += part.Size
		b.Completed.Add(part.Completed)
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
89
	}
90

91
92
	if len(b.Parts) == 0 {
		resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
93
		if err != nil {
Michael Yang's avatar
Michael Yang committed
94
95
96
97
			return err
		}
		defer resp.Body.Close()

Michael Yang's avatar
Michael Yang committed
98
99
100
101
102
		if resp.StatusCode >= http.StatusBadRequest {
			body, _ := io.ReadAll(resp.Body)
			return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
		}

103
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
104

105
106
107
108
109
110
111
		var size = b.Total / numDownloadParts
		switch {
		case size < minDownloadPartSize:
			size = minDownloadPartSize
		case size > maxDownloadPartSize:
			size = maxDownloadPartSize
		}
Michael Yang's avatar
Michael Yang committed
112

113
		var offset int64
114
115
116
117
118
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
119
			if err := b.newPart(offset, size); err != nil {
120
				return err
Michael Yang's avatar
Michael Yang committed
121
122
123
			}

			offset += size
124
125
126
		}
	}

127
	log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
128
129
130
	return nil
}

Michael Yang's avatar
Michael Yang committed
131
132
133
134
135
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
	b.err = b.run(ctx, requestURL, opts)
}

func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
136
137
138
139
	defer blobDownloadManager.Delete(b.Digest)

	ctx, b.CancelFunc = context.WithCancel(ctx)

140
	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
141
142
	if err != nil {
		return err
Michael Yang's avatar
Michael Yang committed
143
	}
144
	defer file.Close()
145

146
	file.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
147

Michael Yang's avatar
Michael Yang committed
148
	g, inner := errgroup.WithContext(ctx)
149
	g.SetLimit(numDownloadParts)
150
151
	for i := range b.Parts {
		part := b.Parts[i]
Michael Yang's avatar
Michael Yang committed
152
153
154
		if part.Completed == part.Size {
			continue
		}
155

Michael Yang's avatar
Michael Yang committed
156
157
158
		i := i
		g.Go(func() error {
			for try := 0; try < maxRetries; try++ {
159
				w := io.NewOffsetWriter(file, part.StartsAt())
Michael Yang's avatar
Michael Yang committed
160
				err := b.downloadChunk(inner, requestURL, w, part, opts)
161
				switch {
162
163
				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
					// return immediately if the context is canceled or the device is out of space
164
165
166
					return err
				case err != nil:
					log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
Michael Yang's avatar
Michael Yang committed
167
					continue
168
169
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
170
171
172
173
174
				}
			}

			return errors.New("max retries exceeded")
		})
175
176
	}

Michael Yang's avatar
Michael Yang committed
177
178
	if err := g.Wait(); err != nil {
		return err
179
180
	}

181
182
	// explicitly close the file so we can rename it
	if err := file.Close(); err != nil {
Michael Yang's avatar
Michael Yang committed
183
184
185
		return err
	}

186
	for i := range b.Parts {
187
		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
Michael Yang's avatar
Michael Yang committed
188
189
			return err
		}
190
191
	}

Michael Yang's avatar
Michael Yang committed
192
193
194
195
196
197
	if err := os.Rename(file.Name(), b.Name); err != nil {
		return err
	}

	b.done = true
	return nil
198
199
}

200
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
Michael Yang's avatar
Michael Yang committed
201
	headers := make(http.Header)
202
	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
203
	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
Michael Yang's avatar
Michael Yang committed
204
205
206
207
	if err != nil {
		return err
	}
	defer resp.Body.Close()
208

209
	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
Michael Yang's avatar
Michael Yang committed
210
	if err != nil && !errors.Is(err, context.Canceled) {
211
212
		// rollback progress
		b.Completed.Add(-n)
Michael Yang's avatar
Michael Yang committed
213
214
		return err
	}
215

Michael Yang's avatar
Michael Yang committed
216
	part.Completed += n
Michael Yang's avatar
Michael Yang committed
217
	if err := b.writePart(part.Name(), part); err != nil {
Michael Yang's avatar
Michael Yang committed
218
219
220
221
222
		return err
	}

	// return nil or context.Canceled
	return err
223
224
}

Michael Yang's avatar
Michael Yang committed
225
226
227
228
229
230
231
232
233
234
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

235
236
237
238
239
240
241
242
243
244
245
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
246

Michael Yang's avatar
Michael Yang committed
247
	part.blobDownload = b
248
	return &part, nil
Michael Yang's avatar
Michael Yang committed
249
250
}

251
252
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
Michael Yang's avatar
Michael Yang committed
253
254
	if err != nil {
		return err
255
	}
Michael Yang's avatar
Michael Yang committed
256
	defer partFile.Close()
257

Michael Yang's avatar
Michael Yang committed
258
	return json.NewEncoder(partFile).Encode(part)
259
}
260
261
262
263
264
265
266

func (b *blobDownload) Write(p []byte) (n int, err error) {
	n = len(p)
	b.Completed.Add(int64(n))
	return n, nil
}

Michael Yang's avatar
Michael Yang committed
267
268
269
270
271
272
273
274
275
276
func (b *blobDownload) acquire() {
	b.references.Add(1)
}

func (b *blobDownload) release() {
	if b.references.Add(-1) == 0 {
		b.CancelFunc()
	}
}

277
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
278
279
	b.acquire()
	defer b.release()
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
		case <-ticker.C:
		case <-ctx.Done():
			return ctx.Err()
		}

		fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", b.Digest),
			Digest:    b.Digest,
			Total:     b.Total,
			Completed: b.Completed.Load(),
		})

Michael Yang's avatar
Michael Yang committed
296
297
		if b.done || b.err != nil {
			return b.err
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
	regOpts *RegistryOptions
	fn      func(api.ProgressResponse)
}

const maxRetries = 3

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
		return err
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
		return err
	default:
		opts.fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", opts.digest),
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

		return nil
	}

Michael Yang's avatar
names  
Michael Yang committed
334
335
	data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	download := data.(*blobDownload)
336
337
338
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
Michael Yang's avatar
names  
Michael Yang committed
339
		if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
340
			blobDownloadManager.Delete(opts.digest)
341
342
343
			return err
		}

Michael Yang's avatar
names  
Michael Yang committed
344
		go download.Run(context.Background(), requestURL, opts.regOpts)
345
346
	}

Michael Yang's avatar
names  
Michael Yang committed
347
	return download.Wait(ctx, opts.fn)
348
}