images.go 32.4 KB
Newer Older
1
2
3
package server

import (
4
	"bufio"
5
	"bytes"
6
	"context"
7
8
9
10
	"crypto/sha256"
	"encoding/json"
	"errors"
	"fmt"
11
	"html/template"
12
13
14
15
	"io"
	"log"
	"net/http"
	"os"
Michael Yang's avatar
Michael Yang committed
16
	"path"
17
	"path/filepath"
Michael Yang's avatar
Michael Yang committed
18
	"reflect"
19
20
21
22
	"strconv"
	"strings"

	"github.com/jmorganca/ollama/api"
23
	"github.com/jmorganca/ollama/llm"
24
	"github.com/jmorganca/ollama/parser"
25
	"github.com/jmorganca/ollama/vector"
26
27
)

28
29
const MaxRetries = 3

30
31
32
33
type RegistryOptions struct {
	Insecure bool
	Username string
	Password string
Patrick Devine's avatar
Patrick Devine committed
34
	Token    string
35
36
}

37
type Model struct {
38
39
40
41
42
43
44
45
	Name         string `json:"name"`
	ModelPath    string
	AdapterPaths []string
	Template     string
	System       string
	Digest       string
	Options      map[string]interface{}
	Embeddings   []vector.Embedding
46
47
}

Bruce MacDonald's avatar
Bruce MacDonald committed
48
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
49
50
51
52
53
54
	t := m.Template
	if request.Template != "" {
		t = request.Template
	}

	tmpl, err := template.New("").Parse(t)
55
56
57
58
59
	if err != nil {
		return "", err
	}

	var vars struct {
Michael Yang's avatar
Michael Yang committed
60
		First  bool
61
62
		System string
		Prompt string
63
		Embed  string
64
65
66

		// deprecated: versions <= 0.0.7 used this to omit the system prompt
		Context []int
67
68
	}

Michael Yang's avatar
Michael Yang committed
69
	vars.First = len(request.Context) == 0
70
71
	vars.System = m.System
	vars.Prompt = request.Prompt
Michael Yang's avatar
Michael Yang committed
72
	vars.Context = request.Context
Bruce MacDonald's avatar
Bruce MacDonald committed
73
	vars.Embed = embedding
74

75
76
77
78
	if request.System != "" {
		vars.System = request.System
	}

79
80
81
82
83
84
85
86
	var sb strings.Builder
	if err := tmpl.Execute(&sb, vars); err != nil {
		return "", err
	}

	return sb.String(), nil
}

87
88
89
90
91
92
93
94
95
96
97
type ManifestV2 struct {
	SchemaVersion int      `json:"schemaVersion"`
	MediaType     string   `json:"mediaType"`
	Config        Layer    `json:"config"`
	Layers        []*Layer `json:"layers"`
}

type Layer struct {
	MediaType string `json:"mediaType"`
	Digest    string `json:"digest"`
	Size      int    `json:"size"`
Michael Yang's avatar
Michael Yang committed
98
	From      string `json:"from,omitempty"`
99
100
}

Michael Yang's avatar
Michael Yang committed
101
type LayerReader struct {
102
	Layer
Michael Yang's avatar
Michael Yang committed
103
	io.Reader
104
105
106
}

type ConfigV2 struct {
107
108
109
110
111
112
	ModelFamily llm.ModelFamily `json:"model_family"`
	ModelType   llm.ModelType   `json:"model_type"`
	FileType    llm.FileType    `json:"file_type"`
	RootFS      RootFS          `json:"rootfs"`

	// required by spec
113
114
115
116
117
118
119
120
121
	Architecture string `json:"architecture"`
	OS           string `json:"os"`
}

type RootFS struct {
	Type    string   `json:"type"`
	DiffIDs []string `json:"diff_ids"`
}

Patrick Devine's avatar
Patrick Devine committed
122
123
124
125
126
127
128
129
130
func (m *ManifestV2) GetTotalSize() int {
	var total int
	for _, layer := range m.Layers {
		total += layer.Size
	}
	total += m.Config.Size
	return total
}

Patrick Devine's avatar
Patrick Devine committed
131
132
func GetManifest(mp ModelPath) (*ManifestV2, error) {
	fp, err := mp.GetManifestPath(false)
133
134
135
	if err != nil {
		return nil, err
	}
136

137
138
	if _, err = os.Stat(fp); err != nil {
		return nil, err
139
140
141
142
	}

	var manifest *ManifestV2

143
	bts, err := os.ReadFile(fp)
144
145
146
147
	if err != nil {
		return nil, fmt.Errorf("couldn't open file '%s'", fp)
	}

148
	if err := json.Unmarshal(bts, &manifest); err != nil {
149
150
151
152
153
154
155
		return nil, err
	}

	return manifest, nil
}

func GetModel(name string) (*Model, error) {
Patrick Devine's avatar
Patrick Devine committed
156
157
158
	mp := ParseModelPath(name)

	manifest, err := GetManifest(mp)
159
160
161
162
163
	if err != nil {
		return nil, err
	}

	model := &Model{
164
		Name:   mp.GetFullTagname(),
Jeffrey Morgan's avatar
Jeffrey Morgan committed
165
		Digest: manifest.Config.Digest,
166
167
168
	}

	for _, layer := range manifest.Layers {
Patrick Devine's avatar
Patrick Devine committed
169
		filename, err := GetBlobsPath(layer.Digest)
170
171
172
173
		if err != nil {
			return nil, err
		}

174
175
176
		switch layer.MediaType {
		case "application/vnd.ollama.image.model":
			model.ModelPath = filename
177
178
179
180
181
182
183
184
185
186
		case "application/vnd.ollama.image.embed":
			file, err := os.Open(filename)
			if err != nil {
				return nil, fmt.Errorf("failed to open file: %s", filename)
			}
			defer file.Close()

			if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
				return nil, err
			}
187
188
		case "application/vnd.ollama.image.adapter":
			model.AdapterPaths = append(model.AdapterPaths, filename)
189
190
191
192
193
194
195
196
197
		case "application/vnd.ollama.image.template":
			bts, err := os.ReadFile(filename)
			if err != nil {
				return nil, err
			}

			model.Template = string(bts)
		case "application/vnd.ollama.image.system":
			bts, err := os.ReadFile(filename)
198
199
200
			if err != nil {
				return nil, err
			}
201
202

			model.System = string(bts)
203
204
205
206
207
208
209
		case "application/vnd.ollama.image.prompt":
			bts, err := os.ReadFile(filename)
			if err != nil {
				return nil, err
			}

			model.Template = string(bts)
210
		case "application/vnd.ollama.image.params":
Michael Yang's avatar
Michael Yang committed
211
212
213
214
215
			params, err := os.Open(filename)
			if err != nil {
				return nil, err
			}
			defer params.Close()
216

217
			// parse model options parameters into a map so that we can see which fields have been specified explicitly
218
			if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
219
220
				return nil, err
			}
221
222
223
224
225
226
		}
	}

	return model, nil
}

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
func filenameWithPath(path, f string) (string, error) {
	// if filePath starts with ~/, replace it with the user's home directory.
	if strings.HasPrefix(f, "~/") {
		parts := strings.Split(f, "/")
		home, err := os.UserHomeDir()
		if err != nil {
			return "", fmt.Errorf("failed to open file: %v", err)
		}

		f = filepath.Join(home, filepath.Join(parts[1:]...))
	}

	// if filePath is not an absolute path, make it relative to the modelfile path
	if !filepath.IsAbs(f) {
		f = filepath.Join(filepath.Dir(path), f)
	}

	return f, nil
}

247
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
248
249
	mf, err := os.Open(path)
	if err != nil {
250
		fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
251
		return fmt.Errorf("failed to open file: %w", err)
252
	}
253
	defer mf.Close()
254

255
	fn(api.ProgressResponse{Status: "parsing modelfile"})
256
257
258
259
260
	commands, err := parser.Parse(mf)
	if err != nil {
		return err
	}

261
262
263
264
265
	config := ConfigV2{
		Architecture: "amd64",
		OS:           "linux",
	}

Michael Yang's avatar
Michael Yang committed
266
	var layers []*LayerReader
267
	params := make(map[string][]string)
268
	embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
269
	for _, c := range commands {
270
		log.Printf("[%s] - %s\n", c.Name, c.Args)
271
272
		switch c.Name {
		case "model":
273
			fn(api.ProgressResponse{Status: "looking for model"})
274
			embed.model = c.Args
Michael Yang's avatar
Michael Yang committed
275
276
			mp := ParseModelPath(c.Args)
			mf, err := GetManifest(mp)
277
			if err != nil {
278
279
280
				modelFile, err := filenameWithPath(path, c.Args)
				if err != nil {
					return err
281
				}
282
				if _, err := os.Stat(modelFile); err != nil {
283
284
285
					// the model file does not exist, try pulling it
					if errors.Is(err, os.ErrNotExist) {
						fn(api.ProgressResponse{Status: "pulling model file"})
286
						if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
287
288
							return err
						}
289
						mf, err = GetManifest(ParseModelPath(c.Args))
290
291
292
293
294
295
296
297
298
						if err != nil {
							return fmt.Errorf("failed to open file after pull: %v", err)
						}
					} else {
						return err
					}
				} else {
					// create a model from this specified file
					fn(api.ProgressResponse{Status: "creating model layer"})
299
					file, err := os.Open(modelFile)
300
301
302
303
304
					if err != nil {
						return fmt.Errorf("failed to open file: %v", err)
					}
					defer file.Close()

305
306
307
308
309
310
311
312
313
314
315
316
					ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
					if err != nil {
						return err
					}

					config.ModelFamily = ggml.ModelFamily
					config.ModelType = ggml.ModelType
					config.FileType = ggml.FileType

					// reset the file
					file.Seek(0, io.SeekStart)

317
318
319
320
321
322
					l, err := CreateLayer(file)
					if err != nil {
						return fmt.Errorf("failed to create layer: %v", err)
					}
					l.MediaType = "application/vnd.ollama.image.model"
					layers = append(layers, l)
323
				}
324
			}
325

326
			if mf != nil {
327
328
329
330
331
332
				log.Printf("manifest = %#v", mf)
				for _, l := range mf.Layers {
					newLayer, err := GetLayerWithBufferFromLayer(l)
					if err != nil {
						return err
					}
Michael Yang's avatar
Michael Yang committed
333
					newLayer.From = mp.GetNamespaceRepository()
334
335
336
					layers = append(layers, newLayer)
				}
			}
337
338
		case "embed":
			embedFilePath, err := filenameWithPath(path, c.Args)
Michael Yang's avatar
Michael Yang committed
339
340
341
			if err != nil {
				return err
			}
342
			embed.files = append(embed.files, embedFilePath)
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
		case "adapter":
			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})

			fp := c.Args
			if strings.HasPrefix(fp, "~/") {
				parts := strings.Split(fp, "/")
				home, err := os.UserHomeDir()
				if err != nil {
					return fmt.Errorf("failed to open file: %v", err)
				}

				fp = filepath.Join(home, filepath.Join(parts[1:]...))
			}

			// If filePath is not an absolute path, make it relative to the modelfile path
			if !filepath.IsAbs(fp) {
				fp = filepath.Join(filepath.Dir(path), fp)
			}

			// create a model from this specified file
			fn(api.ProgressResponse{Status: "creating model layer"})

			file, err := os.Open(fp)
			if err != nil {
				return fmt.Errorf("failed to open file: %v", err)
			}
			defer file.Close()

			l, err := CreateLayer(file)
			if err != nil {
				return fmt.Errorf("failed to create layer: %v", err)
			}
			l.MediaType = "application/vnd.ollama.image.adapter"
			layers = append(layers, l)
Bruce MacDonald's avatar
Bruce MacDonald committed
377
378
379
380
381
382
383
384
385
386
387
388
		case "license":
			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)

			layer, err := CreateLayer(strings.NewReader(c.Args))
			if err != nil {
				return err
			}

			layer.MediaType = mediaType
			layers = append(layers, layer)
		case "template", "system", "prompt":
389
			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
390
			// remove the layer if one exists
391
392
			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
			layers = removeLayerFromLayers(layers, mediaType)
393

394
			layer, err := CreateLayer(strings.NewReader(c.Args))
395
			if err != nil {
396
				return err
397
			}
398
399
400

			layer.MediaType = mediaType
			layers = append(layers, layer)
401
		default:
402
403
			// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
			params[c.Name] = append(params[c.Name], c.Args)
404
405
406
407
		}
	}

	// Create a single layer for the parameters
Michael Yang's avatar
Michael Yang committed
408
	if len(params) > 0 {
409
		fn(api.ProgressResponse{Status: "creating parameter layer"})
410
		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
411
		formattedParams, err := formatParams(params)
412
413
414
		if err != nil {
			return fmt.Errorf("couldn't create params json: %v", err)
		}
415
416
417
418
419
420
421

		bts, err := json.Marshal(formattedParams)
		if err != nil {
			return err
		}

		l, err := CreateLayer(bytes.NewReader(bts))
422
423
424
425
426
		if err != nil {
			return fmt.Errorf("failed to create layer: %v", err)
		}
		l.MediaType = "application/vnd.ollama.image.params"
		layers = append(layers, l)
427
428
429
430

		// apply these parameters to the embedding options, in case embeddings need to be generated using this model
		embed.opts = api.DefaultOptions()
		embed.opts.FromMap(formattedParams)
431
432
	}

433
434
435
436
437
438
439
	// generate the embedding layers
	embeddingLayers, err := embeddingLayers(embed)
	if err != nil {
		return err
	}
	layers = append(layers, embeddingLayers...)

440
441
442
443
444
445
446
447
448
449
450
	digests, err := getLayerDigests(layers)
	if err != nil {
		return err
	}

	var manifestLayers []*Layer
	for _, l := range layers {
		manifestLayers = append(manifestLayers, &l.Layer)
	}

	// Create a layer for the config object
451
	fn(api.ProgressResponse{Status: "creating config layer"})
452
	cfg, err := createConfigLayer(config, digests)
453
454
455
456
457
	if err != nil {
		return err
	}
	layers = append(layers, cfg)

Michael Yang's avatar
Michael Yang committed
458
	if err := SaveLayers(layers, fn, false); err != nil {
459
460
461
462
		return err
	}

	// Create the manifest
463
	fn(api.ProgressResponse{Status: "writing manifest"})
464
465
466
467
468
	err = CreateManifest(name, cfg, manifestLayers)
	if err != nil {
		return err
	}

469
	fn(api.ProgressResponse{Status: "success"})
470
471
472
	return nil
}

473
474
475
476
477
478
479
480
481
482
483
type EmbeddingParams struct {
	model string
	opts  api.Options
	files []string // paths to files to embed
	fn    func(resp api.ProgressResponse)
}

// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
	layers := []*LayerReader{}
	if len(e.files) > 0 {
484
485
486
487
488
489
490
491
492
493
494
		if _, err := os.Stat(e.model); err != nil {
			if os.IsNotExist(err) {
				// this is a model name rather than the file
				model, err := GetModel(e.model)
				if err != nil {
					return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
				}
				e.model = model.ModelPath
			} else {
				return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
			}
495
496
497
		}

		e.opts.EmbeddingOnly = true
498
		llmModel, err := llm.New(e.model, []string{}, e.opts)
499
500
501
		if err != nil {
			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
		}
Bruce MacDonald's avatar
Bruce MacDonald committed
502
		defer func() {
503
504
			if llmModel != nil {
				llmModel.Close()
Bruce MacDonald's avatar
Bruce MacDonald committed
505
506
			}
		}()
507

Bruce MacDonald's avatar
Bruce MacDonald committed
508
509
510
		addedFiles := make(map[string]bool) // keep track of files that have already been added
		for _, filePattern := range e.files {
			matchingFiles, err := filepath.Glob(filePattern)
511
			if err != nil {
Bruce MacDonald's avatar
Bruce MacDonald committed
512
				return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
513
514
			}

Bruce MacDonald's avatar
Bruce MacDonald committed
515
516
			for _, filePath := range matchingFiles {
				if addedFiles[filePath] {
517
518
					continue
				}
Bruce MacDonald's avatar
Bruce MacDonald committed
519
520
521
				addedFiles[filePath] = true
				// TODO: check file type
				f, err := os.Open(filePath)
522
				if err != nil {
Bruce MacDonald's avatar
Bruce MacDonald committed
523
					return nil, fmt.Errorf("could not open embed file: %w", err)
524
				}
Bruce MacDonald's avatar
Bruce MacDonald committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
				scanner := bufio.NewScanner(f)
				scanner.Split(bufio.ScanLines)

				data := []string{}
				for scanner.Scan() {
					data = append(data, scanner.Text())
				}
				f.Close()

				// the digest of the file is set here so that the client knows a new operation is in progress
				fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))

				embeddings := []vector.Embedding{}
				for i, d := range data {
					if strings.TrimSpace(d) == "" {
						continue
					}
					e.fn(api.ProgressResponse{
						Status:    fmt.Sprintf("creating embeddings for file %s", filePath),
						Digest:    fileDigest,
						Total:     len(data) - 1,
						Completed: i,
					})
548
					embed, err := llmModel.Embedding(d)
Bruce MacDonald's avatar
Bruce MacDonald committed
549
					if err != nil {
550
551
						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
						continue
Bruce MacDonald's avatar
Bruce MacDonald committed
552
553
					}
					embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
554
555
				}

Bruce MacDonald's avatar
Bruce MacDonald committed
556
557
558
559
560
				b, err := json.Marshal(embeddings)
				if err != nil {
					return nil, fmt.Errorf("failed to encode embeddings: %w", err)
				}
				r := bytes.NewReader(b)
561

Bruce MacDonald's avatar
Bruce MacDonald committed
562
563
564
565
566
				digest, size := GetSHA256Digest(r)
				// Reset the position of the reader after calculating the digest
				if _, err := r.Seek(0, io.SeekStart); err != nil {
					return nil, fmt.Errorf("could not reset embed reader: %w", err)
				}
567

Bruce MacDonald's avatar
Bruce MacDonald committed
568
569
570
571
572
573
574
575
				layer := &LayerReader{
					Layer: Layer{
						MediaType: "application/vnd.ollama.image.embed",
						Digest:    digest,
						Size:      size,
					},
					Reader: r,
				}
576

Bruce MacDonald's avatar
Bruce MacDonald committed
577
578
				layers = append(layers, layer)
			}
579
580
581
582
583
		}
	}
	return layers, nil
}

Michael Yang's avatar
Michael Yang committed
584
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
585
586
587
588
589
590
591
592
593
594
	j := 0
	for _, l := range layers {
		if l.MediaType != mediaType {
			layers[j] = l
			j++
		}
	}
	return layers[:j]
}

595
func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
596
597
	// Write each of the layers to disk
	for _, layer := range layers {
Patrick Devine's avatar
Patrick Devine committed
598
		fp, err := GetBlobsPath(layer.Digest)
599
600
601
		if err != nil {
			return err
		}
602
603
604

		_, err = os.Stat(fp)
		if os.IsNotExist(err) || force {
605
606
			fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})

607
608
609
610
611
612
613
			out, err := os.Create(fp)
			if err != nil {
				log.Printf("couldn't create %s", fp)
				return err
			}
			defer out.Close()

Michael Yang's avatar
Michael Yang committed
614
			if _, err = io.Copy(out, layer.Reader); err != nil {
615
616
				return err
			}
Michael Yang's avatar
Michael Yang committed
617

618
		} else {
619
			fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
620
621
622
623
624
625
		}
	}

	return nil
}

Michael Yang's avatar
Michael Yang committed
626
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
Patrick Devine's avatar
Patrick Devine committed
627
628
	mp := ParseModelPath(name)

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
	manifest := ManifestV2{
		SchemaVersion: 2,
		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
		Config: Layer{
			MediaType: cfg.MediaType,
			Size:      cfg.Size,
			Digest:    cfg.Digest,
		},
		Layers: layers,
	}

	manifestJSON, err := json.Marshal(manifest)
	if err != nil {
		return err
	}

Patrick Devine's avatar
Patrick Devine committed
645
	fp, err := mp.GetManifestPath(true)
646
647
648
	if err != nil {
		return err
	}
649
	return os.WriteFile(fp, manifestJSON, 0o644)
650
651
}

Michael Yang's avatar
Michael Yang committed
652
func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
Patrick Devine's avatar
Patrick Devine committed
653
	fp, err := GetBlobsPath(layer.Digest)
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
	if err != nil {
		return nil, err
	}

	file, err := os.Open(fp)
	if err != nil {
		return nil, fmt.Errorf("could not open blob: %w", err)
	}
	defer file.Close()

	newLayer, err := CreateLayer(file)
	if err != nil {
		return nil, err
	}
	newLayer.MediaType = layer.MediaType
	return newLayer, nil
}

672
673
// formatParams converts specified parameter options to their correct types
func formatParams(params map[string][]string) (map[string]interface{}, error) {
674
675
676
	opts := api.Options{}
	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
Michael Yang's avatar
Michael Yang committed
677

678
	// build map of json struct tags to their types
Michael Yang's avatar
Michael Yang committed
679
680
681
682
683
684
685
686
	jsonOpts := make(map[string]reflect.StructField)
	for _, field := range reflect.VisibleFields(typeOpts) {
		jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
		if jsonTag != "" {
			jsonOpts[jsonTag] = field
		}
	}

687
	out := make(map[string]interface{})
Michael Yang's avatar
Michael Yang committed
688
	// iterate params and set values based on json struct tags
689
	for key, vals := range params {
Michael Yang's avatar
Michael Yang committed
690
691
692
693
694
		if opt, ok := jsonOpts[key]; ok {
			field := valueOpts.FieldByName(opt.Name)
			if field.IsValid() && field.CanSet() {
				switch field.Kind() {
				case reflect.Float32:
695
					floatVal, err := strconv.ParseFloat(vals[0], 32)
Michael Yang's avatar
Michael Yang committed
696
					if err != nil {
697
						return nil, fmt.Errorf("invalid float value %s", vals)
Michael Yang's avatar
Michael Yang committed
698
699
					}

700
					out[key] = floatVal
Michael Yang's avatar
Michael Yang committed
701
				case reflect.Int:
702
					intVal, err := strconv.ParseInt(vals[0], 10, 0)
Michael Yang's avatar
Michael Yang committed
703
					if err != nil {
704
						return nil, fmt.Errorf("invalid int value %s", vals)
Michael Yang's avatar
Michael Yang committed
705
706
					}

707
					out[key] = intVal
Michael Yang's avatar
Michael Yang committed
708
				case reflect.Bool:
709
					boolVal, err := strconv.ParseBool(vals[0])
Michael Yang's avatar
Michael Yang committed
710
					if err != nil {
711
						return nil, fmt.Errorf("invalid bool value %s", vals)
Michael Yang's avatar
Michael Yang committed
712
713
					}

714
					out[key] = boolVal
Michael Yang's avatar
Michael Yang committed
715
				case reflect.String:
716
					out[key] = vals[0]
717
				case reflect.Slice:
718
719
					// TODO: only string slices are supported right now
					out[key] = vals
Michael Yang's avatar
Michael Yang committed
720
721
722
723
724
725
726
				default:
					return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
				}
			}
		}
	}

727
	return out, nil
728
729
}

Michael Yang's avatar
Michael Yang committed
730
func getLayerDigests(layers []*LayerReader) ([]string, error) {
731
732
733
734
735
736
737
738
739
740
741
	var digests []string
	for _, l := range layers {
		if l.Digest == "" {
			return nil, fmt.Errorf("layer is missing a digest")
		}
		digests = append(digests, l.Digest)
	}
	return digests, nil
}

// CreateLayer creates a Layer object from a given file
Michael Yang's avatar
Michael Yang committed
742
743
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
	digest, size := GetSHA256Digest(f)
744
	f.Seek(0, io.SeekStart)
745

Michael Yang's avatar
Michael Yang committed
746
	layer := &LayerReader{
747
748
749
750
751
		Layer: Layer{
			MediaType: "application/vnd.docker.image.rootfs.diff.tar",
			Digest:    digest,
			Size:      size,
		},
Michael Yang's avatar
Michael Yang committed
752
		Reader: f,
753
754
755
756
757
	}

	return layer, nil
}

Patrick Devine's avatar
Patrick Devine committed
758
759
760
761
762
763
764
765
766
767
768
func CopyModel(src, dest string) error {
	srcPath, err := ParseModelPath(src).GetManifestPath(false)
	if err != nil {
		return err
	}
	destPath, err := ParseModelPath(dest).GetManifestPath(true)
	if err != nil {
		return err
	}

	// copy the file
Michael Yang's avatar
Michael Yang committed
769
	input, err := os.ReadFile(srcPath)
Patrick Devine's avatar
Patrick Devine committed
770
771
772
773
774
	if err != nil {
		fmt.Println("Error reading file:", err)
		return err
	}

Michael Yang's avatar
Michael Yang committed
775
	err = os.WriteFile(destPath, input, 0o644)
Patrick Devine's avatar
Patrick Devine committed
776
777
778
779
780
781
782
783
	if err != nil {
		fmt.Println("Error reading file:", err)
		return err
	}

	return nil
}

784
func DeleteModel(name string) error {
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
	mp := ParseModelPath(name)

	manifest, err := GetManifest(mp)
	if err != nil {
		return err
	}
	deleteMap := make(map[string]bool)
	for _, layer := range manifest.Layers {
		deleteMap[layer.Digest] = true
	}
	deleteMap[manifest.Config.Digest] = true

	fp, err := GetManifestPath()
	if err != nil {
		return err
	}
	err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
		if err != nil {
			return err
		}
		if !info.IsDir() {
			path := path[len(fp)+1:]
			slashIndex := strings.LastIndex(path, "/")
			if slashIndex == -1 {
				return nil
			}
			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
			fmp := ParseModelPath(tag)

			// skip the manifest we're trying to delete
			if mp.GetFullTagname() == fmp.GetFullTagname() {
				return nil
			}

			// save (i.e. delete from the deleteMap) any files used in other manifests
			manifest, err := GetManifest(fmp)
			if err != nil {
				log.Printf("skipping file: %s", fp)
				return nil
			}
			for _, layer := range manifest.Layers {
				delete(deleteMap, layer.Digest)
			}
			delete(deleteMap, manifest.Config.Digest)
		}
		return nil
	})
Michael Yang's avatar
Michael Yang committed
832
833
834
	if err != nil {
		return err
	}
835
836
837
838

	// only delete the files which are still in the deleteMap
	for k, v := range deleteMap {
		if v {
839
			fp, err := GetBlobsPath(k)
840
			if err != nil {
841
842
843
844
845
				log.Printf("couldn't get file path for '%s': %v", k, err)
				continue
			}
			if err := os.Remove(fp); err != nil {
				log.Printf("couldn't remove file '%s': %v", fp, err)
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
				continue
			}
		}
	}

	fp, err = mp.GetManifestPath(false)
	if err != nil {
		return err
	}
	err = os.Remove(fp)
	if err != nil {
		log.Printf("couldn't remove manifest file '%s': %v", fp, err)
		return err
	}

	return nil
}

864
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
Patrick Devine's avatar
Patrick Devine committed
865
866
	mp := ParseModelPath(name)

867
868
	fn(api.ProgressResponse{Status: "retrieving manifest"})

Patrick Devine's avatar
Patrick Devine committed
869
	manifest, err := GetManifest(mp)
870
	if err != nil {
871
		fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
872
873
874
875
		return err
	}

	var layers []*Layer
Jeffrey Morgan's avatar
Jeffrey Morgan committed
876
	layers = append(layers, manifest.Layers...)
877
878
879
	layers = append(layers, &manifest.Config)

	for _, layer := range layers {
880
		exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
881
882
883
884
885
		if err != nil {
			return err
		}

		if exists {
886
887
888
			fn(api.ProgressResponse{
				Status:    "using existing layer",
				Digest:    layer.Digest,
889
890
				Total:     layer.Size,
				Completed: layer.Size,
891
			})
892
			log.Printf("Layer %s already exists", layer.Digest)
893
894
895
			continue
		}

896
		fn(api.ProgressResponse{
897
898
899
			Status: "starting upload",
			Digest: layer.Digest,
			Total:  layer.Size,
900
		})
901

Michael Yang's avatar
Michael Yang committed
902
		location, err := startUpload(ctx, mp, layer, regOpts)
903
904
905
906
907
		if err != nil {
			log.Printf("couldn't start upload: %v", err)
			return err
		}

Michael Yang's avatar
Michael Yang committed
908
909
910
911
912
913
914
915
916
917
918
919
		if strings.HasPrefix(path.Base(location), "sha256:") {
			layer.Digest = path.Base(location)
			fn(api.ProgressResponse{
				Status:    "using existing layer",
				Digest:    layer.Digest,
				Total:     layer.Size,
				Completed: layer.Size,
			})
			continue
		}

		if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil {
920
921
922
			log.Printf("error uploading blob: %v", err)
			return err
		}
923
924
	}

925
	fn(api.ProgressResponse{Status: "pushing manifest"})
926
	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
927
928
929
930
931
932
933
934
935
	headers := map[string]string{
		"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
	}

	manifestJSON, err := json.Marshal(manifest)
	if err != nil {
		return err
	}

936
	resp, err := makeRequest(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
937
938
939
940
941
942
943
944
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
	if resp.StatusCode != http.StatusCreated {
		body, _ := io.ReadAll(resp.Body)
945
		return fmt.Errorf("on push registry responded with code %d: %v", resp.StatusCode, string(body))
946
947
	}

948
	fn(api.ProgressResponse{Status: "success"})
949
950
951
952

	return nil
}

953
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
Patrick Devine's avatar
Patrick Devine committed
954
	mp := ParseModelPath(name)
955

956
	fn(api.ProgressResponse{Status: "pulling manifest"})
957

958
	manifest, err := pullModelManifest(ctx, mp, regOpts)
959
	if err != nil {
960
		return fmt.Errorf("pull model manifest: %s", err)
961
962
963
	}

	var layers []*Layer
Bruce MacDonald's avatar
Bruce MacDonald committed
964
	layers = append(layers, manifest.Layers...)
965
966
967
	layers = append(layers, &manifest.Config)

	for _, layer := range layers {
968
		if err := downloadBlob(ctx, mp, layer.Digest, regOpts, fn); err != nil {
969
970
971
972
			return err
		}
	}

Michael Yang's avatar
Michael Yang committed
973
974
975
	fn(api.ProgressResponse{Status: "verifying sha256 digest"})
	for _, layer := range layers {
		if err := verifyBlob(layer.Digest); err != nil {
976
977
978
979
980
981
982
983
984
985
986
			if errors.Is(err, errDigestMismatch) {
				// something went wrong, delete the blob
				fp, err := GetBlobsPath(layer.Digest)
				if err != nil {
					return err
				}
				if err := os.Remove(fp); err != nil {
					// log this, but return the original error
					log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
				}
			}
Michael Yang's avatar
Michael Yang committed
987
988
989
990
			return err
		}
	}

991
	fn(api.ProgressResponse{Status: "writing manifest"})
992

993
	manifestJSON, err := json.Marshal(manifest)
994
995
996
997
	if err != nil {
		return err
	}

Patrick Devine's avatar
Patrick Devine committed
998
	fp, err := mp.GetManifestPath(true)
999
1000
1001
1002
	if err != nil {
		return err
	}

Bruce MacDonald's avatar
Bruce MacDonald committed
1003
	err = os.WriteFile(fp, manifestJSON, 0o644)
1004
1005
1006
1007
1008
	if err != nil {
		log.Printf("couldn't write to %s", fp)
		return err
	}

1009
	fn(api.ProgressResponse{Status: "success"})
1010
1011
1012
1013

	return nil
}

1014
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
1015
	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
1016
1017
1018
1019
	headers := map[string]string{
		"Accept": "application/vnd.docker.distribution.manifest.v2+json",
	}

1020
	resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
1021
1022
1023
1024
1025
1026
1027
1028
	if err != nil {
		log.Printf("couldn't get manifest: %v", err)
		return nil, err
	}
	defer resp.Body.Close()

	// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
	if resp.StatusCode != http.StatusOK {
1029
		if resp.StatusCode == http.StatusNotFound {
Bruce MacDonald's avatar
Bruce MacDonald committed
1030
			return nil, fmt.Errorf("model not found")
1031
		}
1032
		body, _ := io.ReadAll(resp.Body)
1033
		return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
	}

	var m *ManifestV2
	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
		return nil, err
	}

	return m, err
}

1044
1045
1046
1047
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
	config.RootFS = RootFS{
		Type:    "layers",
		DiffIDs: layers,
1048
1049
1050
1051
1052
1053
1054
	}

	configJSON, err := json.Marshal(config)
	if err != nil {
		return nil, err
	}

1055
	digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
1056

Michael Yang's avatar
Michael Yang committed
1057
	layer := &LayerReader{
1058
1059
1060
1061
1062
		Layer: Layer{
			MediaType: "application/vnd.docker.container.image.v1+json",
			Digest:    digest,
			Size:      size,
		},
1063
		Reader: bytes.NewBuffer(configJSON),
1064
1065
1066
1067
1068
	}
	return layer, nil
}

// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
Michael Yang's avatar
Michael Yang committed
1069
1070
1071
1072
1073
1074
1075
1076
func GetSHA256Digest(r io.Reader) (string, int) {
	h := sha256.New()
	n, err := io.Copy(h, r)
	if err != nil {
		log.Fatal(err)
	}

	return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
1077
1078
}

Michael Yang's avatar
Michael Yang committed
1079
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
1080
	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
Michael Yang's avatar
Michael Yang committed
1081
1082
1083
	if layer.From != "" {
		url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
	}
1084

1085
	resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts)
1086
1087
1088
1089
1090
1091
1092
	if err != nil {
		log.Printf("couldn't start upload: %v", err)
		return "", err
	}
	defer resp.Body.Close()

	// Check for success
Michael Yang's avatar
Michael Yang committed
1093
	if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusCreated {
1094
		body, _ := io.ReadAll(resp.Body)
1095
		return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
	}

	// Extract UUID location from header
	location := resp.Header.Get("Location")
	if location == "" {
		return "", fmt.Errorf("location header is missing in response")
	}

	return location, nil
}

// Function to check if a blob already exists in the Docker registry
1108
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
1109
	url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
1110

1111
	resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
	if err != nil {
		log.Printf("couldn't check for blob: %v", err)
		return false, err
	}
	defer resp.Body.Close()

	// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
	return resp.StatusCode == http.StatusOK, nil
}

1122
func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
1123
1124
1125
1126
	// TODO allow resumability
	// TODO allow canceling uploads via DELETE
	// TODO allow cross repo blob mount

Patrick Devine's avatar
Patrick Devine committed
1127
	fp, err := GetBlobsPath(layer.Digest)
1128
1129
1130
1131
	if err != nil {
		return err
	}

1132
1133
1134
1135
1136
	f, err := os.Open(fp)
	if err != nil {
		return err
	}

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
	totalUploaded := 0

	r, w := io.Pipe()
	defer r.Close()

	go func() {
		defer w.Close()
		for {
			n, err := io.CopyN(w, f, 1024*1024)
			if err != nil && !errors.Is(err, io.EOF) {
				fn(api.ProgressResponse{
					Status:    fmt.Sprintf("error copying pipe: %v", err),
					Digest:    layer.Digest,
					Total:     layer.Size,
					Completed: totalUploaded,
				})
				return
			}
1155

1156
			totalUploaded += int(n)
1157
1158

			fn(api.ProgressResponse{
1159
				Status:    fmt.Sprintf("uploading %s", layer.Digest),
1160
				Digest:    layer.Digest,
1161
1162
				Total:     layer.Size,
				Completed: totalUploaded,
1163
			})
1164
1165
1166
1167

			if totalUploaded >= layer.Size {
				return
			}
1168
		}
1169
	}()
1170

1171
	url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
1172

1173
1174
1175
1176
	headers := make(map[string]string)
	headers["Content-Type"] = "application/octet-stream"
	headers["Content-Range"] = fmt.Sprintf("0-%d", layer.Size-1)
	headers["Content-Length"] = strconv.Itoa(int(layer.Size))
1177

1178
	// finish the upload
1179
	resp, err := makeRequest(ctx, "PUT", url, headers, r, regOpts)
1180
1181
1182
1183
1184
1185
1186
1187
1188
	if err != nil {
		log.Printf("couldn't finish upload: %v", err)
		return err
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusCreated {
		body, _ := io.ReadAll(resp.Body)
		return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
1189
	}
1190
1191
1192
	return nil
}

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
	retryCtx := ctx.Value("retries")
	var retries int
	var ok bool
	if retries, ok = retryCtx.(int); ok {
		if retries > MaxRetries {
			return nil, fmt.Errorf("Maximum retries hit; are you sure you have access to this resource?")
		}
	}

1203
1204
1205
1206
1207
1208
1209
1210
	if !strings.HasPrefix(url, "http") {
		if regOpts.Insecure {
			url = "http://" + url
		} else {
			url = "https://" + url
		}
	}

Patrick Devine's avatar
Patrick Devine committed
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
	// make a copy of the body in case we need to try the call to makeRequest again
	var buf bytes.Buffer
	if body != nil {
		_, err := io.Copy(&buf, body)
		if err != nil {
			return nil, err
		}
	}

	bodyCopy := bytes.NewReader(buf.Bytes())

	req, err := http.NewRequest(method, url, bodyCopy)
1223
1224
1225
1226
	if err != nil {
		return nil, err
	}

Patrick Devine's avatar
Patrick Devine committed
1227
1228
1229
1230
	if regOpts.Token != "" {
		req.Header.Set("Authorization", "Bearer "+regOpts.Token)
	} else if regOpts.Username != "" && regOpts.Password != "" {
		req.SetBasicAuth(regOpts.Username, regOpts.Password)
1231
1232
	}

Patrick Devine's avatar
Patrick Devine committed
1233
1234
	for k, v := range headers {
		req.Header.Set(k, v)
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
	}

	client := &http.Client{
		CheckRedirect: func(req *http.Request, via []*http.Request) error {
			if len(via) >= 10 {
				return fmt.Errorf("too many redirects")
			}
			log.Printf("redirected to: %s\n", req.URL)
			return nil
		},
	}
	resp, err := client.Do(req)
	if err != nil {
		return nil, err
	}

Patrick Devine's avatar
Patrick Devine committed
1251
1252
1253
1254
	// if the request is unauthenticated, try to authenticate and make the request again
	if resp.StatusCode == http.StatusUnauthorized {
		auth := resp.Header.Get("Www-Authenticate")
		authRedir := ParseAuthRedirectString(string(auth))
1255
		token, err := getAuthToken(ctx, authRedir, regOpts)
Patrick Devine's avatar
Patrick Devine committed
1256
1257
1258
1259
1260
		if err != nil {
			return nil, err
		}
		regOpts.Token = token
		bodyCopy = bytes.NewReader(buf.Bytes())
1261
1262
		ctx = context.WithValue(ctx, "retries", retries+1)
		return makeRequest(ctx, method, url, headers, bodyCopy, regOpts)
Patrick Devine's avatar
Patrick Devine committed
1263
1264
	}

1265
1266
	return resp, nil
}
Michael Yang's avatar
Michael Yang committed
1267

Patrick Devine's avatar
Patrick Devine committed
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
func getValue(header, key string) string {
	startIdx := strings.Index(header, key+"=")
	if startIdx == -1 {
		return ""
	}

	// Move the index to the starting quote after the key.
	startIdx += len(key) + 2
	endIdx := startIdx

	for endIdx < len(header) {
		if header[endIdx] == '"' {
			if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
				endIdx++
				continue
			}
			break
		}
		endIdx++
	}
	return header[startIdx:endIdx]
}

func ParseAuthRedirectString(authStr string) AuthRedirect {
	authStr = strings.TrimPrefix(authStr, "Bearer ")

	return AuthRedirect{
		Realm:   getValue(authStr, "realm"),
		Service: getValue(authStr, "service"),
		Scope:   getValue(authStr, "scope"),
	}
}

1301
1302
var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")

Michael Yang's avatar
Michael Yang committed
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
func verifyBlob(digest string) error {
	fp, err := GetBlobsPath(digest)
	if err != nil {
		return err
	}

	f, err := os.Open(fp)
	if err != nil {
		return err
	}
	defer f.Close()

	fileDigest, _ := GetSHA256Digest(f)
	if digest != fileDigest {
1317
		return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
Michael Yang's avatar
Michael Yang committed
1318
1319
1320
1321
	}

	return nil
}