download.go 6.72 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
9
10
	"errors"
	"fmt"
	"io"
	"log"
	"net/http"
Michael Yang's avatar
Michael Yang committed
11
	"net/url"
12
	"os"
Michael Yang's avatar
Michael Yang committed
13
	"path/filepath"
14
	"strconv"
Michael Yang's avatar
Michael Yang committed
15
	"strings"
16
17
18
	"sync"
	"sync/atomic"
	"time"
19

Michael Yang's avatar
Michael Yang committed
20
	"golang.org/x/sync/errgroup"
21
22

	"github.com/jmorganca/ollama/api"
23
24
)

25
var blobDownloadManager sync.Map
26

27
28
29
type blobDownload struct {
	Name   string
	Digest string
30

31
32
	Total     int64
	Completed atomic.Int64
33

34
35
	*os.File
	Parts []*blobDownloadPart
36

37
38
	done chan struct{}
	context.CancelFunc
Michael Yang's avatar
Michael Yang committed
39
	references atomic.Int32
40
}
41

42
type blobDownloadPart struct {
Michael Yang's avatar
Michael Yang committed
43
	N         int
44
45
46
	Offset    int64
	Size      int64
	Completed int64
Michael Yang's avatar
Michael Yang committed
47
48
49
50
51
52
53
54

	*blobDownload `json:"-"`
}

func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
55
}
56

57
58
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
59
	if err != nil {
60
		return err
61
62
	}

Michael Yang's avatar
Michael Yang committed
63
	for _, partFilePath := range partFilePaths {
64
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
65
66
		if err != nil {
			return err
67
68
		}

69
70
71
		b.Total += part.Size
		b.Completed.Add(part.Completed)
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
72
	}
73

74
75
	if len(b.Parts) == 0 {
		resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
76
		if err != nil {
Michael Yang's avatar
Michael Yang committed
77
78
79
80
			return err
		}
		defer resp.Body.Close()

Michael Yang's avatar
Michael Yang committed
81
82
83
84
85
		if resp.StatusCode >= http.StatusBadRequest {
			body, _ := io.ReadAll(resp.Body)
			return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
		}

86
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
87
88
89
90

		var offset int64
		var size int64 = 64 * 1024 * 1024

91
92
93
94
95
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
96
			if err := b.newPart(offset, size); err != nil {
97
				return err
Michael Yang's avatar
Michael Yang committed
98
99
100
			}

			offset += size
101
102
103
		}
	}

104
105
106
107
108
109
110
111
112
113
114
115
	log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts))
	return nil
}

func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) {
	defer blobDownloadManager.Delete(b.Digest)

	ctx, b.CancelFunc = context.WithCancel(ctx)

	b.File, err = os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
	if err != nil {
		return err
Michael Yang's avatar
Michael Yang committed
116
	}
117
118
119
	defer b.Close()

	b.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
120

Michael Yang's avatar
Michael Yang committed
121
122
123
	b.done = make(chan struct{}, 1)
	defer close(b.done)

Michael Yang's avatar
Michael Yang committed
124
125
	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(64)
126
127
	for i := range b.Parts {
		part := b.Parts[i]
Michael Yang's avatar
Michael Yang committed
128
129
130
		if part.Completed == part.Size {
			continue
		}
131

Michael Yang's avatar
Michael Yang committed
132
133
134
		i := i
		g.Go(func() error {
			for try := 0; try < maxRetries; try++ {
135
136
137
138
139
140
				err := b.downloadChunk(ctx, requestURL, i, opts)
				switch {
				case errors.Is(err, context.Canceled):
					return err
				case err != nil:
					log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
Michael Yang's avatar
Michael Yang committed
141
					continue
142
143
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
144
145
146
147
148
				}
			}

			return errors.New("max retries exceeded")
		})
149
150
	}

Michael Yang's avatar
Michael Yang committed
151
152
	if err := g.Wait(); err != nil {
		return err
153
154
	}

155
	if err := b.Close(); err != nil {
Michael Yang's avatar
Michael Yang committed
156
157
158
		return err
	}

159
160
	for i := range b.Parts {
		if err := os.Remove(b.File.Name() + "-" + strconv.Itoa(i)); err != nil {
Michael Yang's avatar
Michael Yang committed
161
162
			return err
		}
163
164
	}

165
	if err := os.Rename(b.File.Name(), b.Name); err != nil {
Michael Yang's avatar
Michael Yang committed
166
		return err
167
168
	}

169
170
171
172
173
174
	return nil
}

func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i int, opts *RegistryOptions) error {
	part := b.Parts[i]

Michael Yang's avatar
Michael Yang committed
175
	offset := part.Offset + part.Completed
176
	w := io.NewOffsetWriter(b.File, offset)
177

Michael Yang's avatar
Michael Yang committed
178
179
	headers := make(http.Header)
	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
180
	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
Michael Yang's avatar
Michael Yang committed
181
182
183
184
	if err != nil {
		return err
	}
	defer resp.Body.Close()
185

186
	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
Michael Yang's avatar
Michael Yang committed
187
	if err != nil && !errors.Is(err, context.Canceled) {
188
189
		// rollback progress
		b.Completed.Add(-n)
Michael Yang's avatar
Michael Yang committed
190
191
		return err
	}
192

Michael Yang's avatar
Michael Yang committed
193
	part.Completed += n
Michael Yang's avatar
Michael Yang committed
194
	if err := b.writePart(part.Name(), part); err != nil {
Michael Yang's avatar
Michael Yang committed
195
196
197
198
199
		return err
	}

	// return nil or context.Canceled
	return err
200
201
}

Michael Yang's avatar
Michael Yang committed
202
203
204
205
206
207
208
209
210
211
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

212
213
214
215
216
217
218
219
220
221
222
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
223

Michael Yang's avatar
Michael Yang committed
224
	part.blobDownload = b
225
	return &part, nil
Michael Yang's avatar
Michael Yang committed
226
227
}

228
229
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
Michael Yang's avatar
Michael Yang committed
230
231
	if err != nil {
		return err
232
	}
Michael Yang's avatar
Michael Yang committed
233
	defer partFile.Close()
234

Michael Yang's avatar
Michael Yang committed
235
	return json.NewEncoder(partFile).Encode(part)
236
}
237
238
239
240
241
242
243

func (b *blobDownload) Write(p []byte) (n int, err error) {
	n = len(p)
	b.Completed.Add(int64(n))
	return n, nil
}

Michael Yang's avatar
Michael Yang committed
244
245
246
247
248
249
250
251
252
253
func (b *blobDownload) acquire() {
	b.references.Add(1)
}

func (b *blobDownload) release() {
	if b.references.Add(-1) == 0 {
		b.CancelFunc()
	}
}

254
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
255
256
	b.acquire()
	defer b.release()
257
258
259
260

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
Michael Yang's avatar
Michael Yang committed
261
262
263
264
		case <-b.done:
			if b.Completed.Load() != b.Total {
				return io.ErrUnexpectedEOF
			}
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
		case <-ticker.C:
		case <-ctx.Done():
			return ctx.Err()
		}

		fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", b.Digest),
			Digest:    b.Digest,
			Total:     b.Total,
			Completed: b.Completed.Load(),
		})

		if b.Completed.Load() >= b.Total {
			<-b.done
			return nil
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
	regOpts *RegistryOptions
	fn      func(api.ProgressResponse)
}

const maxRetries = 3

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
		return err
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
		return err
	default:
		opts.fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", opts.digest),
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

		return nil
	}

	value, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	blobDownload := value.(*blobDownload)
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
		if err := blobDownload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
			return err
		}

		go blobDownload.Run(context.Background(), requestURL, opts.regOpts)
	}

	return blobDownload.Wait(ctx, opts.fn)
}