download.go 5.48 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
package server

import (
	"context"
	"errors"
	"fmt"
	"io"
	"log"
	"net/http"
	"os"
	"path"
	"strconv"
	"sync"
	"time"

	"github.com/jmorganca/ollama/api"
)

type FileDownload struct {
	Digest    string
	FilePath  string
	Total     int64
	Completed int64
}

var inProgress sync.Map // map of digests currently being downloaded to their current download progress

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
	fp, err := GetBlobsPath(digest)
	if err != nil {
		return err
	}

	if fi, _ := os.Stat(fp); fi != nil {
		// we already have the file, so return
		fn(api.ProgressResponse{
			Digest:    digest,
			Total:     int(fi.Size()),
			Completed: int(fi.Size()),
		})

		return nil
	}

	fileDownload := &FileDownload{
		Digest:    digest,
		FilePath:  fp,
		Total:     1, // dummy value to indicate that we don't know the total size yet
		Completed: 0,
	}

	_, downloading := inProgress.LoadOrStore(digest, fileDownload)
	if downloading {
		// this is another client requesting the server to download the same blob concurrently
		return monitorDownload(ctx, mp, regOpts, fileDownload, fn)
	}
Bruce MacDonald's avatar
Bruce MacDonald committed
58
	return doDownload(ctx, mp, regOpts, fileDownload, fn)
59
60
61
62
63
64
65
66
}

var downloadMu sync.Mutex // mutex to check to resume a download while monitoring

// monitorDownload monitors the download progress of a blob and resumes it if it is interrupted
func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
	tick := time.NewTicker(time.Second)
	for range tick.C {
Bruce MacDonald's avatar
Bruce MacDonald committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
		done, resume, err := func() (bool, bool, error) {
			downloadMu.Lock()
			defer downloadMu.Unlock()
			val, downloading := inProgress.Load(f.Digest)
			if !downloading {
				// check once again if the download is complete
				if fi, _ := os.Stat(f.FilePath); fi != nil {
					// successful download while monitoring
					fn(api.ProgressResponse{
						Digest:    f.Digest,
						Total:     int(fi.Size()),
						Completed: int(fi.Size()),
					})
					return true, false, nil
				}
				// resume the download
				inProgress.Store(f.Digest, f) // store the file download again to claim the resume
				return false, true, nil
85
			}
Bruce MacDonald's avatar
Bruce MacDonald committed
86
87
88
			f, ok := val.(*FileDownload)
			if !ok {
				return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
89
			}
Bruce MacDonald's avatar
Bruce MacDonald committed
90
91
92
93
94
95
96
97
98
99
			fn(api.ProgressResponse{
				Status:    fmt.Sprintf("downloading %s", f.Digest),
				Digest:    f.Digest,
				Total:     int(f.Total),
				Completed: int(f.Completed),
			})
			return false, false, nil
		}()
		if err != nil {
			return err
100
		}
Bruce MacDonald's avatar
Bruce MacDonald committed
101
102
103
104
105
106
		if done {
			// done downloading
			return nil
		}
		if resume {
			return doDownload(ctx, mp, regOpts, f, fn)
107
108
109
110
111
112
113
		}
	}
	return nil
}

var chunkSize = 1024 * 1024 // 1 MiB in bytes

Bruce MacDonald's avatar
Bruce MacDonald committed
114
115
// doDownload downloads a blob from the registry and stores it in the blobs directory
func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
116
117
118
119
120
121
122
	var size int64

	fi, err := os.Stat(f.FilePath + "-partial")
	switch {
	case errors.Is(err, os.ErrNotExist):
		// noop, file doesn't exist so create it
	case err != nil:
Bruce MacDonald's avatar
Bruce MacDonald committed
123
		return fmt.Errorf("stat: %w", err)
124
125
126
127
128
129
130
	default:
		size = fi.Size()
		// Ensure the size is divisible by the chunk size by removing excess bytes
		size -= size % int64(chunkSize)

		err := os.Truncate(f.FilePath+"-partial", size)
		if err != nil {
Bruce MacDonald's avatar
Bruce MacDonald committed
131
			return fmt.Errorf("truncate: %w", err)
132
133
134
135
136
137
138
139
		}
	}

	url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), f.Digest)
	headers := map[string]string{
		"Range": fmt.Sprintf("bytes=%d-", size),
	}

140
	resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
141
142
	if err != nil {
		log.Printf("couldn't download blob: %v", err)
Bruce MacDonald's avatar
Bruce MacDonald committed
143
		return err
144
	}
Bruce MacDonald's avatar
Bruce MacDonald committed
145
	defer resp.Body.Close()
146
147
148

	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
		body, _ := io.ReadAll(resp.Body)
Bruce MacDonald's avatar
Bruce MacDonald committed
149
		return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
150
151
152
153
	}

	err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
	if err != nil {
Bruce MacDonald's avatar
Bruce MacDonald committed
154
		return fmt.Errorf("make blobs directory: %w", err)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
	}

	remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
	f.Completed = size
	f.Total = remaining + f.Completed

	inProgress.Store(f.Digest, f)

	out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
	if err != nil {
		return fmt.Errorf("open file: %w", err)
	}
	defer out.Close()
outerLoop:
	for {
		select {
		case <-ctx.Done():
			// handle client request cancellation
			inProgress.Delete(f.Digest)
			return nil
		default:
			fn(api.ProgressResponse{
				Status:    fmt.Sprintf("downloading %s", f.Digest),
				Digest:    f.Digest,
				Total:     int(f.Total),
				Completed: int(f.Completed),
			})

			if f.Completed >= f.Total {
				if err := out.Close(); err != nil {
					return err
				}

				if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil {
					fn(api.ProgressResponse{
						Status:    fmt.Sprintf("error renaming file: %v", err),
						Digest:    f.Digest,
						Total:     int(f.Total),
						Completed: int(f.Completed),
					})
					return err
				}

				break outerLoop
			}
		}

		n, err := io.CopyN(out, resp.Body, int64(chunkSize))
		if err != nil && !errors.Is(err, io.EOF) {
			return err
		}
		f.Completed += n

		inProgress.Store(f.Digest, f)
	}

	inProgress.Delete(f.Digest)

	log.Printf("success getting %s\n", f.Digest)
	return nil
}