upload.go 4.32 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package server

import (
	"context"
	"errors"
	"fmt"
	"io"
	"log"
	"net/http"
	"net/url"
	"os"
	"strconv"

	"github.com/jmorganca/ollama/api"
)

func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, error) {
	requestURL := mp.BaseURL()
	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
	if layer.From != "" {
		values := requestURL.Query()
		values.Add("mount", layer.Digest)
		values.Add("from", layer.From)
		requestURL.RawQuery = values.Encode()
	}

	resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
	if err != nil {
		log.Printf("couldn't start upload: %v", err)
		return nil, err
	}
	defer resp.Body.Close()

	// Extract UUID location from header
	location := resp.Header.Get("Location")
	if location == "" {
		return nil, fmt.Errorf("location header is missing in response")
	}

	return url.Parse(location)
}

Michael Yang's avatar
Michael Yang committed
43
func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
Michael Yang's avatar
Michael Yang committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
	// TODO allow resumability
	// TODO allow canceling uploads via DELETE

	fp, err := GetBlobsPath(layer.Digest)
	if err != nil {
		return err
	}

	f, err := os.Open(fp)
	if err != nil {
		return err
	}
	defer f.Close()

Michael Yang's avatar
Michael Yang committed
58
59
	// 95MB chunk size
	chunkSize := 95 * 1024 * 1024
Michael Yang's avatar
Michael Yang committed
60

Michael Yang's avatar
Michael Yang committed
61
62
	for offset := int64(0); offset < int64(layer.Size); {
		chunk := int64(layer.Size) - offset
Michael Yang's avatar
Michael Yang committed
63
64
65
66
		if chunk > int64(chunkSize) {
			chunk = int64(chunkSize)
		}

Michael Yang's avatar
Michael Yang committed
67
68
		sectionReader := io.NewSectionReader(f, int64(offset), chunk)
		for try := 0; try < MaxRetries; try++ {
Michael Yang's avatar
Michael Yang committed
69
70
			ch := make(chan error, 1)

Michael Yang's avatar
Michael Yang committed
71
72
73
74
75
76
			r, w := io.Pipe()
			defer r.Close()
			go func() {
				defer w.Close()

				for chunked := int64(0); chunked < chunk; {
Michael Yang's avatar
Michael Yang committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
					select {
					case err := <-ch:
						log.Printf("chunk interrupted: %v", err)
						return
					default:
						n, err := io.CopyN(w, sectionReader, 1024*1024)
						if err != nil && !errors.Is(err, io.EOF) {
							fn(api.ProgressResponse{
								Status:    fmt.Sprintf("error reading chunk: %v", err),
								Digest:    layer.Digest,
								Total:     layer.Size,
								Completed: int(offset),
							})

							return
						}

						chunked += n
Michael Yang's avatar
Michael Yang committed
95
						fn(api.ProgressResponse{
Michael Yang's avatar
Michael Yang committed
96
							Status:    fmt.Sprintf("uploading %s", layer.Digest),
Michael Yang's avatar
Michael Yang committed
97
98
							Digest:    layer.Digest,
							Total:     layer.Size,
Michael Yang's avatar
Michael Yang committed
99
							Completed: int(offset) + int(chunked),
Michael Yang's avatar
Michael Yang committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
						})
					}
				}
			}()

			headers := make(http.Header)
			headers.Set("Content-Type", "application/octet-stream")
			headers.Set("Content-Length", strconv.Itoa(int(chunk)))
			headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
			resp, err := makeRequest(ctx, "PATCH", requestURL, headers, r, regOpts)
			if err != nil && !errors.Is(err, io.EOF) {
				fn(api.ProgressResponse{
					Status:    fmt.Sprintf("error uploading chunk: %v", err),
					Digest:    layer.Digest,
					Total:     layer.Size,
					Completed: int(offset),
				})

				return err
			}
			defer resp.Body.Close()

Michael Yang's avatar
Michael Yang committed
122
123
			switch {
			case resp.StatusCode == http.StatusUnauthorized:
Michael Yang's avatar
Michael Yang committed
124
125
				ch <- errors.New("unauthorized")

Michael Yang's avatar
Michael Yang committed
126
127
				auth := resp.Header.Get("www-authenticate")
				authRedir := ParseAuthRedirectString(auth)
Michael Yang's avatar
Michael Yang committed
128
				token, err := getAuthToken(ctx, authRedir)
Michael Yang's avatar
Michael Yang committed
129
130
131
132
133
				if err != nil {
					return err
				}

				regOpts.Token = token
Michael Yang's avatar
Michael Yang committed
134
				sectionReader = io.NewSectionReader(f, int64(offset), chunk)
Michael Yang's avatar
Michael Yang committed
135
				continue
Michael Yang's avatar
Michael Yang committed
136
			case resp.StatusCode >= http.StatusBadRequest:
Michael Yang's avatar
Michael Yang committed
137
138
139
140
141
142
143
144
145
				body, _ := io.ReadAll(resp.Body)
				return fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
			}

			offset += sectionReader.Size()
			requestURL, err = url.Parse(resp.Header.Get("Location"))
			if err != nil {
				return err
			}
Michael Yang's avatar
Michael Yang committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

			break
		}
	}

	values := requestURL.Query()
	values.Add("digest", layer.Digest)
	requestURL.RawQuery = values.Encode()

	headers := make(http.Header)
	headers.Set("Content-Type", "application/octet-stream")
	headers.Set("Content-Length", "0")

	// finish the upload
	resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
	if err != nil {
		log.Printf("couldn't finish upload: %v", err)
		return err
	}
	defer resp.Body.Close()

Michael Yang's avatar
Michael Yang committed
167
	if resp.StatusCode >= http.StatusBadRequest {
Michael Yang's avatar
Michael Yang committed
168
169
170
171
172
		body, _ := io.ReadAll(resp.Body)
		return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
	}
	return nil
}