download.go 6.17 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
9
10
	"errors"
	"fmt"
	"io"
	"log"
	"net/http"
Michael Yang's avatar
Michael Yang committed
11
	"net/url"
12
	"os"
Michael Yang's avatar
Michael Yang committed
13
	"path/filepath"
14
	"strconv"
15
16
17
	"sync"
	"sync/atomic"
	"time"
18

Michael Yang's avatar
Michael Yang committed
19
	"golang.org/x/sync/errgroup"
20
21

	"github.com/jmorganca/ollama/api"
22
23
)

24
var blobDownloadManager sync.Map
25

26
27
28
type blobDownload struct {
	Name   string
	Digest string
29

30
31
	Total     int64
	Completed atomic.Int64
32

33
34
	*os.File
	Parts []*blobDownloadPart
35

36
37
38
39
	done chan struct{}
	context.CancelFunc
	refCount atomic.Int32
}
40

41
42
43
44
45
type blobDownloadPart struct {
	Offset    int64
	Size      int64
	Completed int64
}
46

47
48
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
49
	if err != nil {
50
		return err
51
52
	}

Michael Yang's avatar
Michael Yang committed
53
	for _, partFilePath := range partFilePaths {
54
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
55
56
		if err != nil {
			return err
57
58
		}

59
60
61
		b.Total += part.Size
		b.Completed.Add(part.Completed)
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
62
	}
63

64
65
	if len(b.Parts) == 0 {
		resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
66
		if err != nil {
Michael Yang's avatar
Michael Yang committed
67
68
69
70
			return err
		}
		defer resp.Body.Close()

71
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
72
73
74
75

		var offset int64
		var size int64 = 64 * 1024 * 1024

76
77
78
79
80
81
82
83
84
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

			partName := b.Name + "-partial-" + strconv.Itoa(len(b.Parts))
			part := blobDownloadPart{Offset: offset, Size: size}
			if err := b.writePart(partName, &part); err != nil {
				return err
Michael Yang's avatar
Michael Yang committed
85
86
			}

87
			b.Parts = append(b.Parts, &part)
Michael Yang's avatar
Michael Yang committed
88
89

			offset += size
90
91
92
		}
	}

93
94
95
96
97
98
99
100
101
102
103
104
	log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts))
	return nil
}

func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) {
	defer blobDownloadManager.Delete(b.Digest)

	ctx, b.CancelFunc = context.WithCancel(ctx)

	b.File, err = os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
	if err != nil {
		return err
Michael Yang's avatar
Michael Yang committed
105
	}
106
107
108
	defer b.Close()

	b.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
109

Michael Yang's avatar
Michael Yang committed
110
111
112
	b.done = make(chan struct{}, 1)
	defer close(b.done)

Michael Yang's avatar
Michael Yang committed
113
114
	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(64)
115
116
	for i := range b.Parts {
		part := b.Parts[i]
Michael Yang's avatar
Michael Yang committed
117
118
119
		if part.Completed == part.Size {
			continue
		}
120

Michael Yang's avatar
Michael Yang committed
121
122
123
		i := i
		g.Go(func() error {
			for try := 0; try < maxRetries; try++ {
124
125
126
127
128
129
				err := b.downloadChunk(ctx, requestURL, i, opts)
				switch {
				case errors.Is(err, context.Canceled):
					return err
				case err != nil:
					log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
Michael Yang's avatar
Michael Yang committed
130
					continue
131
132
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
133
134
135
136
137
				}
			}

			return errors.New("max retries exceeded")
		})
138
139
	}

Michael Yang's avatar
Michael Yang committed
140
141
	if err := g.Wait(); err != nil {
		return err
142
143
	}

144
	if err := b.Close(); err != nil {
Michael Yang's avatar
Michael Yang committed
145
146
147
		return err
	}

148
149
	for i := range b.Parts {
		if err := os.Remove(b.File.Name() + "-" + strconv.Itoa(i)); err != nil {
Michael Yang's avatar
Michael Yang committed
150
151
			return err
		}
152
153
	}

154
	if err := os.Rename(b.File.Name(), b.Name); err != nil {
Michael Yang's avatar
Michael Yang committed
155
		return err
156
157
	}

158
159
160
161
162
163
164
	return nil
}

func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i int, opts *RegistryOptions) error {
	part := b.Parts[i]

	partName := b.File.Name() + "-" + strconv.Itoa(i)
Michael Yang's avatar
Michael Yang committed
165
	offset := part.Offset + part.Completed
166
	w := io.NewOffsetWriter(b.File, offset)
167

Michael Yang's avatar
Michael Yang committed
168
169
	headers := make(http.Header)
	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
170
	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
Michael Yang's avatar
Michael Yang committed
171
172
173
174
	if err != nil {
		return err
	}
	defer resp.Body.Close()
175

176
	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
Michael Yang's avatar
Michael Yang committed
177
	if err != nil && !errors.Is(err, context.Canceled) {
178
179
		// rollback progress
		b.Completed.Add(-n)
Michael Yang's avatar
Michael Yang committed
180
181
		return err
	}
182

Michael Yang's avatar
Michael Yang committed
183
	part.Completed += n
Michael Yang's avatar
Michael Yang committed
184
185
186
187
188
189
	if err := b.writePart(partName, part); err != nil {
		return err
	}

	// return nil or context.Canceled
	return err
190
191
192
193
194
195
196
197
198
199
200
201
202
}

func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
203

204
	return &part, nil
Michael Yang's avatar
Michael Yang committed
205
206
}

207
208
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
Michael Yang's avatar
Michael Yang committed
209
210
	if err != nil {
		return err
211
	}
Michael Yang's avatar
Michael Yang committed
212
	defer partFile.Close()
213

Michael Yang's avatar
Michael Yang committed
214
	return json.NewEncoder(partFile).Encode(part)
215
}
216
217
218
219
220
221
222
223
224
225
226
227
228

func (b *blobDownload) Write(p []byte) (n int, err error) {
	n = len(p)
	b.Completed.Add(int64(n))
	return n, nil
}

func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
	b.refCount.Add(1)

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
Michael Yang's avatar
Michael Yang committed
229
230
231
232
		case <-b.done:
			if b.Completed.Load() != b.Total {
				return io.ErrUnexpectedEOF
			}
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
		case <-ticker.C:
		case <-ctx.Done():
			if b.refCount.Add(-1) == 0 {
				b.CancelFunc()
			}

			return ctx.Err()
		}

		fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", b.Digest),
			Digest:    b.Digest,
			Total:     b.Total,
			Completed: b.Completed.Load(),
		})

		if b.Completed.Load() >= b.Total {
			<-b.done
			return nil
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
	regOpts *RegistryOptions
	fn      func(api.ProgressResponse)
}

const maxRetries = 3

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
		return err
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
		return err
	default:
		opts.fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", opts.digest),
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

		return nil
	}

	value, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	blobDownload := value.(*blobDownload)
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
		if err := blobDownload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
			return err
		}

		go blobDownload.Run(context.Background(), requestURL, opts.regOpts)
	}

	return blobDownload.Wait(ctx, opts.fn)
}