images.go 29.8 KB
Newer Older
1
2
3
package server

import (
4
	"bufio"
5
	"bytes"
6
	"context"
7
8
9
10
	"crypto/sha256"
	"encoding/json"
	"errors"
	"fmt"
11
	"html/template"
12
13
14
15
16
	"io"
	"log"
	"net/http"
	"os"
	"path/filepath"
Michael Yang's avatar
Michael Yang committed
17
	"reflect"
18
19
20
21
	"strconv"
	"strings"

	"github.com/jmorganca/ollama/api"
22
	"github.com/jmorganca/ollama/llm"
23
	"github.com/jmorganca/ollama/parser"
24
	"github.com/jmorganca/ollama/vector"
25
26
)

27
28
29
30
31
32
type RegistryOptions struct {
	Insecure bool
	Username string
	Password string
}

33
type Model struct {
34
35
36
37
38
39
40
41
	Name         string `json:"name"`
	ModelPath    string
	AdapterPaths []string
	Template     string
	System       string
	Digest       string
	Options      map[string]interface{}
	Embeddings   []vector.Embedding
42
43
}

Bruce MacDonald's avatar
Bruce MacDonald committed
44
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
45
46
47
48
49
50
	t := m.Template
	if request.Template != "" {
		t = request.Template
	}

	tmpl, err := template.New("").Parse(t)
51
52
53
54
55
	if err != nil {
		return "", err
	}

	var vars struct {
Michael Yang's avatar
Michael Yang committed
56
		First  bool
57
58
		System string
		Prompt string
59
		Embed  string
60
61
62

		// deprecated: versions <= 0.0.7 used this to omit the system prompt
		Context []int
63
64
	}

Michael Yang's avatar
Michael Yang committed
65
	vars.First = len(request.Context) == 0
66
67
	vars.System = m.System
	vars.Prompt = request.Prompt
Michael Yang's avatar
Michael Yang committed
68
	vars.Context = request.Context
Bruce MacDonald's avatar
Bruce MacDonald committed
69
	vars.Embed = embedding
70

71
72
73
74
	if request.System != "" {
		vars.System = request.System
	}

75
76
77
78
79
80
81
82
	var sb strings.Builder
	if err := tmpl.Execute(&sb, vars); err != nil {
		return "", err
	}

	return sb.String(), nil
}

83
84
85
86
87
88
89
90
91
92
93
94
95
type ManifestV2 struct {
	SchemaVersion int      `json:"schemaVersion"`
	MediaType     string   `json:"mediaType"`
	Config        Layer    `json:"config"`
	Layers        []*Layer `json:"layers"`
}

type Layer struct {
	MediaType string `json:"mediaType"`
	Digest    string `json:"digest"`
	Size      int    `json:"size"`
}

Michael Yang's avatar
Michael Yang committed
96
type LayerReader struct {
97
	Layer
Michael Yang's avatar
Michael Yang committed
98
	io.Reader
99
100
101
}

type ConfigV2 struct {
102
103
104
105
106
107
	ModelFamily llm.ModelFamily `json:"model_family"`
	ModelType   llm.ModelType   `json:"model_type"`
	FileType    llm.FileType    `json:"file_type"`
	RootFS      RootFS          `json:"rootfs"`

	// required by spec
108
109
110
111
112
113
114
115
116
	Architecture string `json:"architecture"`
	OS           string `json:"os"`
}

type RootFS struct {
	Type    string   `json:"type"`
	DiffIDs []string `json:"diff_ids"`
}

Patrick Devine's avatar
Patrick Devine committed
117
118
119
120
121
122
123
124
125
func (m *ManifestV2) GetTotalSize() int {
	var total int
	for _, layer := range m.Layers {
		total += layer.Size
	}
	total += m.Config.Size
	return total
}

Patrick Devine's avatar
Patrick Devine committed
126
127
func GetManifest(mp ModelPath) (*ManifestV2, error) {
	fp, err := mp.GetManifestPath(false)
128
129
130
	if err != nil {
		return nil, err
	}
131

132
133
	if _, err = os.Stat(fp); err != nil {
		return nil, err
134
135
136
137
	}

	var manifest *ManifestV2

138
	bts, err := os.ReadFile(fp)
139
140
141
142
	if err != nil {
		return nil, fmt.Errorf("couldn't open file '%s'", fp)
	}

143
	if err := json.Unmarshal(bts, &manifest); err != nil {
144
145
146
147
148
149
150
		return nil, err
	}

	return manifest, nil
}

func GetModel(name string) (*Model, error) {
Patrick Devine's avatar
Patrick Devine committed
151
152
153
	mp := ParseModelPath(name)

	manifest, err := GetManifest(mp)
154
155
156
157
158
	if err != nil {
		return nil, err
	}

	model := &Model{
159
		Name:   mp.GetFullTagname(),
Jeffrey Morgan's avatar
Jeffrey Morgan committed
160
		Digest: manifest.Config.Digest,
161
162
163
	}

	for _, layer := range manifest.Layers {
Patrick Devine's avatar
Patrick Devine committed
164
		filename, err := GetBlobsPath(layer.Digest)
165
166
167
168
		if err != nil {
			return nil, err
		}

169
170
171
		switch layer.MediaType {
		case "application/vnd.ollama.image.model":
			model.ModelPath = filename
172
173
174
175
176
177
178
179
180
181
		case "application/vnd.ollama.image.embed":
			file, err := os.Open(filename)
			if err != nil {
				return nil, fmt.Errorf("failed to open file: %s", filename)
			}
			defer file.Close()

			if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
				return nil, err
			}
182
183
		case "application/vnd.ollama.image.adapter":
			model.AdapterPaths = append(model.AdapterPaths, filename)
184
185
186
187
188
189
190
191
192
		case "application/vnd.ollama.image.template":
			bts, err := os.ReadFile(filename)
			if err != nil {
				return nil, err
			}

			model.Template = string(bts)
		case "application/vnd.ollama.image.system":
			bts, err := os.ReadFile(filename)
193
194
195
			if err != nil {
				return nil, err
			}
196
197

			model.System = string(bts)
198
199
200
201
202
203
204
		case "application/vnd.ollama.image.prompt":
			bts, err := os.ReadFile(filename)
			if err != nil {
				return nil, err
			}

			model.Template = string(bts)
205
		case "application/vnd.ollama.image.params":
Michael Yang's avatar
Michael Yang committed
206
207
208
209
210
			params, err := os.Open(filename)
			if err != nil {
				return nil, err
			}
			defer params.Close()
211

212
			// parse model options parameters into a map so that we can see which fields have been specified explicitly
213
			if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
214
215
				return nil, err
			}
216
217
218
219
220
221
		}
	}

	return model, nil
}

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
func filenameWithPath(path, f string) (string, error) {
	// if filePath starts with ~/, replace it with the user's home directory.
	if strings.HasPrefix(f, "~/") {
		parts := strings.Split(f, "/")
		home, err := os.UserHomeDir()
		if err != nil {
			return "", fmt.Errorf("failed to open file: %v", err)
		}

		f = filepath.Join(home, filepath.Join(parts[1:]...))
	}

	// if filePath is not an absolute path, make it relative to the modelfile path
	if !filepath.IsAbs(f) {
		f = filepath.Join(filepath.Dir(path), f)
	}

	return f, nil
}

242
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
243
244
	mf, err := os.Open(path)
	if err != nil {
245
		fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
246
		return fmt.Errorf("failed to open file: %w", err)
247
	}
248
	defer mf.Close()
249

250
	fn(api.ProgressResponse{Status: "parsing modelfile"})
251
252
253
254
255
	commands, err := parser.Parse(mf)
	if err != nil {
		return err
	}

256
257
258
259
260
	config := ConfigV2{
		Architecture: "amd64",
		OS:           "linux",
	}

Michael Yang's avatar
Michael Yang committed
261
	var layers []*LayerReader
262
	params := make(map[string][]string)
263
	embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
264
	for _, c := range commands {
265
		log.Printf("[%s] - %s\n", c.Name, c.Args)
266
267
		switch c.Name {
		case "model":
268
			fn(api.ProgressResponse{Status: "looking for model"})
269
			embed.model = c.Args
270
			mf, err := GetManifest(ParseModelPath(c.Args))
271
			if err != nil {
272
273
274
				modelFile, err := filenameWithPath(path, c.Args)
				if err != nil {
					return err
275
				}
276
				if _, err := os.Stat(modelFile); err != nil {
277
278
279
					// the model file does not exist, try pulling it
					if errors.Is(err, os.ErrNotExist) {
						fn(api.ProgressResponse{Status: "pulling model file"})
280
						if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
281
282
							return err
						}
283
						mf, err = GetManifest(ParseModelPath(c.Args))
284
285
286
287
288
289
290
291
292
						if err != nil {
							return fmt.Errorf("failed to open file after pull: %v", err)
						}
					} else {
						return err
					}
				} else {
					// create a model from this specified file
					fn(api.ProgressResponse{Status: "creating model layer"})
293
					file, err := os.Open(modelFile)
294
295
296
297
298
					if err != nil {
						return fmt.Errorf("failed to open file: %v", err)
					}
					defer file.Close()

299
300
301
302
303
304
305
306
307
308
309
310
					ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
					if err != nil {
						return err
					}

					config.ModelFamily = ggml.ModelFamily
					config.ModelType = ggml.ModelType
					config.FileType = ggml.FileType

					// reset the file
					file.Seek(0, io.SeekStart)

311
312
313
314
315
316
					l, err := CreateLayer(file)
					if err != nil {
						return fmt.Errorf("failed to create layer: %v", err)
					}
					l.MediaType = "application/vnd.ollama.image.model"
					layers = append(layers, l)
317
				}
318
			}
319

320
			if mf != nil {
321
322
323
324
325
326
327
328
329
				log.Printf("manifest = %#v", mf)
				for _, l := range mf.Layers {
					newLayer, err := GetLayerWithBufferFromLayer(l)
					if err != nil {
						return err
					}
					layers = append(layers, newLayer)
				}
			}
330
331
		case "embed":
			embedFilePath, err := filenameWithPath(path, c.Args)
Michael Yang's avatar
Michael Yang committed
332
333
334
			if err != nil {
				return err
			}
335
			embed.files = append(embed.files, embedFilePath)
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
		case "adapter":
			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})

			fp := c.Args
			if strings.HasPrefix(fp, "~/") {
				parts := strings.Split(fp, "/")
				home, err := os.UserHomeDir()
				if err != nil {
					return fmt.Errorf("failed to open file: %v", err)
				}

				fp = filepath.Join(home, filepath.Join(parts[1:]...))
			}

			// If filePath is not an absolute path, make it relative to the modelfile path
			if !filepath.IsAbs(fp) {
				fp = filepath.Join(filepath.Dir(path), fp)
			}

			// create a model from this specified file
			fn(api.ProgressResponse{Status: "creating model layer"})

			file, err := os.Open(fp)
			if err != nil {
				return fmt.Errorf("failed to open file: %v", err)
			}
			defer file.Close()

			l, err := CreateLayer(file)
			if err != nil {
				return fmt.Errorf("failed to create layer: %v", err)
			}
			l.MediaType = "application/vnd.ollama.image.adapter"
			layers = append(layers, l)
Bruce MacDonald's avatar
Bruce MacDonald committed
370
371
372
373
374
375
376
377
378
379
380
381
		case "license":
			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)

			layer, err := CreateLayer(strings.NewReader(c.Args))
			if err != nil {
				return err
			}

			layer.MediaType = mediaType
			layers = append(layers, layer)
		case "template", "system", "prompt":
382
			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
383
			// remove the layer if one exists
384
385
			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
			layers = removeLayerFromLayers(layers, mediaType)
386

387
			layer, err := CreateLayer(strings.NewReader(c.Args))
388
			if err != nil {
389
				return err
390
			}
391
392
393

			layer.MediaType = mediaType
			layers = append(layers, layer)
394
		default:
395
396
			// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
			params[c.Name] = append(params[c.Name], c.Args)
397
398
399
400
		}
	}

	// Create a single layer for the parameters
Michael Yang's avatar
Michael Yang committed
401
	if len(params) > 0 {
402
		fn(api.ProgressResponse{Status: "creating parameter layer"})
403
		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
404
		formattedParams, err := formatParams(params)
405
406
407
		if err != nil {
			return fmt.Errorf("couldn't create params json: %v", err)
		}
408
409
410
411
412
413
414

		bts, err := json.Marshal(formattedParams)
		if err != nil {
			return err
		}

		l, err := CreateLayer(bytes.NewReader(bts))
415
416
417
418
419
		if err != nil {
			return fmt.Errorf("failed to create layer: %v", err)
		}
		l.MediaType = "application/vnd.ollama.image.params"
		layers = append(layers, l)
420
421
422
423

		// apply these parameters to the embedding options, in case embeddings need to be generated using this model
		embed.opts = api.DefaultOptions()
		embed.opts.FromMap(formattedParams)
424
425
	}

426
427
428
429
430
431
432
	// generate the embedding layers
	embeddingLayers, err := embeddingLayers(embed)
	if err != nil {
		return err
	}
	layers = append(layers, embeddingLayers...)

433
434
435
436
437
438
439
440
441
442
443
	digests, err := getLayerDigests(layers)
	if err != nil {
		return err
	}

	var manifestLayers []*Layer
	for _, l := range layers {
		manifestLayers = append(manifestLayers, &l.Layer)
	}

	// Create a layer for the config object
444
	fn(api.ProgressResponse{Status: "creating config layer"})
445
	cfg, err := createConfigLayer(config, digests)
446
447
448
449
450
451
452
453
454
455
456
	if err != nil {
		return err
	}
	layers = append(layers, cfg)

	err = SaveLayers(layers, fn, false)
	if err != nil {
		return err
	}

	// Create the manifest
457
	fn(api.ProgressResponse{Status: "writing manifest"})
458
459
460
461
462
	err = CreateManifest(name, cfg, manifestLayers)
	if err != nil {
		return err
	}

463
	fn(api.ProgressResponse{Status: "success"})
464
465
466
	return nil
}

467
468
469
470
471
472
473
474
475
476
477
type EmbeddingParams struct {
	model string
	opts  api.Options
	files []string // paths to files to embed
	fn    func(resp api.ProgressResponse)
}

// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
	layers := []*LayerReader{}
	if len(e.files) > 0 {
478
479
480
481
482
483
484
485
486
487
488
		if _, err := os.Stat(e.model); err != nil {
			if os.IsNotExist(err) {
				// this is a model name rather than the file
				model, err := GetModel(e.model)
				if err != nil {
					return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
				}
				e.model = model.ModelPath
			} else {
				return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
			}
489
490
491
		}

		e.opts.EmbeddingOnly = true
492
		llmModel, err := llm.New(e.model, []string{}, e.opts)
493
494
495
		if err != nil {
			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
		}
Bruce MacDonald's avatar
Bruce MacDonald committed
496
		defer func() {
497
498
			if llmModel != nil {
				llmModel.Close()
Bruce MacDonald's avatar
Bruce MacDonald committed
499
500
			}
		}()
501

Bruce MacDonald's avatar
Bruce MacDonald committed
502
503
504
		addedFiles := make(map[string]bool) // keep track of files that have already been added
		for _, filePattern := range e.files {
			matchingFiles, err := filepath.Glob(filePattern)
505
			if err != nil {
Bruce MacDonald's avatar
Bruce MacDonald committed
506
				return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
507
508
			}

Bruce MacDonald's avatar
Bruce MacDonald committed
509
510
			for _, filePath := range matchingFiles {
				if addedFiles[filePath] {
511
512
					continue
				}
Bruce MacDonald's avatar
Bruce MacDonald committed
513
514
515
				addedFiles[filePath] = true
				// TODO: check file type
				f, err := os.Open(filePath)
516
				if err != nil {
Bruce MacDonald's avatar
Bruce MacDonald committed
517
					return nil, fmt.Errorf("could not open embed file: %w", err)
518
				}
Bruce MacDonald's avatar
Bruce MacDonald committed
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
				scanner := bufio.NewScanner(f)
				scanner.Split(bufio.ScanLines)

				data := []string{}
				for scanner.Scan() {
					data = append(data, scanner.Text())
				}
				f.Close()

				// the digest of the file is set here so that the client knows a new operation is in progress
				fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))

				embeddings := []vector.Embedding{}
				for i, d := range data {
					if strings.TrimSpace(d) == "" {
						continue
					}
					e.fn(api.ProgressResponse{
						Status:    fmt.Sprintf("creating embeddings for file %s", filePath),
						Digest:    fileDigest,
						Total:     len(data) - 1,
						Completed: i,
					})
542
					embed, err := llmModel.Embedding(d)
Bruce MacDonald's avatar
Bruce MacDonald committed
543
					if err != nil {
544
545
						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
						continue
Bruce MacDonald's avatar
Bruce MacDonald committed
546
547
					}
					embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
548
549
				}

Bruce MacDonald's avatar
Bruce MacDonald committed
550
551
552
553
554
				b, err := json.Marshal(embeddings)
				if err != nil {
					return nil, fmt.Errorf("failed to encode embeddings: %w", err)
				}
				r := bytes.NewReader(b)
555

Bruce MacDonald's avatar
Bruce MacDonald committed
556
557
558
559
560
				digest, size := GetSHA256Digest(r)
				// Reset the position of the reader after calculating the digest
				if _, err := r.Seek(0, io.SeekStart); err != nil {
					return nil, fmt.Errorf("could not reset embed reader: %w", err)
				}
561

Bruce MacDonald's avatar
Bruce MacDonald committed
562
563
564
565
566
567
568
569
				layer := &LayerReader{
					Layer: Layer{
						MediaType: "application/vnd.ollama.image.embed",
						Digest:    digest,
						Size:      size,
					},
					Reader: r,
				}
570

Bruce MacDonald's avatar
Bruce MacDonald committed
571
572
				layers = append(layers, layer)
			}
573
574
575
576
577
		}
	}
	return layers, nil
}

Michael Yang's avatar
Michael Yang committed
578
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
579
580
581
582
583
584
585
586
587
588
	j := 0
	for _, l := range layers {
		if l.MediaType != mediaType {
			layers[j] = l
			j++
		}
	}
	return layers[:j]
}

589
func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
590
591
	// Write each of the layers to disk
	for _, layer := range layers {
Patrick Devine's avatar
Patrick Devine committed
592
		fp, err := GetBlobsPath(layer.Digest)
593
594
595
		if err != nil {
			return err
		}
596
597
598

		_, err = os.Stat(fp)
		if os.IsNotExist(err) || force {
599
600
			fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})

601
602
603
604
605
606
607
			out, err := os.Create(fp)
			if err != nil {
				log.Printf("couldn't create %s", fp)
				return err
			}
			defer out.Close()

Michael Yang's avatar
Michael Yang committed
608
			if _, err = io.Copy(out, layer.Reader); err != nil {
609
610
				return err
			}
Michael Yang's avatar
Michael Yang committed
611

612
		} else {
613
			fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
614
615
616
617
618
619
		}
	}

	return nil
}

Michael Yang's avatar
Michael Yang committed
620
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
Patrick Devine's avatar
Patrick Devine committed
621
622
	mp := ParseModelPath(name)

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
	manifest := ManifestV2{
		SchemaVersion: 2,
		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
		Config: Layer{
			MediaType: cfg.MediaType,
			Size:      cfg.Size,
			Digest:    cfg.Digest,
		},
		Layers: layers,
	}

	manifestJSON, err := json.Marshal(manifest)
	if err != nil {
		return err
	}

Patrick Devine's avatar
Patrick Devine committed
639
	fp, err := mp.GetManifestPath(true)
640
641
642
	if err != nil {
		return err
	}
643
	return os.WriteFile(fp, manifestJSON, 0o644)
644
645
}

Michael Yang's avatar
Michael Yang committed
646
func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
Patrick Devine's avatar
Patrick Devine committed
647
	fp, err := GetBlobsPath(layer.Digest)
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
	if err != nil {
		return nil, err
	}

	file, err := os.Open(fp)
	if err != nil {
		return nil, fmt.Errorf("could not open blob: %w", err)
	}
	defer file.Close()

	newLayer, err := CreateLayer(file)
	if err != nil {
		return nil, err
	}
	newLayer.MediaType = layer.MediaType
	return newLayer, nil
}

666
667
// formatParams converts specified parameter options to their correct types
func formatParams(params map[string][]string) (map[string]interface{}, error) {
668
669
670
	opts := api.Options{}
	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
Michael Yang's avatar
Michael Yang committed
671

672
	// build map of json struct tags to their types
Michael Yang's avatar
Michael Yang committed
673
674
675
676
677
678
679
680
	jsonOpts := make(map[string]reflect.StructField)
	for _, field := range reflect.VisibleFields(typeOpts) {
		jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
		if jsonTag != "" {
			jsonOpts[jsonTag] = field
		}
	}

681
	out := make(map[string]interface{})
Michael Yang's avatar
Michael Yang committed
682
	// iterate params and set values based on json struct tags
683
	for key, vals := range params {
Michael Yang's avatar
Michael Yang committed
684
685
686
687
688
		if opt, ok := jsonOpts[key]; ok {
			field := valueOpts.FieldByName(opt.Name)
			if field.IsValid() && field.CanSet() {
				switch field.Kind() {
				case reflect.Float32:
689
					floatVal, err := strconv.ParseFloat(vals[0], 32)
Michael Yang's avatar
Michael Yang committed
690
					if err != nil {
691
						return nil, fmt.Errorf("invalid float value %s", vals)
Michael Yang's avatar
Michael Yang committed
692
693
					}

694
					out[key] = floatVal
Michael Yang's avatar
Michael Yang committed
695
				case reflect.Int:
696
					intVal, err := strconv.ParseInt(vals[0], 10, 0)
Michael Yang's avatar
Michael Yang committed
697
					if err != nil {
698
						return nil, fmt.Errorf("invalid int value %s", vals)
Michael Yang's avatar
Michael Yang committed
699
700
					}

701
					out[key] = intVal
Michael Yang's avatar
Michael Yang committed
702
				case reflect.Bool:
703
					boolVal, err := strconv.ParseBool(vals[0])
Michael Yang's avatar
Michael Yang committed
704
					if err != nil {
705
						return nil, fmt.Errorf("invalid bool value %s", vals)
Michael Yang's avatar
Michael Yang committed
706
707
					}

708
					out[key] = boolVal
Michael Yang's avatar
Michael Yang committed
709
				case reflect.String:
710
					out[key] = vals[0]
711
				case reflect.Slice:
712
713
					// TODO: only string slices are supported right now
					out[key] = vals
Michael Yang's avatar
Michael Yang committed
714
715
716
717
718
719
720
				default:
					return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
				}
			}
		}
	}

721
	return out, nil
722
723
}

Michael Yang's avatar
Michael Yang committed
724
func getLayerDigests(layers []*LayerReader) ([]string, error) {
725
726
727
728
729
730
731
732
733
734
735
	var digests []string
	for _, l := range layers {
		if l.Digest == "" {
			return nil, fmt.Errorf("layer is missing a digest")
		}
		digests = append(digests, l.Digest)
	}
	return digests, nil
}

// CreateLayer creates a Layer object from a given file
Michael Yang's avatar
Michael Yang committed
736
737
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
	digest, size := GetSHA256Digest(f)
738
	f.Seek(0, io.SeekStart)
739

Michael Yang's avatar
Michael Yang committed
740
	layer := &LayerReader{
741
742
743
744
745
		Layer: Layer{
			MediaType: "application/vnd.docker.image.rootfs.diff.tar",
			Digest:    digest,
			Size:      size,
		},
Michael Yang's avatar
Michael Yang committed
746
		Reader: f,
747
748
749
750
751
	}

	return layer, nil
}

Patrick Devine's avatar
Patrick Devine committed
752
753
754
755
756
757
758
759
760
761
762
func CopyModel(src, dest string) error {
	srcPath, err := ParseModelPath(src).GetManifestPath(false)
	if err != nil {
		return err
	}
	destPath, err := ParseModelPath(dest).GetManifestPath(true)
	if err != nil {
		return err
	}

	// copy the file
Michael Yang's avatar
Michael Yang committed
763
	input, err := os.ReadFile(srcPath)
Patrick Devine's avatar
Patrick Devine committed
764
765
766
767
768
	if err != nil {
		fmt.Println("Error reading file:", err)
		return err
	}

Michael Yang's avatar
Michael Yang committed
769
	err = os.WriteFile(destPath, input, 0o644)
Patrick Devine's avatar
Patrick Devine committed
770
771
772
773
774
775
776
777
	if err != nil {
		fmt.Println("Error reading file:", err)
		return err
	}

	return nil
}

778
func DeleteModel(name string) error {
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
	mp := ParseModelPath(name)

	manifest, err := GetManifest(mp)
	if err != nil {
		return err
	}
	deleteMap := make(map[string]bool)
	for _, layer := range manifest.Layers {
		deleteMap[layer.Digest] = true
	}
	deleteMap[manifest.Config.Digest] = true

	fp, err := GetManifestPath()
	if err != nil {
		return err
	}
	err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
		if err != nil {
			return err
		}
		if !info.IsDir() {
			path := path[len(fp)+1:]
			slashIndex := strings.LastIndex(path, "/")
			if slashIndex == -1 {
				return nil
			}
			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
			fmp := ParseModelPath(tag)

			// skip the manifest we're trying to delete
			if mp.GetFullTagname() == fmp.GetFullTagname() {
				return nil
			}

			// save (i.e. delete from the deleteMap) any files used in other manifests
			manifest, err := GetManifest(fmp)
			if err != nil {
				log.Printf("skipping file: %s", fp)
				return nil
			}
			for _, layer := range manifest.Layers {
				delete(deleteMap, layer.Digest)
			}
			delete(deleteMap, manifest.Config.Digest)
		}
		return nil
	})
Michael Yang's avatar
Michael Yang committed
826
827
828
	if err != nil {
		return err
	}
829
830
831
832

	// only delete the files which are still in the deleteMap
	for k, v := range deleteMap {
		if v {
833
			fp, err := GetBlobsPath(k)
834
			if err != nil {
835
836
837
838
839
				log.Printf("couldn't get file path for '%s': %v", k, err)
				continue
			}
			if err := os.Remove(fp); err != nil {
				log.Printf("couldn't remove file '%s': %v", fp, err)
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
				continue
			}
		}
	}

	fp, err = mp.GetManifestPath(false)
	if err != nil {
		return err
	}
	err = os.Remove(fp)
	if err != nil {
		log.Printf("couldn't remove manifest file '%s': %v", fp, err)
		return err
	}

	return nil
}

858
func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
Patrick Devine's avatar
Patrick Devine committed
859
860
	mp := ParseModelPath(name)

861
862
	fn(api.ProgressResponse{Status: "retrieving manifest"})

Patrick Devine's avatar
Patrick Devine committed
863
	manifest, err := GetManifest(mp)
864
	if err != nil {
865
		fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
866
867
868
869
		return err
	}

	var layers []*Layer
Jeffrey Morgan's avatar
Jeffrey Morgan committed
870
	layers = append(layers, manifest.Layers...)
871
872
873
	layers = append(layers, &manifest.Config)

	for _, layer := range layers {
874
		exists, err := checkBlobExistence(mp, layer.Digest, regOpts)
875
876
877
878
879
		if err != nil {
			return err
		}

		if exists {
880
881
882
			fn(api.ProgressResponse{
				Status:    "using existing layer",
				Digest:    layer.Digest,
883
884
				Total:     layer.Size,
				Completed: layer.Size,
885
			})
886
			log.Printf("Layer %s already exists", layer.Digest)
887
888
889
			continue
		}

890
		fn(api.ProgressResponse{
891
892
893
			Status: "starting upload",
			Digest: layer.Digest,
			Total:  layer.Size,
894
		})
895

896
		location, err := startUpload(mp, regOpts)
897
898
899
900
901
		if err != nil {
			log.Printf("couldn't start upload: %v", err)
			return err
		}

902
		err = uploadBlobChunked(mp, location, layer, regOpts, fn)
903
904
905
906
		if err != nil {
			log.Printf("error uploading blob: %v", err)
			return err
		}
907
908
	}

909
	fn(api.ProgressResponse{Status: "pushing manifest"})
910
	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
911
912
913
914
915
916
917
918
919
	headers := map[string]string{
		"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
	}

	manifestJSON, err := json.Marshal(manifest)
	if err != nil {
		return err
	}

920
	resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
921
922
923
924
925
926
927
928
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
	if resp.StatusCode != http.StatusCreated {
		body, _ := io.ReadAll(resp.Body)
929
		return fmt.Errorf("on push registry responded with code %d: %v", resp.StatusCode, string(body))
930
931
	}

932
	fn(api.ProgressResponse{Status: "success"})
933
934
935
936

	return nil
}

937
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
Patrick Devine's avatar
Patrick Devine committed
938
	mp := ParseModelPath(name)
939

940
	fn(api.ProgressResponse{Status: "pulling manifest"})
941

942
	manifest, err := pullModelManifest(mp, regOpts)
943
	if err != nil {
944
		return fmt.Errorf("pull model manifest: %s", err)
945
946
947
	}

	var layers []*Layer
Bruce MacDonald's avatar
Bruce MacDonald committed
948
	layers = append(layers, manifest.Layers...)
949
950
951
	layers = append(layers, &manifest.Config)

	for _, layer := range layers {
952
		if err := downloadBlob(ctx, mp, layer.Digest, regOpts, fn); err != nil {
953
954
955
956
			return err
		}
	}

Michael Yang's avatar
Michael Yang committed
957
958
959
	fn(api.ProgressResponse{Status: "verifying sha256 digest"})
	for _, layer := range layers {
		if err := verifyBlob(layer.Digest); err != nil {
960
961
962
963
964
965
966
967
968
969
970
			if errors.Is(err, errDigestMismatch) {
				// something went wrong, delete the blob
				fp, err := GetBlobsPath(layer.Digest)
				if err != nil {
					return err
				}
				if err := os.Remove(fp); err != nil {
					// log this, but return the original error
					log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
				}
			}
Michael Yang's avatar
Michael Yang committed
971
972
973
974
			return err
		}
	}

975
	fn(api.ProgressResponse{Status: "writing manifest"})
976

977
	manifestJSON, err := json.Marshal(manifest)
978
979
980
981
	if err != nil {
		return err
	}

Patrick Devine's avatar
Patrick Devine committed
982
	fp, err := mp.GetManifestPath(true)
983
984
985
986
	if err != nil {
		return err
	}

Bruce MacDonald's avatar
Bruce MacDonald committed
987
	err = os.WriteFile(fp, manifestJSON, 0o644)
988
989
990
991
992
	if err != nil {
		log.Printf("couldn't write to %s", fp)
		return err
	}

993
	fn(api.ProgressResponse{Status: "success"})
994
995
996
997

	return nil
}

998
999
func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
	url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
1000
1001
1002
1003
	headers := map[string]string{
		"Accept": "application/vnd.docker.distribution.manifest.v2+json",
	}

1004
	resp, err := makeRequest("GET", url, headers, nil, regOpts)
1005
1006
1007
1008
1009
1010
1011
1012
	if err != nil {
		log.Printf("couldn't get manifest: %v", err)
		return nil, err
	}
	defer resp.Body.Close()

	// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
	if resp.StatusCode != http.StatusOK {
1013
		if resp.StatusCode == http.StatusNotFound {
Bruce MacDonald's avatar
Bruce MacDonald committed
1014
			return nil, fmt.Errorf("model not found")
1015
		}
1016
		body, _ := io.ReadAll(resp.Body)
1017
		return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
	}

	var m *ManifestV2
	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
		return nil, err
	}

	return m, err
}

1028
1029
1030
1031
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
	config.RootFS = RootFS{
		Type:    "layers",
		DiffIDs: layers,
1032
1033
1034
1035
1036
1037
1038
	}

	configJSON, err := json.Marshal(config)
	if err != nil {
		return nil, err
	}

1039
	digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
1040

Michael Yang's avatar
Michael Yang committed
1041
	layer := &LayerReader{
1042
1043
1044
1045
1046
		Layer: Layer{
			MediaType: "application/vnd.docker.container.image.v1+json",
			Digest:    digest,
			Size:      size,
		},
1047
		Reader: bytes.NewBuffer(configJSON),
1048
1049
1050
1051
1052
	}
	return layer, nil
}

// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
Michael Yang's avatar
Michael Yang committed
1053
1054
1055
1056
1057
1058
1059
1060
func GetSHA256Digest(r io.Reader) (string, int) {
	h := sha256.New()
	n, err := io.Copy(h, r)
	if err != nil {
		log.Fatal(err)
	}

	return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
1061
1062
}

1063
1064
func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) {
	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
1065

1066
	resp, err := makeRequest("POST", url, nil, nil, regOpts)
1067
1068
1069
1070
1071
1072
1073
1074
1075
	if err != nil {
		log.Printf("couldn't start upload: %v", err)
		return "", err
	}
	defer resp.Body.Close()

	// Check for success
	if resp.StatusCode != http.StatusAccepted {
		body, _ := io.ReadAll(resp.Body)
1076
		return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
	}

	// Extract UUID location from header
	location := resp.Header.Get("Location")
	if location == "" {
		return "", fmt.Errorf("location header is missing in response")
	}

	return location, nil
}

// Function to check if a blob already exists in the Docker registry
1089
1090
func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
	url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
1091

1092
	resp, err := makeRequest("HEAD", url, nil, nil, regOpts)
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
	if err != nil {
		log.Printf("couldn't check for blob: %v", err)
		return false, err
	}
	defer resp.Body.Close()

	// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
	return resp.StatusCode == http.StatusOK, nil
}

Michael Yang's avatar
Michael Yang committed
1103
func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
1104
1105
1106
1107
	// TODO allow resumability
	// TODO allow canceling uploads via DELETE
	// TODO allow cross repo blob mount

Patrick Devine's avatar
Patrick Devine committed
1108
	fp, err := GetBlobsPath(layer.Digest)
1109
1110
1111
1112
	if err != nil {
		return err
	}

1113
1114
1115
1116
1117
	f, err := os.Open(fp)
	if err != nil {
		return err
	}

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
	totalUploaded := 0

	r, w := io.Pipe()
	defer r.Close()

	go func() {
		defer w.Close()
		for {
			n, err := io.CopyN(w, f, 1024*1024)
			if err != nil && !errors.Is(err, io.EOF) {
				fn(api.ProgressResponse{
					Status:    fmt.Sprintf("error copying pipe: %v", err),
					Digest:    layer.Digest,
					Total:     layer.Size,
					Completed: totalUploaded,
				})
				return
			}
1136

1137
			totalUploaded += int(n)
1138
1139

			fn(api.ProgressResponse{
1140
				Status:    fmt.Sprintf("uploading %s", layer.Digest),
1141
				Digest:    layer.Digest,
1142
1143
				Total:     layer.Size,
				Completed: totalUploaded,
1144
			})
1145
1146
1147
1148

			if totalUploaded >= layer.Size {
				return
			}
1149
		}
1150
	}()
1151

1152
	url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
1153

1154
1155
1156
1157
	headers := make(map[string]string)
	headers["Content-Type"] = "application/octet-stream"
	headers["Content-Range"] = fmt.Sprintf("0-%d", layer.Size-1)
	headers["Content-Length"] = strconv.Itoa(int(layer.Size))
1158

1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
	// finish the upload
	resp, err := makeRequest("PUT", url, headers, r, regOpts)
	if err != nil {
		log.Printf("couldn't finish upload: %v", err)
		return err
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusCreated {
		body, _ := io.ReadAll(resp.Body)
		return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
1170
	}
1171
1172
1173
	return nil
}

1174
1175
1176
1177
1178
1179
1180
1181
1182
func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
	if !strings.HasPrefix(url, "http") {
		if regOpts.Insecure {
			url = "http://" + url
		} else {
			url = "https://" + url
		}
	}

1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
	req, err := http.NewRequest(method, url, body)
	if err != nil {
		return nil, err
	}

	for k, v := range headers {
		req.Header.Set(k, v)
	}

	// TODO: better auth
1193
1194
	if regOpts.Username != "" && regOpts.Password != "" {
		req.SetBasicAuth(regOpts.Username, regOpts.Password)
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
	}

	client := &http.Client{
		CheckRedirect: func(req *http.Request, via []*http.Request) error {
			if len(via) >= 10 {
				return fmt.Errorf("too many redirects")
			}
			log.Printf("redirected to: %s\n", req.URL)
			return nil
		},
	}
	resp, err := client.Do(req)
	if err != nil {
		return nil, err
	}

	return resp, nil
}
Michael Yang's avatar
Michael Yang committed
1213

1214
1215
var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")

Michael Yang's avatar
Michael Yang committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
func verifyBlob(digest string) error {
	fp, err := GetBlobsPath(digest)
	if err != nil {
		return err
	}

	f, err := os.Open(fp)
	if err != nil {
		return err
	}
	defer f.Close()

	fileDigest, _ := GetSHA256Digest(f)
	if digest != fileDigest {
1230
		return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
Michael Yang's avatar
Michael Yang committed
1231
1232
1233
1234
	}

	return nil
}