download.go 6.44 KB
Newer Older
1
2
3
4
package server

import (
	"context"
Michael Yang's avatar
Michael Yang committed
5
	"encoding/json"
6
7
8
9
10
	"errors"
	"fmt"
	"io"
	"log"
	"net/http"
Michael Yang's avatar
Michael Yang committed
11
	"net/url"
12
	"os"
Michael Yang's avatar
Michael Yang committed
13
	"path/filepath"
14
	"strconv"
Michael Yang's avatar
Michael Yang committed
15
	"strings"
16
17
18
	"sync"
	"sync/atomic"
	"time"
19

Michael Yang's avatar
Michael Yang committed
20
	"golang.org/x/sync/errgroup"
21
22

	"github.com/jmorganca/ollama/api"
23
24
)

25
var blobDownloadManager sync.Map
26

27
28
29
type blobDownload struct {
	Name   string
	Digest string
30

31
32
	Total     int64
	Completed atomic.Int64
33

34
35
	*os.File
	Parts []*blobDownloadPart
36

37
38
39
40
	done chan struct{}
	context.CancelFunc
	refCount atomic.Int32
}
41

42
type blobDownloadPart struct {
Michael Yang's avatar
Michael Yang committed
43
	N         int
44
45
46
	Offset    int64
	Size      int64
	Completed int64
Michael Yang's avatar
Michael Yang committed
47
48
49
50
51
52
53
54

	*blobDownload `json:"-"`
}

func (p *blobDownloadPart) Name() string {
	return strings.Join([]string{
		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
	}, "-")
55
}
56

57
58
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
Michael Yang's avatar
Michael Yang committed
59
	if err != nil {
60
		return err
61
62
	}

Michael Yang's avatar
Michael Yang committed
63
	for _, partFilePath := range partFilePaths {
64
		part, err := b.readPart(partFilePath)
Bruce MacDonald's avatar
Bruce MacDonald committed
65
66
		if err != nil {
			return err
67
68
		}

69
70
71
		b.Total += part.Size
		b.Completed.Add(part.Completed)
		b.Parts = append(b.Parts, part)
Michael Yang's avatar
Michael Yang committed
72
	}
73

74
75
	if len(b.Parts) == 0 {
		resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
76
		if err != nil {
Michael Yang's avatar
Michael Yang committed
77
78
79
80
			return err
		}
		defer resp.Body.Close()

81
		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Michael Yang's avatar
Michael Yang committed
82
83
84
85

		var offset int64
		var size int64 = 64 * 1024 * 1024

86
87
88
89
90
		for offset < b.Total {
			if offset+size > b.Total {
				size = b.Total - offset
			}

Michael Yang's avatar
Michael Yang committed
91
			if err := b.newPart(offset, size); err != nil {
92
				return err
Michael Yang's avatar
Michael Yang committed
93
94
95
			}

			offset += size
96
97
98
		}
	}

99
100
101
102
103
104
105
106
107
108
109
110
	log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts))
	return nil
}

func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) {
	defer blobDownloadManager.Delete(b.Digest)

	ctx, b.CancelFunc = context.WithCancel(ctx)

	b.File, err = os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
	if err != nil {
		return err
Michael Yang's avatar
Michael Yang committed
111
	}
112
113
114
	defer b.Close()

	b.Truncate(b.Total)
Michael Yang's avatar
Michael Yang committed
115

Michael Yang's avatar
Michael Yang committed
116
117
118
	b.done = make(chan struct{}, 1)
	defer close(b.done)

Michael Yang's avatar
Michael Yang committed
119
120
	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(64)
121
122
	for i := range b.Parts {
		part := b.Parts[i]
Michael Yang's avatar
Michael Yang committed
123
124
125
		if part.Completed == part.Size {
			continue
		}
126

Michael Yang's avatar
Michael Yang committed
127
128
129
		i := i
		g.Go(func() error {
			for try := 0; try < maxRetries; try++ {
130
131
132
133
134
135
				err := b.downloadChunk(ctx, requestURL, i, opts)
				switch {
				case errors.Is(err, context.Canceled):
					return err
				case err != nil:
					log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
Michael Yang's avatar
Michael Yang committed
136
					continue
137
138
				default:
					return nil
Michael Yang's avatar
Michael Yang committed
139
140
141
142
143
				}
			}

			return errors.New("max retries exceeded")
		})
144
145
	}

Michael Yang's avatar
Michael Yang committed
146
147
	if err := g.Wait(); err != nil {
		return err
148
149
	}

150
	if err := b.Close(); err != nil {
Michael Yang's avatar
Michael Yang committed
151
152
153
		return err
	}

154
155
	for i := range b.Parts {
		if err := os.Remove(b.File.Name() + "-" + strconv.Itoa(i)); err != nil {
Michael Yang's avatar
Michael Yang committed
156
157
			return err
		}
158
159
	}

160
	if err := os.Rename(b.File.Name(), b.Name); err != nil {
Michael Yang's avatar
Michael Yang committed
161
		return err
162
163
	}

164
165
166
167
168
169
	return nil
}

func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i int, opts *RegistryOptions) error {
	part := b.Parts[i]

Michael Yang's avatar
Michael Yang committed
170
	offset := part.Offset + part.Completed
171
	w := io.NewOffsetWriter(b.File, offset)
172

Michael Yang's avatar
Michael Yang committed
173
174
	headers := make(http.Header)
	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
175
	resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
Michael Yang's avatar
Michael Yang committed
176
177
178
179
	if err != nil {
		return err
	}
	defer resp.Body.Close()
180

181
	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
Michael Yang's avatar
Michael Yang committed
182
	if err != nil && !errors.Is(err, context.Canceled) {
183
184
		// rollback progress
		b.Completed.Add(-n)
Michael Yang's avatar
Michael Yang committed
185
186
		return err
	}
187

Michael Yang's avatar
Michael Yang committed
188
	part.Completed += n
Michael Yang's avatar
Michael Yang committed
189
	if err := b.writePart(part.Name(), part); err != nil {
Michael Yang's avatar
Michael Yang committed
190
191
192
193
194
		return err
	}

	// return nil or context.Canceled
	return err
195
196
}

Michael Yang's avatar
Michael Yang committed
197
198
199
200
201
202
203
204
205
206
func (b *blobDownload) newPart(offset, size int64) error {
	part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
	if err := b.writePart(part.Name(), &part); err != nil {
		return err
	}

	b.Parts = append(b.Parts, &part)
	return nil
}

207
208
209
210
211
212
213
214
215
216
217
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
	var part blobDownloadPart
	partFile, err := os.Open(partName)
	if err != nil {
		return nil, err
	}
	defer partFile.Close()

	if err := json.NewDecoder(partFile).Decode(&part); err != nil {
		return nil, err
	}
218

Michael Yang's avatar
Michael Yang committed
219
	part.blobDownload = b
220
	return &part, nil
Michael Yang's avatar
Michael Yang committed
221
222
}

223
224
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
	partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
Michael Yang's avatar
Michael Yang committed
225
226
	if err != nil {
		return err
227
	}
Michael Yang's avatar
Michael Yang committed
228
	defer partFile.Close()
229

Michael Yang's avatar
Michael Yang committed
230
	return json.NewEncoder(partFile).Encode(part)
231
}
232
233
234
235
236
237
238
239
240
241
242
243
244

func (b *blobDownload) Write(p []byte) (n int, err error) {
	n = len(p)
	b.Completed.Add(int64(n))
	return n, nil
}

func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
	b.refCount.Add(1)

	ticker := time.NewTicker(60 * time.Millisecond)
	for {
		select {
Michael Yang's avatar
Michael Yang committed
245
246
247
248
		case <-b.done:
			if b.Completed.Load() != b.Total {
				return io.ErrUnexpectedEOF
			}
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
		case <-ticker.C:
		case <-ctx.Done():
			if b.refCount.Add(-1) == 0 {
				b.CancelFunc()
			}

			return ctx.Err()
		}

		fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", b.Digest),
			Digest:    b.Digest,
			Total:     b.Total,
			Completed: b.Completed.Load(),
		})

		if b.Completed.Load() >= b.Total {
			<-b.done
			return nil
		}
	}
}

type downloadOpts struct {
	mp      ModelPath
	digest  string
	regOpts *RegistryOptions
	fn      func(api.ProgressResponse)
}

const maxRetries = 3

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
	fp, err := GetBlobsPath(opts.digest)
	if err != nil {
		return err
	}

	fi, err := os.Stat(fp)
	switch {
	case errors.Is(err, os.ErrNotExist):
	case err != nil:
		return err
	default:
		opts.fn(api.ProgressResponse{
			Status:    fmt.Sprintf("downloading %s", opts.digest),
			Digest:    opts.digest,
			Total:     fi.Size(),
			Completed: fi.Size(),
		})

		return nil
	}

	value, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
	blobDownload := value.(*blobDownload)
	if !ok {
		requestURL := opts.mp.BaseURL()
		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
		if err := blobDownload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
			return err
		}

		go blobDownload.Run(context.Background(), requestURL, opts.regOpts)
	}

	return blobDownload.Wait(ctx, opts.fn)
}