download.go 7.12 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
9
10
	"errors"
	"fmt"
	"io"
	"log"
	"net/http"
Michael Yang's avatar
Michael Yang committed
11
	"net/url"
12
	"os"
Michael Yang's avatar
Michael Yang committed
13
	"path/filepath"
14
	"strconv"
Michael Yang's avatar
Michael Yang committed
15
	"strings"
16
17
18
	"sync"
	"sync/atomic"
	"time"
19

Michael Yang's avatar
Michael Yang committed
20
	"golang.org/x/sync/errgroup"
21
22

	"github.com/jmorganca/ollama/api"
23
	"github.com/jmorganca/ollama/format"
24
25
)

26
var blobDownloadManager sync.Map
27

28
29
30
type blobDownload struct {
	Name   string
	Digest string
31

32
33
	Total     int64
	Completed atomic.Int64
Michael Yang's avatar
Michael Yang committed
34
	done      bool
35

36
	Parts []*blobDownloadPart
37

38
	context.CancelFunc
Michael Yang's avatar
Michael Yang committed
39
	references atomic.Int32
40
}
41

42
type blobDownloadPart struct {
Michael Yang's avatar
Michael Yang committed
43
	N         int
44
45
46
	Offset    int64
	Size      int64
	Completed int64
Michael Yang's avatar
Michael Yang committed
47
48
49
50

	*blobDownload `json:"-"`
}

51
52
53
54
55
56
const (
	numDownloadParts          = 64
	minDownloadPartSize int64 = 32 * 1000 * 1000
	maxDownloadPartSize int64 = 256 * 1000 * 1000
)

Michael Yang's avatar
Michael Yang committed
57
58
59
60
func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
61
}
62

63
64
65
66
67
68
69
70
func (p *blobDownloadPart) StartsAt() int64 {
	return p.Offset + p.Completed
}

func (p *blobDownloadPart) StopsAt() int64 {
	return p.Offset + p.Size
}

71
72
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
73
	if err != nil {
74
		return err
75
76
	}

Michael Yang's avatar
Michael Yang committed
77
	for _, partFilePath := range partFilePaths {
78
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
79
80
		if err != nil {
			return err
81
82
		}

83
84
85
		b.Total += part.Size
		b.Completed.Add(part.Completed)
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
86
	}
87

88
89
	if len(b.Parts) == 0 {
		resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
90
		if err != nil {
Michael Yang's avatar
Michael Yang committed
91
92
93
94
			return err
		}
		defer resp.Body.Close()

Michael Yang's avatar
Michael Yang committed
95
96
97
98
99
		if resp.StatusCode >= http.StatusBadRequest {
			body, _ := io.ReadAll(resp.Body)
			return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
		}

100
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
101

102
103
104
105
106
107
108
		var size = b.Total / numDownloadParts
		switch {
		case size < minDownloadPartSize:
			size = minDownloadPartSize
		case size > maxDownloadPartSize:
			size = maxDownloadPartSize
		}
Michael Yang's avatar
Michael Yang committed
109

110
		var offset int64
111
112
113
114
115
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
116
			if err := b.newPart(offset, size); err != nil {
117
				return err
Michael Yang's avatar
Michael Yang committed
118
119
120
			}

			offset += size
121
122
123
		}
	}

124
	log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
125
126
127
128
129
130
131
132
	return nil
}

func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) {
	defer blobDownloadManager.Delete(b.Digest)

	ctx, b.CancelFunc = context.WithCancel(ctx)

133
	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
134
135
	if err != nil {
		return err
Michael Yang's avatar
Michael Yang committed
136
	}
137
	defer file.Close()
138

139
	file.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
140

Michael Yang's avatar
Michael Yang committed
141
	g, _ := errgroup.WithContext(ctx)
142
	g.SetLimit(numDownloadParts)
143
144
	for i := range b.Parts {
		part := b.Parts[i]
Michael Yang's avatar
Michael Yang committed
145
146
147
		if part.Completed == part.Size {
			continue
		}
148

Michael Yang's avatar
Michael Yang committed
149
150
151
		i := i
		g.Go(func() error {
			for try := 0; try < maxRetries; try++ {
152
153
				w := io.NewOffsetWriter(file, part.StartsAt())
				err := b.downloadChunk(ctx, requestURL, w, part, opts)
154
155
156
157
158
				switch {
				case errors.Is(err, context.Canceled):
					return err
				case err != nil:
					log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
Michael Yang's avatar
Michael Yang committed
159
					continue
160
161
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
162
163
164
165
166
				}
			}

			return errors.New("max retries exceeded")
		})
167
168
	}

Michael Yang's avatar
Michael Yang committed
169
170
	if err := g.Wait(); err != nil {
		return err
171
172
	}

173
174
	// explicitly close the file so we can rename it
	if err := file.Close(); err != nil {
Michael Yang's avatar
Michael Yang committed
175
176
177
		return err
	}

178
	for i := range b.Parts {
179
		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
Michael Yang's avatar
Michael Yang committed
180
181
			return err
		}
182
183
	}

Michael Yang's avatar
Michael Yang committed
184
185
186
187
188
189
	if err := os.Rename(file.Name(), b.Name); err != nil {
		return err
	}

	b.done = true
	return nil
190
191
}

192
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
Michael Yang's avatar
Michael Yang committed
193
	headers := make(http.Header)
194
	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
195
	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
Michael Yang's avatar
Michael Yang committed
196
197
198
199
	if err != nil {
		return err
	}
	defer resp.Body.Close()
200

201
	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
Michael Yang's avatar
Michael Yang committed
202
	if err != nil && !errors.Is(err, context.Canceled) {
203
204
		// rollback progress
		b.Completed.Add(-n)
Michael Yang's avatar
Michael Yang committed
205
206
		return err
	}
207

Michael Yang's avatar
Michael Yang committed
208
	part.Completed += n
Michael Yang's avatar
Michael Yang committed
209
	if err := b.writePart(part.Name(), part); err != nil {
Michael Yang's avatar
Michael Yang committed
210
211
212
213
214
		return err
	}

	// return nil or context.Canceled
	return err
215
216
}

Michael Yang's avatar
Michael Yang committed
217
218
219
220
221
222
223
224
225
226
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

227
228
229
230
231
232
233
234
235
236
237
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
238

Michael Yang's avatar
Michael Yang committed
239
	part.blobDownload = b
240
	return &part, nil
Michael Yang's avatar
Michael Yang committed
241
242
}

243
244
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
Michael Yang's avatar
Michael Yang committed
245
246
	if err != nil {
		return err
247
	}
Michael Yang's avatar
Michael Yang committed
248
	defer partFile.Close()
249

Michael Yang's avatar
Michael Yang committed
250
	return json.NewEncoder(partFile).Encode(part)
251
}
252
253
254
255
256
257
258

func (b *blobDownload) Write(p []byte) (n int, err error) {
	n = len(p)
	b.Completed.Add(int64(n))
	return n, nil
}

Michael Yang's avatar
Michael Yang committed
259
260
261
262
263
264
265
266
267
268
func (b *blobDownload) acquire() {
	b.references.Add(1)
}

func (b *blobDownload) release() {
	if b.references.Add(-1) == 0 {
		b.CancelFunc()
	}
}

269
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
270
271
	b.acquire()
	defer b.release()
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
		case <-ticker.C:
		case <-ctx.Done():
			return ctx.Err()
		}

		fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", b.Digest),
			Digest:    b.Digest,
			Total:     b.Total,
			Completed: b.Completed.Load(),
		})

Michael Yang's avatar
Michael Yang committed
288
289
		if b.done {
			return nil
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
	regOpts *RegistryOptions
	fn      func(api.ProgressResponse)
}

const maxRetries = 3

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
		return err
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
		return err
	default:
		opts.fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", opts.digest),
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

		return nil
	}

Michael Yang's avatar
names  
Michael Yang committed
326
327
	data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	download := data.(*blobDownload)
328
329
330
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
Michael Yang's avatar
names  
Michael Yang committed
331
		if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
332
			blobDownloadManager.Delete(opts.digest)
333
334
335
			return err
		}

Michael Yang's avatar
names  
Michael Yang committed
336
		go download.Run(context.Background(), requestURL, opts.regOpts)
337
338
	}

Michael Yang's avatar
names  
Michael Yang committed
339
	return download.Wait(ctx, opts.fn)
340
}